TFIDF_LOGREG / app.py
subbunanepalli's picture
Update app.py
8964c36 verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional
import pandas as pd
import joblib
import os
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.preprocessing import LabelEncoder
from sklearn.multioutput import MultiOutputClassifier
import config
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
# --- Configuration ---
LABEL_COLUMNS = [
"Red_Flag_Reason", "Maker_Action", "Escalation_Level",
"Risk_Category", "Risk_Drivers", "Investigation_Outcome"
]
TEXT_COLUMN = "Sanction_Context"
MODEL_DIR = "/tmp"
MODEL_PATH = os.path.join(MODEL_DIR, "logreg_model.pkl")
TFIDF_PATH = os.path.join(MODEL_DIR, "tfidf_vectorizer.pkl")
ENCODERS_PATH = os.path.join(MODEL_DIR, "label_encoders.pkl")
# --- FastAPI App ---
app = FastAPI()
# --- Input Schema ---
class TransactionData(BaseModel):
Transaction_Id: str
Hit_Seq: int
Hit_Id_List: str
Origin: str
Designation: str
Keywords: str
Name: str
SWIFT_Tag: str
Currency: str
Entity: str
Message: str
City: str
Country: str
State: str
Hit_Type: str
Record_Matching_String: str
WatchList_Match_String: str
Payment_Sender_Name: Optional[str] = ""
Payment_Reciever_Name: Optional[str] = ""
Swift_Message_Type: str
Text_Sanction_Data: str
Matched_Sanctioned_Entity: str
Is_Match: int
Red_Flag_Reason: str
Risk_Level: str
Risk_Score: float
Risk_Score_Description: str
CDD_Level: str
PEP_Status: str
Value_Date: str
Last_Review_Date: str
Next_Review_Date: str
Sanction_Description: str
Checker_Notes: str
Sanction_Context: str
Maker_Action: str
Customer_ID: int
Customer_Type: str
Industry: str
Transaction_Date_Time: str
Transaction_Type: str
Transaction_Channel: str
Originating_Bank: str
Beneficiary_Bank: str
Geographic_Origin: str
Geographic_Destination: str
Match_Score: float
Match_Type: str
Sanctions_List_Version: str
Screening_Date_Time: str
Risk_Category: str
Risk_Drivers: str
Alert_Status: str
Investigation_Outcome: str
Case_Owner_Analyst: str
Escalation_Level: str
Escalation_Date: str
Regulatory_Reporting_Flags: bool
Audit_Trail_Timestamp: str
Source_Of_Funds: str
Purpose_Of_Transaction: str
Beneficial_Owner: str
Sanctions_Exposure_History: bool
class PredictionRequest(BaseModel):
transaction_data: TransactionData
class DataPathInput(BaseModel):
data_path: str
# --- Root ---
@app.get("/")
def health_check():
return {"status": "healthy", "message": "LOGREG TF-IDF API is running"}
# --- Train ---
@app.post("/train")
def train():
try:
# Load data
df = pd.read_csv(config.DATA_PATH)
# Prepare features and labels
X = df[config.TEXT_COLUMN]
y = df[config.LABEL_COLUMNS]
# Encode labels using LabelEncoder per column
label_encoders = {}
y_encoded = pd.DataFrame()
for col in config.LABEL_COLUMNS:
le = LabelEncoder()
y_encoded[col] = le.fit_transform(y[col])
label_encoders[col] = le
# Train/test split
X_train, X_test, y_train, y_test = train_test_split(
X, y_encoded, test_size=config.TEST_SIZE, random_state=config.RANDOM_STATE
)
# TF-IDF vectorizer
vectorizer = TfidfVectorizer(
max_features=config.TFIDF_MAX_FEATURES,
ngram_range=config.NGRAM_RANGE,
stop_words='english' if config.USE_STOPWORDS else None
)
X_train_vec = vectorizer.fit_transform(X_train)
X_test_vec = vectorizer.transform(X_test)
# Model
model = MultiOutputClassifier(LogisticRegression(max_iter=1000))
model.fit(X_train_vec, y_train)
# Predictions
y_pred = model.predict(X_test_vec)
# Accuracy
accuracy = {
col: accuracy_score(y_test[col], [pred[i] for pred in y_pred])
for i, col in enumerate(config.LABEL_COLUMNS)
}
# Save all
joblib.dump(model, config.MODEL_PATH)
joblib.dump(vectorizer, config.TFIDF_VECTORIZER_PATH)
joblib.dump(label_encoders, config.LABEL_ENCODERS_PATH)
return {
"message": "Training completed successfully.",
"accuracy": accuracy
}
except Exception as e:
return {"error": str(e)}
# --- Validate (only structure check) ---
@app.post("/validate")
def validate_model(input: DataPathInput):
try:
df = pd.read_csv(input.data_path)
required_columns = [TEXT_COLUMN] + LABEL_COLUMNS
missing = [col for col in required_columns if col not in df.columns]
if missing:
return {"status": " Invalid input", "missing_columns": missing}
else:
return {"status": " Input is valid."}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Validation error: {str(e)}")
# --- Test ---
@app.post("/test")
def test_model(input: DataPathInput):
try:
df = pd.read_csv(input.data_path)
df = df.dropna(subset=[TEXT_COLUMN])
tfidf = joblib.load(TFIDF_PATH)
model = joblib.load(MODEL_PATH)
encoders = joblib.load(ENCODERS_PATH)
X_vec = tfidf.transform(df[TEXT_COLUMN])
preds = model.predict(X_vec)
decoded_preds = []
for pred in preds:
decoded = {
col: encoders[col].inverse_transform([label])[0]
for col, label in zip(LABEL_COLUMNS, pred)
}
decoded_preds.append(decoded)
return {"predictions": decoded_preds[:8]} # Sample 5 predictions
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# --- Predict ---
@app.post("/predict")
def predict(request: PredictionRequest):
try:
input_data = pd.DataFrame([request.transaction_data.dict()])
# Structured field-based formatted text input
text_input = f"""
Transaction ID: {input_data.get('Transaction_Id', [''])[0]}
Origin: {input_data.get('Origin', [''])[0]}
Designation: {input_data.get('Designation', [''])[0]}
Keywords: {input_data.get('Keywords', [''])[0]}
Name: {input_data.get('Name', [''])[0]}
SWIFT Tag: {input_data.get('SWIFT_Tag', [''])[0]}
Currency: {input_data.get('Currency', [''])[0]}
Entity: {input_data.get('Entity', [''])[0]}
Message: {input_data.get('Message', [''])[0]}
City: {input_data.get('City', [''])[0]}
Country: {input_data.get('Country', [''])[0]}
State: {input_data.get('State', [''])[0]}
Hit Type: {input_data.get('Hit_Type', [''])[0]}
Record Matching String: {input_data.get('Record_Matching_String', [''])[0]}
WatchList Match String: {input_data.get('WatchList_Match_String', [''])[0]}
Payment Sender: {input_data.get('Payment_Sender_Name', [''])[0]}
Payment Receiver: {input_data.get('Payment_Reciever_Name', [''])[0]}
Swift Message Type: {input_data.get('Swift_Message_Type', [''])[0]}
Text Sanction Data: {input_data.get('Text_Sanction_Data', [''])[0]}
Matched Sanctioned Entity: {input_data.get('Matched_Sanctioned_Entity', [''])[0]}
Red Flag Reason: {input_data.get('Red_Flag_Reason', [''])[0]}
Risk Level: {input_data.get('Risk_Level', [''])[0]}
Risk Score: {input_data.get('Risk_Score', [''])[0]}
CDD Level: {input_data.get('CDD_Level', [''])[0]}
PEP Status: {input_data.get('PEP_Status', [''])[0]}
Sanction Description: {input_data.get('Sanction_Description', [''])[0]}
Checker Notes: {input_data.get('Checker_Notes', [''])[0]}
Sanction Context: {input_data.get('Sanction_Context', [''])[0]}
Maker Action: {input_data.get('Maker_Action', [''])[0]}
Customer Type: {input_data.get('Customer_Type', [''])[0]}
Industry: {input_data.get('Industry', [''])[0]}
Transaction Type: {input_data.get('Transaction_Type', [''])[0]}
Transaction Channel: {input_data.get('Transaction_Channel', [''])[0]}
Geographic Origin: {input_data.get('Geographic_Origin', [''])[0]}
Geographic Destination: {input_data.get('Geographic_Destination', [''])[0]}
Risk Category: {input_data.get('Risk_Category', [''])[0]}
Risk Drivers: {input_data.get('Risk_Drivers', [''])[0]}
Alert Status: {input_data.get('Alert_Status', [''])[0]}
Investigation Outcome: {input_data.get('Investigation_Outcome', [''])[0]}
Source of Funds: {input_data.get('Source_Of_Funds', [''])[0]}
Purpose of Transaction: {input_data.get('Purpose_Of_Transaction', [''])[0]}
Beneficial Owner: {input_data.get('Beneficial_Owner', [''])[0]}
"""
# Load TF-IDF and model
tfidf = joblib.load(TFIDF_PATH)
model = joblib.load(MODEL_PATH)
encoders = joblib.load(ENCODERS_PATH)
# Predict
X_vec = tfidf.transform([text_input])
pred = model.predict(X_vec)[0]
# Decode predictions
decoded = {
col: encoders[col].inverse_transform([p])[0]
for col, p in zip(LABEL_COLUMNS, pred)
}
return {"prediction": decoded}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))