BeeBasic commited on
Commit
a08c7f9
·
verified ·
1 Parent(s): 28ba1f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -20
app.py CHANGED
@@ -6,6 +6,7 @@ import pandas as pd
6
 
7
  app = FastAPI(title="Food Surplus Predictor API")
8
 
 
9
  model_path = hf_hub_download(
10
  repo_id="BeeBasic/food-for-all",
11
  filename="best_model.joblib",
@@ -13,7 +14,7 @@ model_path = hf_hub_download(
13
  )
14
  model = joblib.load(model_path)
15
 
16
- # Define raw input schema
17
  class CanteenInput(BaseModel):
18
  canteen_id: str
19
  canteen_name: str
@@ -31,26 +32,22 @@ def home():
31
 
32
  @app.post("/predict")
33
  def predict_surplus(request: RequestBody):
 
34
  df = pd.DataFrame([canteen.dict() for canteen in request.data])
35
 
36
- # One-hot encode based on training columns
37
- all_columns = [
38
- 'day','month','year','day_of_week',
39
- 'canteen_id_C002','canteen_id_C003','canteen_id_C004',
40
- 'canteen_id_C005','canteen_id_C006','canteen_id_C010',
41
- 'canteen_name_Anna University Mess','canteen_name_Buhari Hotel Canteen',
42
- 'canteen_name_Crescent College Cafeteria','canteen_name_IIT Madras Hostel Mess',
43
- 'canteen_name_Murugan Idli Shop','canteen_name_SRM Campus Canteen',
44
- 'canteen_name_Sangeetha Veg Restaurant','canteen_name_The Marina Café',
45
- 'canteen_name_VIT University Main Canteen'
46
- ]
47
-
48
- encoded = pd.get_dummies(df, columns=["canteen_id","canteen_name"])
49
- for col in all_columns:
50
- if col not in encoded.columns:
51
- encoded[col] = 0
52
- encoded = encoded[all_columns]
53
-
54
- predictions = model.predict(encoded)
55
  df["predicted_surplus"] = predictions
 
56
  return df.to_dict(orient="records")
 
6
 
7
  app = FastAPI(title="Food Surplus Predictor API")
8
 
9
+ # Download model
10
  model_path = hf_hub_download(
11
  repo_id="BeeBasic/food-for-all",
12
  filename="best_model.joblib",
 
14
  )
15
  model = joblib.load(model_path)
16
 
17
+ # Define schema
18
  class CanteenInput(BaseModel):
19
  canteen_id: str
20
  canteen_name: str
 
32
 
33
  @app.post("/predict")
34
  def predict_surplus(request: RequestBody):
35
+ # Convert input to DataFrame
36
  df = pd.DataFrame([canteen.dict() for canteen in request.data])
37
 
38
+ # One-hot encode categorical columns
39
+ df_encoded = pd.get_dummies(df, columns=["canteen_id", "canteen_name"])
40
+
41
+ # Align columns with model’s expected input
42
+ model_features = model.feature_names_ if hasattr(model, "feature_names_") else None
43
+ if model_features:
44
+ for col in model_features:
45
+ if col not in df_encoded.columns:
46
+ df_encoded[col] = 0
47
+ df_encoded = df_encoded[model_features]
48
+
49
+ # Run prediction
50
+ predictions = model.predict(df_encoded)
 
 
 
 
 
 
51
  df["predicted_surplus"] = predictions
52
+
53
  return df.to_dict(orient="records")