Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from typing import Optional | |
| import pandas as pd | |
| import joblib | |
| import os | |
| # Initialize FastAPI app | |
| app = FastAPI() | |
| # --- Model paths --- | |
| TFIDF_VECTORIZER_PATH = "models/tfidf_vectorizer.pkl" | |
| MODELS_PATH = "models/xgb_models.pkl" | |
| LABEL_ENCODERS_PATH = "models/label_encoders.pkl" | |
| # --- Load Models --- | |
| try: | |
| tfidf_vectorizer = joblib.load(TFIDF_VECTORIZER_PATH) | |
| models = joblib.load(MODELS_PATH) | |
| label_encoders = joblib.load(LABEL_ENCODERS_PATH) | |
| except Exception as e: | |
| raise RuntimeError(f"Model loading failed: {e}") | |
| # --- Input Schemas --- | |
| class TransactionData(BaseModel): | |
| Transaction_Id: str | |
| Hit_Seq: int | |
| Hit_Id_List: str | |
| Origin: str | |
| Designation: str | |
| Keywords: str | |
| Name: str | |
| SWIFT_Tag: str | |
| Currency: str | |
| Entity: str | |
| Message: str | |
| City: str | |
| Country: str | |
| State: str | |
| Hit_Type: str | |
| Record_Matching_String: str | |
| WatchList_Match_String: str | |
| Payment_Sender_Name: Optional[str] = "" | |
| Payment_Reciever_Name: Optional[str] = "" | |
| Swift_Message_Type: str | |
| Text_Sanction_Data: str | |
| Matched_Sanctioned_Entity: str | |
| Is_Match: int | |
| Red_Flag_Reason: str | |
| Risk_Level: str | |
| Risk_Score: float | |
| Risk_Score_Description: str | |
| CDD_Level: str | |
| PEP_Status: str | |
| Value_Date: str | |
| Last_Review_Date: str | |
| Next_Review_Date: str | |
| Sanction_Description: str | |
| Checker_Notes: str | |
| Sanction_Context: str | |
| Maker_Action: str | |
| Customer_ID: int | |
| Customer_Type: str | |
| Industry: str | |
| Transaction_Date_Time: str | |
| Transaction_Type: str | |
| Transaction_Channel: str | |
| Originating_Bank: str | |
| Beneficiary_Bank: str | |
| Geographic_Origin: str | |
| Geographic_Destination: str | |
| Match_Score: float | |
| Match_Type: str | |
| Sanctions_List_Version: str | |
| Screening_Date_Time: str | |
| Risk_Category: str | |
| Risk_Drivers: str | |
| Alert_Status: str | |
| Investigation_Outcome: str | |
| Case_Owner_Analyst: str | |
| Escalation_Level: str | |
| Escalation_Date: str | |
| Regulatory_Reporting_Flags: bool | |
| Audit_Trail_Timestamp: str | |
| Source_Of_Funds: str | |
| Purpose_Of_Transaction: str | |
| Beneficial_Owner: str | |
| Sanctions_Exposure_History: bool | |
| class PredictionRequest(BaseModel): | |
| transaction_data: TransactionData | |
| # --- Health Check --- | |
| def health_check(): | |
| return {"status": "healthy", "message": "XGBoost TF-IDF API is running"} | |
| # --- Prediction Endpoint --- | |
| async def predict(request: PredictionRequest): | |
| try: | |
| input_data = pd.DataFrame([request.transaction_data.dict()]) | |
| # Combine text fields | |
| text_input = "\n".join([ | |
| str(input_data[col].iloc[0]) for col in input_data.columns if pd.notna(input_data[col].iloc[0]) | |
| ]) | |
| # TF-IDF transform | |
| X_tfidf = tfidf_vectorizer.transform([text_input]) | |
| # Predict each label | |
| response = {} | |
| for label, model in models.items(): | |
| proba = model.predict_proba(X_tfidf)[0] | |
| pred_idx = proba.argmax() | |
| pred_label = label_encoders[label].inverse_transform([pred_idx])[0] | |
| class_probs = { | |
| label_encoders[label].classes_[i]: float(prob) | |
| for i, prob in enumerate(proba) | |
| } | |
| response[label] = { | |
| "prediction": pred_label, | |
| "probabilities": class_probs | |
| } | |
| return response | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Inference error: {str(e)}") | |
| # --- Validation Endpoint --- | |
| def validate_input(data: TransactionData): | |
| return {"message": "Input is valid."} | |
| # --- Run Locally (optional) --- | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.environ.get("PORT", 7860)) | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |