File size: 1,818 Bytes
0c59e63
 
906f3f9
 
0c59e63
 
 
906f3f9
0c59e63
 
 
 
906f3f9
0c59e63
906f3f9
 
0c59e63
 
 
906f3f9
 
0c59e63
 
906f3f9
0c59e63
906f3f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c59e63
 
 
 
906f3f9
 
0c59e63
 
906f3f9
 
0c59e63
 
906f3f9
 
 
 
 
 
 
0c59e63
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
import joblib
import pandas as pd
from fastapi import HTTPException
from sklearn.pipeline import Pipeline
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.multioutput import MultiOutputClassifier

from utils import create_text_input

# ========== Config ==========
DATA_PATH = "data/synthetic_transactions_samples_5000.csv"
MODEL_DIR = "models"
MODEL_PATH = os.path.join(MODEL_DIR, "logreg_model.pkl")

def train_model():
    try:
        # Load and preprocess data
        df = pd.read_csv(DATA_PATH).fillna("")
        df["text_input"] = df.apply(create_text_input, axis=1)

        # Features and targets
        X = df["text_input"]
        y = df[[
            "Maker_Action", 
            "Escalation_Level", 
            "Risk_Category", 
            "Risk_Drivers", 
            "Investigation_Outcome", 
            "Red_Flag_Reason"
        ]]

        # Train/test split
        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=0.2, random_state=42
        )

        # Pipeline: TF-IDF + MultiOutput LR
        pipeline = Pipeline([
            ("vectorizer", TfidfVectorizer()),
            ("classifier", MultiOutputClassifier(LogisticRegression(max_iter=1000)))
        ])

        # Train
        pipeline.fit(X_train, y_train)

        # Save model
        os.makedirs(MODEL_DIR, exist_ok=True)
        joblib.dump(pipeline, MODEL_PATH)

        # Evaluate
        accuracy = pipeline.score(X_test, y_test)

        return {
            "message": "Model trained and saved successfully.",
            "accuracy": round(accuracy, 4)
        }

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))