LogReg / train.py
subbunanepalli's picture
Update train.py
906f3f9 verified
import os
import joblib
import pandas as pd
from fastapi import HTTPException
from sklearn.pipeline import Pipeline
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.multioutput import MultiOutputClassifier
from utils import create_text_input
# ========== Config ==========
DATA_PATH = "data/synthetic_transactions_samples_5000.csv"
MODEL_DIR = "models"
MODEL_PATH = os.path.join(MODEL_DIR, "logreg_model.pkl")
def train_model():
try:
# Load and preprocess data
df = pd.read_csv(DATA_PATH).fillna("")
df["text_input"] = df.apply(create_text_input, axis=1)
# Features and targets
X = df["text_input"]
y = df[[
"Maker_Action",
"Escalation_Level",
"Risk_Category",
"Risk_Drivers",
"Investigation_Outcome",
"Red_Flag_Reason"
]]
# Train/test split
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
# Pipeline: TF-IDF + MultiOutput LR
pipeline = Pipeline([
("vectorizer", TfidfVectorizer()),
("classifier", MultiOutputClassifier(LogisticRegression(max_iter=1000)))
])
# Train
pipeline.fit(X_train, y_train)
# Save model
os.makedirs(MODEL_DIR, exist_ok=True)
joblib.dump(pipeline, MODEL_PATH)
# Evaluate
accuracy = pipeline.score(X_test, y_test)
return {
"message": "Model trained and saved successfully.",
"accuracy": round(accuracy, 4)
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))