from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import Optional import pandas as pd import joblib import os from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.linear_model import LogisticRegression from sklearn.model_selection import train_test_split from sklearn.multioutput import MultiOutputClassifier from sklearn.pipeline import Pipeline # ========== Config ========== DATA_PATH = "data/synthetic_transactions_samples_5000.csv" MODEL_DIR = "models" MODEL_PATH = os.path.join(MODEL_DIR, "logreg_model.pkl") # ========== FastAPI Init ========== app = FastAPI() # ========== Input Schema ========== class TransactionData(BaseModel): Transaction_Id: str Hit_Seq: int Hit_Id_List: str Origin: str Designation: str Keywords: str Name: str SWIFT_Tag: str Currency: str Entity: str Message: str City: str Country: str State: str Hit_Type: str Record_Matching_String: str WatchList_Match_String: str Payment_Sender_Name: Optional[str] = "" Payment_Reciever_Name: Optional[str] = "" Swift_Message_Type: str Text_Sanction_Data: str Matched_Sanctioned_Entity: str Is_Match: int Red_Flag_Reason: str Risk_Level: str Risk_Score: float Risk_Score_Description: str CDD_Level: str PEP_Status: str Value_Date: str Last_Review_Date: str Next_Review_Date: str Sanction_Description: str Checker_Notes: str Sanction_Context: str Maker_Action: str Customer_ID: int Customer_Type: str Industry: str Transaction_Date_Time: str Transaction_Type: str Transaction_Channel: str Originating_Bank: str Beneficiary_Bank: str Geographic_Origin: str Geographic_Destination: str Match_Score: float Match_Type: str Sanctions_List_Version: str Screening_Date_Time: str Risk_Category: str Risk_Drivers: str Alert_Status: str Investigation_Outcome: str Case_Owner_Analyst: str Escalation_Level: str Escalation_Date: str Regulatory_Reporting_Flags: bool Audit_Trail_Timestamp: str Source_Of_Funds: str Purpose_Of_Transaction: str Beneficial_Owner: str Sanctions_Exposure_History: bool # ========== Utils ========== def create_text_input(row): return f""" Transaction ID: {row['Transaction_Id']} Origin: {row['Origin']} Designation: {row['Designation']} Keywords: {row['Keywords']} Name: {row['Name']} SWIFT Tag: {row['SWIFT_Tag']} Currency: {row['Currency']} Entity: {row['Entity']} Message: {row['Message']} City: {row['City']} Country: {row['Country']} State: {row['State']} Hit Type: {row['Hit_Type']} Record Matching String: {row['Record_Matching_String']} WatchList Match String: {row['WatchList_Match_String']} Payment Sender: {row['Payment_Sender_Name']} Payment Receiver: {row['Payment_Reciever_Name']} Swift Message Type: {row['Swift_Message_Type']} Text Sanction Data: {row['Text_Sanction_Data']} Matched Sanctioned Entity: {row['Matched_Sanctioned_Entity']} Red Flag Reason: {row['Red_Flag_Reason']} Risk Level: {row['Risk_Level']} Risk Score: {row['Risk_Score']} CDD Level: {row['CDD_Level']} PEP Status: {row['PEP_Status']} Sanction Description: {row['Sanction_Description']} Checker Notes: {row['Checker_Notes']} Sanction Context: {row['Sanction_Context']} Maker Action: {row['Maker_Action']} Customer Type: {row['Customer_Type']} Industry: {row['Industry']} Transaction Type: {row['Transaction_Type']} Transaction Channel: {row['Transaction_Channel']} Geographic Origin: {row['Geographic_Origin']} Geographic Destination: {row['Geographic_Destination']} Risk Category: {row['Risk_Category']} Risk Drivers: {row['Risk_Drivers']} Alert Status: {row['Alert_Status']} Investigation Outcome: {row['Investigation_Outcome']} Source of Funds: {row['Source_Of_Funds']} Purpose of Transaction: {row['Purpose_Of_Transaction']} Beneficial Owner: {row['Beneficial_Owner']} """ # ========== Root ========== @app.get("/") def root(): return {"message": "TF-IDF Logistic Regression API is running."} # ========== API Routes ========== @app.post("/train") def train_model(): df = pd.read_csv(DATA_PATH) df = df.fillna("") df["text_input"] = df.apply(create_text_input, axis=1) X = df["text_input"] y = df[["Maker_Action", "Escalation_Level", "Risk_Category", "Risk_Drivers", "Investigation_Outcome", "Red_Flag_Reason"]] X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) vectorizer = TfidfVectorizer() classifier = MultiOutputClassifier(LogisticRegression(max_iter=1000)) pipeline = Pipeline([ ("vectorizer", vectorizer), ("classifier", classifier) ]) pipeline.fit(X_train, y_train) os.makedirs(MODEL_DIR, exist_ok=True) joblib.dump(pipeline, MODEL_PATH) accuracy = pipeline.score(X_test, y_test) return {"message": "Model trained and saved.", "accuracy": accuracy} @app.post("/predict") def predict(request: TransactionData): try: model = joblib.load(MODEL_PATH) input_data = pd.DataFrame([request.dict()]) input_data = input_data.fillna("") text_input = create_text_input(input_data.iloc[0]) prediction = model.predict([text_input])[0] return { "Maker_Action": prediction[0], "Escalation_Level": prediction[1], "Risk_Category": prediction[2], "Risk_Drivers": prediction[3], "Investigation_Outcome": prediction[4], } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/validate") def validate_input(request: TransactionData): return {"message": "Input is valid."} @app.get("/test") def test_api(): return {"message": "Test successful."}