ganeshkonapalli's picture
Update app.py
f2c7813 verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional
import pandas as pd
import joblib
import os
# Initialize FastAPI app
app = FastAPI()
# --- Model paths ---
TFIDF_VECTORIZER_PATH = "models/tfidf_vectorizer.pkl"
MODELS_PATH = "models/xgb_models.pkl"
LABEL_ENCODERS_PATH = "models/label_encoders.pkl"
# --- Load Models ---
try:
tfidf_vectorizer = joblib.load(TFIDF_VECTORIZER_PATH)
models = joblib.load(MODELS_PATH)
label_encoders = joblib.load(LABEL_ENCODERS_PATH)
except Exception as e:
raise RuntimeError(f"Model loading failed: {e}")
# --- Input Schemas ---
class TransactionData(BaseModel):
Transaction_Id: str
Hit_Seq: int
Hit_Id_List: str
Origin: str
Designation: str
Keywords: str
Name: str
SWIFT_Tag: str
Currency: str
Entity: str
Message: str
City: str
Country: str
State: str
Hit_Type: str
Record_Matching_String: str
WatchList_Match_String: str
Payment_Sender_Name: Optional[str] = ""
Payment_Reciever_Name: Optional[str] = ""
Swift_Message_Type: str
Text_Sanction_Data: str
Matched_Sanctioned_Entity: str
Is_Match: int
Red_Flag_Reason: str
Risk_Level: str
Risk_Score: float
Risk_Score_Description: str
CDD_Level: str
PEP_Status: str
Value_Date: str
Last_Review_Date: str
Next_Review_Date: str
Sanction_Description: str
Checker_Notes: str
Sanction_Context: str
Maker_Action: str
Customer_ID: int
Customer_Type: str
Industry: str
Transaction_Date_Time: str
Transaction_Type: str
Transaction_Channel: str
Originating_Bank: str
Beneficiary_Bank: str
Geographic_Origin: str
Geographic_Destination: str
Match_Score: float
Match_Type: str
Sanctions_List_Version: str
Screening_Date_Time: str
Risk_Category: str
Risk_Drivers: str
Alert_Status: str
Investigation_Outcome: str
Case_Owner_Analyst: str
Escalation_Level: str
Escalation_Date: str
Regulatory_Reporting_Flags: bool
Audit_Trail_Timestamp: str
Source_Of_Funds: str
Purpose_Of_Transaction: str
Beneficial_Owner: str
Sanctions_Exposure_History: bool
class PredictionRequest(BaseModel):
transaction_data: TransactionData
# --- Health Check ---
@app.get("/")
def health_check():
return {"status": "healthy", "message": "XGBoost TF-IDF API is running"}
# --- Prediction Endpoint ---
@app.post("/predict")
async def predict(request: PredictionRequest):
try:
input_data = pd.DataFrame([request.transaction_data.dict()])
# Combine text fields
text_input = "\n".join([
str(input_data[col].iloc[0]) for col in input_data.columns if pd.notna(input_data[col].iloc[0])
])
# TF-IDF transform
X_tfidf = tfidf_vectorizer.transform([text_input])
# Predict each label
response = {}
for label, model in models.items():
proba = model.predict_proba(X_tfidf)[0]
pred_idx = proba.argmax()
pred_label = label_encoders[label].inverse_transform([pred_idx])[0]
class_probs = {
label_encoders[label].classes_[i]: float(prob)
for i, prob in enumerate(proba)
}
response[label] = {
"prediction": pred_label,
"probabilities": class_probs
}
return response
except Exception as e:
raise HTTPException(status_code=500, detail=f"Inference error: {str(e)}")
# --- Validation Endpoint ---
@app.post("/validate")
def validate_input(data: TransactionData):
return {"message": "Input is valid."}
# --- Run Locally (optional) ---
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)