LogReg_model / app.py
subbunanepalli's picture
Update app.py
d133e6d verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional
import pandas as pd
import joblib
import os
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.multioutput import MultiOutputClassifier
from sklearn.pipeline import Pipeline
# ========== Config ==========
DATA_PATH = "data/synthetic_transactions_samples_5000.csv"
MODEL_DIR = "models"
MODEL_PATH = os.path.join(MODEL_DIR, "logreg_model.pkl")
# ========== FastAPI Init ==========
app = FastAPI()
# ========== Input Schema ==========
class TransactionData(BaseModel):
Transaction_Id: str
Hit_Seq: int
Hit_Id_List: str
Origin: str
Designation: str
Keywords: str
Name: str
SWIFT_Tag: str
Currency: str
Entity: str
Message: str
City: str
Country: str
State: str
Hit_Type: str
Record_Matching_String: str
WatchList_Match_String: str
Payment_Sender_Name: Optional[str] = ""
Payment_Reciever_Name: Optional[str] = ""
Swift_Message_Type: str
Text_Sanction_Data: str
Matched_Sanctioned_Entity: str
Is_Match: int
Red_Flag_Reason: str
Risk_Level: str
Risk_Score: float
Risk_Score_Description: str
CDD_Level: str
PEP_Status: str
Value_Date: str
Last_Review_Date: str
Next_Review_Date: str
Sanction_Description: str
Checker_Notes: str
Sanction_Context: str
Maker_Action: str
Customer_ID: int
Customer_Type: str
Industry: str
Transaction_Date_Time: str
Transaction_Type: str
Transaction_Channel: str
Originating_Bank: str
Beneficiary_Bank: str
Geographic_Origin: str
Geographic_Destination: str
Match_Score: float
Match_Type: str
Sanctions_List_Version: str
Screening_Date_Time: str
Risk_Category: str
Risk_Drivers: str
Alert_Status: str
Investigation_Outcome: str
Case_Owner_Analyst: str
Escalation_Level: str
Escalation_Date: str
Regulatory_Reporting_Flags: bool
Audit_Trail_Timestamp: str
Source_Of_Funds: str
Purpose_Of_Transaction: str
Beneficial_Owner: str
Sanctions_Exposure_History: bool
# ========== Utils ==========
def create_text_input(row):
return f"""
Transaction ID: {row['Transaction_Id']}
Origin: {row['Origin']}
Designation: {row['Designation']}
Keywords: {row['Keywords']}
Name: {row['Name']}
SWIFT Tag: {row['SWIFT_Tag']}
Currency: {row['Currency']}
Entity: {row['Entity']}
Message: {row['Message']}
City: {row['City']}
Country: {row['Country']}
State: {row['State']}
Hit Type: {row['Hit_Type']}
Record Matching String: {row['Record_Matching_String']}
WatchList Match String: {row['WatchList_Match_String']}
Payment Sender: {row['Payment_Sender_Name']}
Payment Receiver: {row['Payment_Reciever_Name']}
Swift Message Type: {row['Swift_Message_Type']}
Text Sanction Data: {row['Text_Sanction_Data']}
Matched Sanctioned Entity: {row['Matched_Sanctioned_Entity']}
Red Flag Reason: {row['Red_Flag_Reason']}
Risk Level: {row['Risk_Level']}
Risk Score: {row['Risk_Score']}
CDD Level: {row['CDD_Level']}
PEP Status: {row['PEP_Status']}
Sanction Description: {row['Sanction_Description']}
Checker Notes: {row['Checker_Notes']}
Sanction Context: {row['Sanction_Context']}
Maker Action: {row['Maker_Action']}
Customer Type: {row['Customer_Type']}
Industry: {row['Industry']}
Transaction Type: {row['Transaction_Type']}
Transaction Channel: {row['Transaction_Channel']}
Geographic Origin: {row['Geographic_Origin']}
Geographic Destination: {row['Geographic_Destination']}
Risk Category: {row['Risk_Category']}
Risk Drivers: {row['Risk_Drivers']}
Alert Status: {row['Alert_Status']}
Investigation Outcome: {row['Investigation_Outcome']}
Source of Funds: {row['Source_Of_Funds']}
Purpose of Transaction: {row['Purpose_Of_Transaction']}
Beneficial Owner: {row['Beneficial_Owner']}
"""
# ========== Root ==========
@app.get("/")
def root():
return {"message": "TF-IDF Logistic Regression API is running."}
# ========== API Routes ==========
@app.post("/train")
def train_model():
df = pd.read_csv(DATA_PATH)
df = df.fillna("")
df["text_input"] = df.apply(create_text_input, axis=1)
X = df["text_input"]
y = df[["Maker_Action", "Escalation_Level", "Risk_Category", "Risk_Drivers", "Investigation_Outcome", "Red_Flag_Reason"]]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
vectorizer = TfidfVectorizer()
classifier = MultiOutputClassifier(LogisticRegression(max_iter=1000))
pipeline = Pipeline([
("vectorizer", vectorizer),
("classifier", classifier)
])
pipeline.fit(X_train, y_train)
os.makedirs(MODEL_DIR, exist_ok=True)
joblib.dump(pipeline, MODEL_PATH)
accuracy = pipeline.score(X_test, y_test)
return {"message": "Model trained and saved.", "accuracy": accuracy}
@app.post("/predict")
def predict(request: TransactionData):
try:
model = joblib.load(MODEL_PATH)
input_data = pd.DataFrame([request.dict()])
input_data = input_data.fillna("")
text_input = create_text_input(input_data.iloc[0])
prediction = model.predict([text_input])[0]
return {
"Maker_Action": prediction[0],
"Escalation_Level": prediction[1],
"Risk_Category": prediction[2],
"Risk_Drivers": prediction[3],
"Investigation_Outcome": prediction[4],
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/validate")
def validate_input(request: TransactionData):
return {"message": "Input is valid."}
@app.get("/test")
def test_api():
return {"message": "Test successful."}