Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -90,18 +90,18 @@ class TransactionData(BaseModel):
|
|
| 90 |
class PredictionRequest(BaseModel):
|
| 91 |
transaction_data: TransactionData
|
| 92 |
|
| 93 |
-
# ---
|
| 94 |
@app.get("/")
|
| 95 |
def health_check():
|
| 96 |
return {"status": "healthy", "message": "XGBoost TF-IDF API is running"}
|
| 97 |
|
| 98 |
-
# ---
|
| 99 |
@app.post("/predict")
|
| 100 |
async def predict(request: PredictionRequest):
|
| 101 |
try:
|
| 102 |
input_data = pd.DataFrame([request.transaction_data.dict()])
|
| 103 |
|
| 104 |
-
#
|
| 105 |
text_input = "\n".join([
|
| 106 |
str(input_data[col].iloc[0]) for col in input_data.columns if pd.notna(input_data[col].iloc[0])
|
| 107 |
])
|
|
@@ -109,7 +109,7 @@ async def predict(request: PredictionRequest):
|
|
| 109 |
# TF-IDF transform
|
| 110 |
X_tfidf = tfidf_vectorizer.transform([text_input])
|
| 111 |
|
| 112 |
-
# Predict
|
| 113 |
response = {}
|
| 114 |
for label, model in models.items():
|
| 115 |
proba = model.predict_proba(X_tfidf)[0]
|
|
@@ -129,6 +129,11 @@ async def predict(request: PredictionRequest):
|
|
| 129 |
except Exception as e:
|
| 130 |
raise HTTPException(status_code=500, detail=f"Inference error: {str(e)}")
|
| 131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
# --- Run Locally (optional) ---
|
| 133 |
if __name__ == "__main__":
|
| 134 |
import uvicorn
|
|
|
|
| 90 |
class PredictionRequest(BaseModel):
|
| 91 |
transaction_data: TransactionData
|
| 92 |
|
| 93 |
+
# --- Health Check ---
|
| 94 |
@app.get("/")
|
| 95 |
def health_check():
|
| 96 |
return {"status": "healthy", "message": "XGBoost TF-IDF API is running"}
|
| 97 |
|
| 98 |
+
# --- Prediction Endpoint ---
|
| 99 |
@app.post("/predict")
|
| 100 |
async def predict(request: PredictionRequest):
|
| 101 |
try:
|
| 102 |
input_data = pd.DataFrame([request.transaction_data.dict()])
|
| 103 |
|
| 104 |
+
# Combine text fields
|
| 105 |
text_input = "\n".join([
|
| 106 |
str(input_data[col].iloc[0]) for col in input_data.columns if pd.notna(input_data[col].iloc[0])
|
| 107 |
])
|
|
|
|
| 109 |
# TF-IDF transform
|
| 110 |
X_tfidf = tfidf_vectorizer.transform([text_input])
|
| 111 |
|
| 112 |
+
# Predict each label
|
| 113 |
response = {}
|
| 114 |
for label, model in models.items():
|
| 115 |
proba = model.predict_proba(X_tfidf)[0]
|
|
|
|
| 129 |
except Exception as e:
|
| 130 |
raise HTTPException(status_code=500, detail=f"Inference error: {str(e)}")
|
| 131 |
|
| 132 |
+
# --- Validation Endpoint ---
|
| 133 |
+
@app.post("/validate")
|
| 134 |
+
def validate_input(data: TransactionData):
|
| 135 |
+
return {"message": "Input is valid."}
|
| 136 |
+
|
| 137 |
# --- Run Locally (optional) ---
|
| 138 |
if __name__ == "__main__":
|
| 139 |
import uvicorn
|