ganeshkonapalli commited on
Commit
f2c7813
·
verified ·
1 Parent(s): aaa245d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -4
app.py CHANGED
@@ -90,18 +90,18 @@ class TransactionData(BaseModel):
90
  class PredictionRequest(BaseModel):
91
  transaction_data: TransactionData
92
 
93
- # --- Root Route ---
94
  @app.get("/")
95
  def health_check():
96
  return {"status": "healthy", "message": "XGBoost TF-IDF API is running"}
97
 
98
- # --- Predict Route ---
99
  @app.post("/predict")
100
  async def predict(request: PredictionRequest):
101
  try:
102
  input_data = pd.DataFrame([request.transaction_data.dict()])
103
 
104
- # Concatenate relevant fields into a single string
105
  text_input = "\n".join([
106
  str(input_data[col].iloc[0]) for col in input_data.columns if pd.notna(input_data[col].iloc[0])
107
  ])
@@ -109,7 +109,7 @@ async def predict(request: PredictionRequest):
109
  # TF-IDF transform
110
  X_tfidf = tfidf_vectorizer.transform([text_input])
111
 
112
- # Predict for each label
113
  response = {}
114
  for label, model in models.items():
115
  proba = model.predict_proba(X_tfidf)[0]
@@ -129,6 +129,11 @@ async def predict(request: PredictionRequest):
129
  except Exception as e:
130
  raise HTTPException(status_code=500, detail=f"Inference error: {str(e)}")
131
 
 
 
 
 
 
132
  # --- Run Locally (optional) ---
133
  if __name__ == "__main__":
134
  import uvicorn
 
90
  class PredictionRequest(BaseModel):
91
  transaction_data: TransactionData
92
 
93
+ # --- Health Check ---
94
  @app.get("/")
95
  def health_check():
96
  return {"status": "healthy", "message": "XGBoost TF-IDF API is running"}
97
 
98
+ # --- Prediction Endpoint ---
99
  @app.post("/predict")
100
  async def predict(request: PredictionRequest):
101
  try:
102
  input_data = pd.DataFrame([request.transaction_data.dict()])
103
 
104
+ # Combine text fields
105
  text_input = "\n".join([
106
  str(input_data[col].iloc[0]) for col in input_data.columns if pd.notna(input_data[col].iloc[0])
107
  ])
 
109
  # TF-IDF transform
110
  X_tfidf = tfidf_vectorizer.transform([text_input])
111
 
112
+ # Predict each label
113
  response = {}
114
  for label, model in models.items():
115
  proba = model.predict_proba(X_tfidf)[0]
 
129
  except Exception as e:
130
  raise HTTPException(status_code=500, detail=f"Inference error: {str(e)}")
131
 
132
+ # --- Validation Endpoint ---
133
+ @app.post("/validate")
134
+ def validate_input(data: TransactionData):
135
+ return {"message": "Input is valid."}
136
+
137
  # --- Run Locally (optional) ---
138
  if __name__ == "__main__":
139
  import uvicorn