from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import Optional import pandas as pd import joblib import os from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.preprocessing import LabelEncoder from sklearn.multioutput import MultiOutputClassifier import config from sklearn.metrics import accuracy_score from sklearn.model_selection import train_test_split from sklearn.linear_model import LogisticRegression # --- Configuration --- LABEL_COLUMNS = [ "Red_Flag_Reason", "Maker_Action", "Escalation_Level", "Risk_Category", "Risk_Drivers", "Investigation_Outcome" ] TEXT_COLUMN = "Sanction_Context" MODEL_DIR = "/tmp" MODEL_PATH = os.path.join(MODEL_DIR, "logreg_model.pkl") TFIDF_PATH = os.path.join(MODEL_DIR, "tfidf_vectorizer.pkl") ENCODERS_PATH = os.path.join(MODEL_DIR, "label_encoders.pkl") # --- FastAPI App --- app = FastAPI() # --- Input Schema --- class TransactionData(BaseModel): Transaction_Id: str Hit_Seq: int Hit_Id_List: str Origin: str Designation: str Keywords: str Name: str SWIFT_Tag: str Currency: str Entity: str Message: str City: str Country: str State: str Hit_Type: str Record_Matching_String: str WatchList_Match_String: str Payment_Sender_Name: Optional[str] = "" Payment_Reciever_Name: Optional[str] = "" Swift_Message_Type: str Text_Sanction_Data: str Matched_Sanctioned_Entity: str Is_Match: int Red_Flag_Reason: str Risk_Level: str Risk_Score: float Risk_Score_Description: str CDD_Level: str PEP_Status: str Value_Date: str Last_Review_Date: str Next_Review_Date: str Sanction_Description: str Checker_Notes: str Sanction_Context: str Maker_Action: str Customer_ID: int Customer_Type: str Industry: str Transaction_Date_Time: str Transaction_Type: str Transaction_Channel: str Originating_Bank: str Beneficiary_Bank: str Geographic_Origin: str Geographic_Destination: str Match_Score: float Match_Type: str Sanctions_List_Version: str Screening_Date_Time: str Risk_Category: str Risk_Drivers: str Alert_Status: str Investigation_Outcome: str Case_Owner_Analyst: str Escalation_Level: str Escalation_Date: str Regulatory_Reporting_Flags: bool Audit_Trail_Timestamp: str Source_Of_Funds: str Purpose_Of_Transaction: str Beneficial_Owner: str Sanctions_Exposure_History: bool class PredictionRequest(BaseModel): transaction_data: TransactionData class DataPathInput(BaseModel): data_path: str # --- Root --- @app.get("/") def health_check(): return {"status": "healthy", "message": "LOGREG TF-IDF API is running"} # --- Train --- @app.post("/train") def train(): try: # Load data df = pd.read_csv(config.DATA_PATH) # Prepare features and labels X = df[config.TEXT_COLUMN] y = df[config.LABEL_COLUMNS] # Encode labels using LabelEncoder per column label_encoders = {} y_encoded = pd.DataFrame() for col in config.LABEL_COLUMNS: le = LabelEncoder() y_encoded[col] = le.fit_transform(y[col]) label_encoders[col] = le # Train/test split X_train, X_test, y_train, y_test = train_test_split( X, y_encoded, test_size=config.TEST_SIZE, random_state=config.RANDOM_STATE ) # TF-IDF vectorizer vectorizer = TfidfVectorizer( max_features=config.TFIDF_MAX_FEATURES, ngram_range=config.NGRAM_RANGE, stop_words='english' if config.USE_STOPWORDS else None ) X_train_vec = vectorizer.fit_transform(X_train) X_test_vec = vectorizer.transform(X_test) # Model model = MultiOutputClassifier(LogisticRegression(max_iter=1000)) model.fit(X_train_vec, y_train) # Predictions y_pred = model.predict(X_test_vec) # Accuracy accuracy = { col: accuracy_score(y_test[col], [pred[i] for pred in y_pred]) for i, col in enumerate(config.LABEL_COLUMNS) } # Save all joblib.dump(model, config.MODEL_PATH) joblib.dump(vectorizer, config.TFIDF_VECTORIZER_PATH) joblib.dump(label_encoders, config.LABEL_ENCODERS_PATH) return { "message": "Training completed successfully.", "accuracy": accuracy } except Exception as e: return {"error": str(e)} # --- Validate (only structure check) --- @app.post("/validate") def validate_model(input: DataPathInput): try: df = pd.read_csv(input.data_path) required_columns = [TEXT_COLUMN] + LABEL_COLUMNS missing = [col for col in required_columns if col not in df.columns] if missing: return {"status": " Invalid input", "missing_columns": missing} else: return {"status": " Input is valid."} except Exception as e: raise HTTPException(status_code=500, detail=f"Validation error: {str(e)}") # --- Test --- @app.post("/test") def test_model(input: DataPathInput): try: df = pd.read_csv(input.data_path) df = df.dropna(subset=[TEXT_COLUMN]) tfidf = joblib.load(TFIDF_PATH) model = joblib.load(MODEL_PATH) encoders = joblib.load(ENCODERS_PATH) X_vec = tfidf.transform(df[TEXT_COLUMN]) preds = model.predict(X_vec) decoded_preds = [] for pred in preds: decoded = { col: encoders[col].inverse_transform([label])[0] for col, label in zip(LABEL_COLUMNS, pred) } decoded_preds.append(decoded) return {"predictions": decoded_preds[:8]} # Sample 5 predictions except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # --- Predict --- @app.post("/predict") def predict(request: PredictionRequest): try: input_data = pd.DataFrame([request.transaction_data.dict()]) # Structured field-based formatted text input text_input = f""" Transaction ID: {input_data.get('Transaction_Id', [''])[0]} Origin: {input_data.get('Origin', [''])[0]} Designation: {input_data.get('Designation', [''])[0]} Keywords: {input_data.get('Keywords', [''])[0]} Name: {input_data.get('Name', [''])[0]} SWIFT Tag: {input_data.get('SWIFT_Tag', [''])[0]} Currency: {input_data.get('Currency', [''])[0]} Entity: {input_data.get('Entity', [''])[0]} Message: {input_data.get('Message', [''])[0]} City: {input_data.get('City', [''])[0]} Country: {input_data.get('Country', [''])[0]} State: {input_data.get('State', [''])[0]} Hit Type: {input_data.get('Hit_Type', [''])[0]} Record Matching String: {input_data.get('Record_Matching_String', [''])[0]} WatchList Match String: {input_data.get('WatchList_Match_String', [''])[0]} Payment Sender: {input_data.get('Payment_Sender_Name', [''])[0]} Payment Receiver: {input_data.get('Payment_Reciever_Name', [''])[0]} Swift Message Type: {input_data.get('Swift_Message_Type', [''])[0]} Text Sanction Data: {input_data.get('Text_Sanction_Data', [''])[0]} Matched Sanctioned Entity: {input_data.get('Matched_Sanctioned_Entity', [''])[0]} Red Flag Reason: {input_data.get('Red_Flag_Reason', [''])[0]} Risk Level: {input_data.get('Risk_Level', [''])[0]} Risk Score: {input_data.get('Risk_Score', [''])[0]} CDD Level: {input_data.get('CDD_Level', [''])[0]} PEP Status: {input_data.get('PEP_Status', [''])[0]} Sanction Description: {input_data.get('Sanction_Description', [''])[0]} Checker Notes: {input_data.get('Checker_Notes', [''])[0]} Sanction Context: {input_data.get('Sanction_Context', [''])[0]} Maker Action: {input_data.get('Maker_Action', [''])[0]} Customer Type: {input_data.get('Customer_Type', [''])[0]} Industry: {input_data.get('Industry', [''])[0]} Transaction Type: {input_data.get('Transaction_Type', [''])[0]} Transaction Channel: {input_data.get('Transaction_Channel', [''])[0]} Geographic Origin: {input_data.get('Geographic_Origin', [''])[0]} Geographic Destination: {input_data.get('Geographic_Destination', [''])[0]} Risk Category: {input_data.get('Risk_Category', [''])[0]} Risk Drivers: {input_data.get('Risk_Drivers', [''])[0]} Alert Status: {input_data.get('Alert_Status', [''])[0]} Investigation Outcome: {input_data.get('Investigation_Outcome', [''])[0]} Source of Funds: {input_data.get('Source_Of_Funds', [''])[0]} Purpose of Transaction: {input_data.get('Purpose_Of_Transaction', [''])[0]} Beneficial Owner: {input_data.get('Beneficial_Owner', [''])[0]} """ # Load TF-IDF and model tfidf = joblib.load(TFIDF_PATH) model = joblib.load(MODEL_PATH) encoders = joblib.load(ENCODERS_PATH) # Predict X_vec = tfidf.transform([text_input]) pred = model.predict(X_vec)[0] # Decode predictions decoded = { col: encoders[col].inverse_transform([p])[0] for col, p in zip(LABEL_COLUMNS, pred) } return {"prediction": decoded} except Exception as e: raise HTTPException(status_code=500, detail=str(e))