danialsiddiqui commited on
Commit
3ee223a
·
1 Parent(s): d66b3b4

Fix app and add predict endpoint

Browse files
Files changed (2) hide show
  1. app.py +30 -14
  2. requirements.txt +3 -0
app.py CHANGED
@@ -1,23 +1,39 @@
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
 
3
  import joblib
4
- import numpy as np
5
 
6
- # Load your trained model
7
- model = joblib.load("model.joblib")
8
 
9
- # Initialize FastAPI app
10
- app = FastAPI(title="Supermarket Sales Forecast API")
 
 
 
11
 
12
- # Define input schema (replace with actual model features)
13
- class SalesInput(BaseModel):
14
- feature1: float
15
- feature2: float
16
- feature3: float
 
 
 
 
 
 
17
 
18
  @app.post("/predict")
19
- def predict(data: SalesInput):
20
- # Convert input to array for prediction
21
- input_data = np.array([[data.feature1, data.feature2, data.feature3]])
22
- prediction = model.predict(input_data)
 
 
 
 
 
 
 
23
  return {"prediction": prediction.tolist()}
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
+ from huggingface_hub import hf_hub_download
4
  import joblib
5
+ import pandas as pd
6
 
7
+ app = FastAPI()
 
8
 
9
+ # Load model from your HF model repo
10
+ model_path = hf_hub_download(
11
+ repo_id="danialsiddiqui/nexus-task6-model",
12
+ filename="model.joblib"
13
+ )
14
 
15
+ model_data = joblib.load(model_path)
16
+ model = model_data["model"]
17
+ columns = model_data["columns"]
18
+
19
+ class PredictInput(BaseModel):
20
+ gender: str
21
+ customer_type: str
22
+ product_line: str
23
+ unit_price: float
24
+ quantity: int
25
+ tax_5: float
26
 
27
  @app.post("/predict")
28
+ def predict(input_data: PredictInput):
29
+ df = pd.DataFrame([input_data.dict()])
30
+ df = pd.get_dummies(df)
31
+
32
+ # Align columns
33
+ for col in columns:
34
+ if col not in df.columns:
35
+ df[col] = 0
36
+
37
+ df = df[columns]
38
+ prediction = model.predict(df)
39
  return {"prediction": prediction.tolist()}
requirements.txt CHANGED
@@ -5,3 +5,6 @@ joblib
5
  numpy
6
  requests
7
  streamlit
 
 
 
 
5
  numpy
6
  requests
7
  streamlit
8
+ huggingface_hub
9
+ pandas
10
+ scikit-learn