sravan837 commited on
Commit
e34f8c8
·
1 Parent(s): 467ff7c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -9
app.py CHANGED
@@ -1,10 +1,13 @@
1
- from fastapi import FastAPI
2
  import joblib
3
  import pandas as pd
4
  from pydantic import BaseModel
5
 
6
  # Load the trained model
7
- model = joblib.load("slope_stability_model.pkl")
 
 
 
8
 
9
  # Initialize FastAPI
10
  app = FastAPI()
@@ -17,15 +20,32 @@ class SlopeStabilityInput(BaseModel):
17
  slope_angle: float
18
  slope_height: float
19
  water_pressure_ratio: float
20
- reinforcement_type: str
 
 
 
 
 
 
 
 
21
 
22
- # Define API endpoint
23
  @app.post("/predict")
24
  def predict_slope_stability(data: SlopeStabilityInput):
25
- # Convert input data into a DataFrame
26
- input_data = pd.DataFrame([data.dict()])
 
 
 
 
 
 
 
 
 
 
27
 
28
- # Make prediction using the model
29
- prediction = model.predict(input_data)
30
 
31
- return {"Factor of Safety": float(prediction[0])}
 
 
1
+ from fastapi import FastAPI, HTTPException
2
  import joblib
3
  import pandas as pd
4
  from pydantic import BaseModel
5
 
6
  # Load the trained model
7
+ try:
8
+ model = joblib.load("slope_stability_model.pkl")
9
+ except Exception as e:
10
+ raise RuntimeError(f"Error loading model: {e}")
11
 
12
  # Initialize FastAPI
13
  app = FastAPI()
 
20
  slope_angle: float
21
  slope_height: float
22
  water_pressure_ratio: float
23
+ reinforcement_type: str # Categorical feature
24
+
25
+ # Dummy mapping (You should replace this with actual mapping used during training)
26
+ reinforcement_mapping = {
27
+ "Geogrid": 0,
28
+ "Anchors": 1,
29
+ "Shotcrete": 2,
30
+ "Gabions": 3
31
+ }
32
 
 
33
  @app.post("/predict")
34
  def predict_slope_stability(data: SlopeStabilityInput):
35
+ try:
36
+ # Convert input data into a DataFrame
37
+ input_data = pd.DataFrame([data.dict()])
38
+
39
+ # Encode reinforcement_type (ensure consistency with training)
40
+ if data.reinforcement_type not in reinforcement_mapping:
41
+ raise HTTPException(status_code=400, detail="Invalid reinforcement type")
42
+
43
+ input_data["reinforcement_type"] = reinforcement_mapping[data.reinforcement_type]
44
+
45
+ # Make prediction using the model
46
+ prediction = model.predict(input_data)
47
 
48
+ return {"Factor of Safety": float(prediction[0])}
 
49
 
50
+ except Exception as e:
51
+ raise HTTPException(status_code=500, detail=str(e))