from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import Optional import pandas as pd import joblib import os # Initialize FastAPI app app = FastAPI() # --- Model paths --- TFIDF_VECTORIZER_PATH = "models/tfidf_vectorizer.pkl" MODELS_PATH = "models/xgb_models.pkl" LABEL_ENCODERS_PATH = "models/label_encoders.pkl" # --- Load Models --- try: tfidf_vectorizer = joblib.load(TFIDF_VECTORIZER_PATH) models = joblib.load(MODELS_PATH) label_encoders = joblib.load(LABEL_ENCODERS_PATH) except Exception as e: raise RuntimeError(f"Model loading failed: {e}") # --- Input Schemas --- class TransactionData(BaseModel): Transaction_Id: str Hit_Seq: int Hit_Id_List: str Origin: str Designation: str Keywords: str Name: str SWIFT_Tag: str Currency: str Entity: str Message: str City: str Country: str State: str Hit_Type: str Record_Matching_String: str WatchList_Match_String: str Payment_Sender_Name: Optional[str] = "" Payment_Reciever_Name: Optional[str] = "" Swift_Message_Type: str Text_Sanction_Data: str Matched_Sanctioned_Entity: str Is_Match: int Red_Flag_Reason: str Risk_Level: str Risk_Score: float Risk_Score_Description: str CDD_Level: str PEP_Status: str Value_Date: str Last_Review_Date: str Next_Review_Date: str Sanction_Description: str Checker_Notes: str Sanction_Context: str Maker_Action: str Customer_ID: int Customer_Type: str Industry: str Transaction_Date_Time: str Transaction_Type: str Transaction_Channel: str Originating_Bank: str Beneficiary_Bank: str Geographic_Origin: str Geographic_Destination: str Match_Score: float Match_Type: str Sanctions_List_Version: str Screening_Date_Time: str Risk_Category: str Risk_Drivers: str Alert_Status: str Investigation_Outcome: str Case_Owner_Analyst: str Escalation_Level: str Escalation_Date: str Regulatory_Reporting_Flags: bool Audit_Trail_Timestamp: str Source_Of_Funds: str Purpose_Of_Transaction: str Beneficial_Owner: str Sanctions_Exposure_History: bool class PredictionRequest(BaseModel): transaction_data: TransactionData # --- Health Check --- @app.get("/") def health_check(): return {"status": "healthy", "message": "XGBoost TF-IDF API is running"} # --- Prediction Endpoint --- @app.post("/predict") async def predict(request: PredictionRequest): try: input_data = pd.DataFrame([request.transaction_data.dict()]) # Combine text fields text_input = "\n".join([ str(input_data[col].iloc[0]) for col in input_data.columns if pd.notna(input_data[col].iloc[0]) ]) # TF-IDF transform X_tfidf = tfidf_vectorizer.transform([text_input]) # Predict each label response = {} for label, model in models.items(): proba = model.predict_proba(X_tfidf)[0] pred_idx = proba.argmax() pred_label = label_encoders[label].inverse_transform([pred_idx])[0] class_probs = { label_encoders[label].classes_[i]: float(prob) for i, prob in enumerate(proba) } response[label] = { "prediction": pred_label, "probabilities": class_probs } return response except Exception as e: raise HTTPException(status_code=500, detail=f"Inference error: {str(e)}") # --- Validation Endpoint --- @app.post("/validate") def validate_input(data: TransactionData): return {"message": "Input is valid."} # --- Run Locally (optional) --- if __name__ == "__main__": import uvicorn port = int(os.environ.get("PORT", 7860)) uvicorn.run(app, host="0.0.0.0", port=port)