from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import Optional, Dict, Any import pandas as pd from sklearn.model_selection import train_test_split from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.linear_model import LogisticRegression from sklearn.multioutput import MultiOutputClassifier from sklearn.metrics import accuracy_score import joblib import os import config app = FastAPI() # Load model and vectorizer if exist def load_model_vectorizer(): if os.path.exists(config.MODEL_PATH) and os.path.exists(config.TFIDF_PATH): model = joblib.load(config.MODEL_PATH) vectorizer = joblib.load(config.TFIDF_PATH) return model, vectorizer else: return None, None # Pydantic model for prediction input (add all your fields as needed) class TransactionData(BaseModel): Transaction_Id: str Origin: str Designation: str Keywords: str Name: str SWIFT_Tag: str Currency: str Entity: str Message: str City: str Country: str State: str Hit_Type: str Record_Matching_String: str WatchList_Match_String: str Payment_Sender_Name: Optional[str] = "" Payment_Reciever_Name: Optional[str] = "" Swift_Message_Type: str Text_Sanction_Data: str Matched_Sanctioned_Entity: str Red_Flag_Reason: str Risk_Level: str Risk_Score: float CDD_Level: str PEP_Status: str Sanction_Description: str Checker_Notes: str Sanction_Context: str @app.get("/") async def root(): return {"message": "Welcome to the Logistic Regression API"} # --- TRAIN --- @app.post("/train") def train(): try: os.makedirs(config.MODEL_SAVE_DIR, exist_ok=True) df = pd.read_csv(config.DATA_PATH) X = df[config.TEXT_COLUMN] y = df[config.LABEL_COLUMNS] X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) vectorizer = TfidfVectorizer() X_train_vec = vectorizer.fit_transform(X_train) X_test_vec = vectorizer.transform(X_test) model = MultiOutputClassifier(LogisticRegression(max_iter=1000)) model.fit(X_train_vec, y_train) y_pred = model.predict(X_test_vec) accuracy = { col: accuracy_score(y_test[col], [pred[i] for pred in y_pred]) for i, col in enumerate(y.columns) } joblib.dump(model, config.MODEL_PATH) joblib.dump(vectorizer, config.TFIDF_PATH) return {"message": "Training completed", "accuracy": accuracy} except Exception as e: return {"error": str(e)} # --- PREDICT --- class PredictionRequest(BaseModel): transaction_data: TransactionData @app.post("/predict") def predict(request: PredictionRequest): try: input_data = pd.DataFrame([request.transaction_data.dict()]) # Use .get(..., [''])[0] to avoid KeyError if field is missing text_input = f""" Transaction ID: {input_data.get('Transaction_Id', [''])[0]} Origin: {input_data.get('Origin', [''])[0]} Designation: {input_data.get('Designation', [''])[0]} Keywords: {input_data.get('Keywords', [''])[0]} Name: {input_data.get('Name', [''])[0]} SWIFT Tag: {input_data.get('SWIFT_Tag', [''])[0]} Currency: {input_data.get('Currency', [''])[0]} Entity: {input_data.get('Entity', [''])[0]} Message: {input_data.get('Message', [''])[0]} City: {input_data.get('City', [''])[0]} Country: {input_data.get('Country', [''])[0]} State: {input_data.get('State', [''])[0]} Hit Type: {input_data.get('Hit_Type', [''])[0]} Record Matching String: {input_data.get('Record_Matching_String', [''])[0]} WatchList Match String: {input_data.get('WatchList_Match_String', [''])[0]} Payment Sender: {input_data.get('Payment_Sender_Name', [''])[0]} Payment Receiver: {input_data.get('Payment_Reciever_Name', [''])[0]} Swift Message Type: {input_data.get('Swift_Message_Type', [''])[0]} Text Sanction Data: {input_data.get('Text_Sanction_Data', [''])[0]} Matched Sanctioned Entity: {input_data.get('Matched_Sanctioned_Entity', [''])[0]} Red Flag Reason: {input_data.get('Red_Flag_Reason', [''])[0]} Risk Level: {input_data.get('Risk_Level', [''])[0]} Risk Score: {input_data.get('Risk_Score', [''])[0]} CDD Level: {input_data.get('CDD_Level', [''])[0]} PEP Status: {input_data.get('PEP_Status', [''])[0]} Sanction Description: {input_data.get('Sanction_Description', [''])[0]} Checker Notes: {input_data.get('Checker_Notes', [''])[0]} Sanction Context: {input_data.get('Sanction_Context', [''])[0]} Maker Action: {input_data.get('Maker_Action', [''])[0]} Customer Type: {input_data.get('Customer_Type', [''])[0]} Industry: {input_data.get('Industry', [''])[0]} Transaction Type: {input_data.get('Transaction_Type', [''])[0]} Transaction Channel: {input_data.get('Transaction_Channel', [''])[0]} Geographic Origin: {input_data.get('Geographic_Origin', [''])[0]} Geographic Destination: {input_data.get('Geographic_Destination', [''])[0]} Risk Category: {input_data.get('Risk_Category', [''])[0]} Risk Drivers: {input_data.get('Risk_Drivers', [''])[0]} Alert Status: {input_data.get('Alert_Status', [''])[0]} Investigation Outcome: {input_data.get('Investigation_Outcome', [''])[0]} Source of Funds: {input_data.get('Source_Of_Funds', [''])[0]} Purpose of Transaction: {input_data.get('Purpose_Of_Transaction', [''])[0]} Beneficial Owner: {input_data.get('Beneficial_Owner', [''])[0]} """ # Load model and vectorizer model = joblib.load(config.MODEL_PATH) vectorizer = joblib.load(config.TFIDF_PATH) # Vectorize and predict X_vec = vectorizer.transform([text_input]) preds = model.predict(X_vec)[0] # Response with label names response = {label: pred for label, pred in zip(config.LABEL_COLUMNS, preds)} return response except Exception as e: return {"error": str(e)} # --- TEST --- @app.get("/test") def test(): if not os.path.exists(config.DATA_PATH): raise HTTPException(status_code=404, detail="Test data file not found") df = pd.read_csv(config.DATA_PATH) model, vectorizer = load_model_vectorizer() if model is None or vectorizer is None: raise HTTPException(status_code=400, detail="Model is not trained yet. Please train first.") X = df[config.TEXT_COLUMN] y = df[config.LABEL_COLUMNS] X_vec = vectorizer.transform(X) y_pred = model.predict(X_vec) accuracy = { col: accuracy_score(y[col], [pred[i] for pred in y_pred]) for i, col in enumerate(y.columns) } return {"accuracy": accuracy} # --- VALIDATE --- @app.post("/validate") def validate(data: dict): # Accept dict with list of inputs for batch validation inputs = data.get("inputs") if not inputs: raise HTTPException(status_code=400, detail="No inputs provided") model, vectorizer = load_model_vectorizer() if model is None or vectorizer is None: raise HTTPException(status_code=400, detail="Model is not trained yet. Please train first.") results = [] for item in inputs: # Construct text input string text_input = "" for key, value in item.items(): text_input += f"{key}: {value}\n" X_vec = vectorizer.transform([text_input]) preds = model.predict(X_vec)[0] result = {label: pred for label, pred in zip(config.LABEL_COLUMNS, preds)} results.append(result) return {"results": results}