LOG_REGRESSION / app.py
subbunanepalli's picture
Update app.py
065b4c9 verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional, Dict, Any
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.multioutput import MultiOutputClassifier
from sklearn.metrics import accuracy_score
import joblib
import os
import config
app = FastAPI()
# Load model and vectorizer if exist
def load_model_vectorizer():
if os.path.exists(config.MODEL_PATH) and os.path.exists(config.TFIDF_PATH):
model = joblib.load(config.MODEL_PATH)
vectorizer = joblib.load(config.TFIDF_PATH)
return model, vectorizer
else:
return None, None
# Pydantic model for prediction input (add all your fields as needed)
class TransactionData(BaseModel):
Transaction_Id: str
Origin: str
Designation: str
Keywords: str
Name: str
SWIFT_Tag: str
Currency: str
Entity: str
Message: str
City: str
Country: str
State: str
Hit_Type: str
Record_Matching_String: str
WatchList_Match_String: str
Payment_Sender_Name: Optional[str] = ""
Payment_Reciever_Name: Optional[str] = ""
Swift_Message_Type: str
Text_Sanction_Data: str
Matched_Sanctioned_Entity: str
Red_Flag_Reason: str
Risk_Level: str
Risk_Score: float
CDD_Level: str
PEP_Status: str
Sanction_Description: str
Checker_Notes: str
Sanction_Context: str
@app.get("/")
async def root():
return {"message": "Welcome to the Logistic Regression API"}
# --- TRAIN ---
@app.post("/train")
def train():
try:
os.makedirs(config.MODEL_SAVE_DIR, exist_ok=True)
df = pd.read_csv(config.DATA_PATH)
X = df[config.TEXT_COLUMN]
y = df[config.LABEL_COLUMNS]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
vectorizer = TfidfVectorizer()
X_train_vec = vectorizer.fit_transform(X_train)
X_test_vec = vectorizer.transform(X_test)
model = MultiOutputClassifier(LogisticRegression(max_iter=1000))
model.fit(X_train_vec, y_train)
y_pred = model.predict(X_test_vec)
accuracy = {
col: accuracy_score(y_test[col], [pred[i] for pred in y_pred])
for i, col in enumerate(y.columns)
}
joblib.dump(model, config.MODEL_PATH)
joblib.dump(vectorizer, config.TFIDF_PATH)
return {"message": "Training completed", "accuracy": accuracy}
except Exception as e:
return {"error": str(e)}
# --- PREDICT ---
class PredictionRequest(BaseModel):
transaction_data: TransactionData
@app.post("/predict")
def predict(request: PredictionRequest):
try:
input_data = pd.DataFrame([request.transaction_data.dict()])
# Use .get(..., [''])[0] to avoid KeyError if field is missing
text_input = f"""
Transaction ID: {input_data.get('Transaction_Id', [''])[0]}
Origin: {input_data.get('Origin', [''])[0]}
Designation: {input_data.get('Designation', [''])[0]}
Keywords: {input_data.get('Keywords', [''])[0]}
Name: {input_data.get('Name', [''])[0]}
SWIFT Tag: {input_data.get('SWIFT_Tag', [''])[0]}
Currency: {input_data.get('Currency', [''])[0]}
Entity: {input_data.get('Entity', [''])[0]}
Message: {input_data.get('Message', [''])[0]}
City: {input_data.get('City', [''])[0]}
Country: {input_data.get('Country', [''])[0]}
State: {input_data.get('State', [''])[0]}
Hit Type: {input_data.get('Hit_Type', [''])[0]}
Record Matching String: {input_data.get('Record_Matching_String', [''])[0]}
WatchList Match String: {input_data.get('WatchList_Match_String', [''])[0]}
Payment Sender: {input_data.get('Payment_Sender_Name', [''])[0]}
Payment Receiver: {input_data.get('Payment_Reciever_Name', [''])[0]}
Swift Message Type: {input_data.get('Swift_Message_Type', [''])[0]}
Text Sanction Data: {input_data.get('Text_Sanction_Data', [''])[0]}
Matched Sanctioned Entity: {input_data.get('Matched_Sanctioned_Entity', [''])[0]}
Red Flag Reason: {input_data.get('Red_Flag_Reason', [''])[0]}
Risk Level: {input_data.get('Risk_Level', [''])[0]}
Risk Score: {input_data.get('Risk_Score', [''])[0]}
CDD Level: {input_data.get('CDD_Level', [''])[0]}
PEP Status: {input_data.get('PEP_Status', [''])[0]}
Sanction Description: {input_data.get('Sanction_Description', [''])[0]}
Checker Notes: {input_data.get('Checker_Notes', [''])[0]}
Sanction Context: {input_data.get('Sanction_Context', [''])[0]}
Maker Action: {input_data.get('Maker_Action', [''])[0]}
Customer Type: {input_data.get('Customer_Type', [''])[0]}
Industry: {input_data.get('Industry', [''])[0]}
Transaction Type: {input_data.get('Transaction_Type', [''])[0]}
Transaction Channel: {input_data.get('Transaction_Channel', [''])[0]}
Geographic Origin: {input_data.get('Geographic_Origin', [''])[0]}
Geographic Destination: {input_data.get('Geographic_Destination', [''])[0]}
Risk Category: {input_data.get('Risk_Category', [''])[0]}
Risk Drivers: {input_data.get('Risk_Drivers', [''])[0]}
Alert Status: {input_data.get('Alert_Status', [''])[0]}
Investigation Outcome: {input_data.get('Investigation_Outcome', [''])[0]}
Source of Funds: {input_data.get('Source_Of_Funds', [''])[0]}
Purpose of Transaction: {input_data.get('Purpose_Of_Transaction', [''])[0]}
Beneficial Owner: {input_data.get('Beneficial_Owner', [''])[0]}
"""
# Load model and vectorizer
model = joblib.load(config.MODEL_PATH)
vectorizer = joblib.load(config.TFIDF_PATH)
# Vectorize and predict
X_vec = vectorizer.transform([text_input])
preds = model.predict(X_vec)[0]
# Response with label names
response = {label: pred for label, pred in zip(config.LABEL_COLUMNS, preds)}
return response
except Exception as e:
return {"error": str(e)}
# --- TEST ---
@app.get("/test")
def test():
if not os.path.exists(config.DATA_PATH):
raise HTTPException(status_code=404, detail="Test data file not found")
df = pd.read_csv(config.DATA_PATH)
model, vectorizer = load_model_vectorizer()
if model is None or vectorizer is None:
raise HTTPException(status_code=400, detail="Model is not trained yet. Please train first.")
X = df[config.TEXT_COLUMN]
y = df[config.LABEL_COLUMNS]
X_vec = vectorizer.transform(X)
y_pred = model.predict(X_vec)
accuracy = {
col: accuracy_score(y[col], [pred[i] for pred in y_pred])
for i, col in enumerate(y.columns)
}
return {"accuracy": accuracy}
# --- VALIDATE ---
@app.post("/validate")
def validate(data: dict):
# Accept dict with list of inputs for batch validation
inputs = data.get("inputs")
if not inputs:
raise HTTPException(status_code=400, detail="No inputs provided")
model, vectorizer = load_model_vectorizer()
if model is None or vectorizer is None:
raise HTTPException(status_code=400, detail="Model is not trained yet. Please train first.")
results = []
for item in inputs:
# Construct text input string
text_input = ""
for key, value in item.items():
text_input += f"{key}: {value}\n"
X_vec = vectorizer.transform([text_input])
preds = model.predict(X_vec)[0]
result = {label: pred for label, pred in zip(config.LABEL_COLUMNS, preds)}
results.append(result)
return {"results": results}