subbunanepalli commited on
Commit
0c59e63
·
verified ·
1 Parent(s): 78fc4f9

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +40 -0
train.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from fastapi import HTTPException
3
+ import os
4
+ import joblib
5
+ from sklearn.pipeline import Pipeline
6
+ from sklearn.linear_model import LogisticRegression
7
+ from sklearn.model_selection import train_test_split
8
+ from sklearn.feature_extraction.text import TfidfVectorizer
9
+ from sklearn.multioutput import MultiOutputClassifier
10
+
11
+ from utils import create_text_input
12
+
13
+ DATA_PATH = "data/synthetic_transactions_samples_5000.csv"
14
+ MODEL_PATH = "models/logreg_model.pkl"
15
+
16
+ def train_model():
17
+ try:
18
+ df = pd.read_csv(DATA_PATH)
19
+ df = df.fillna("")
20
+ df["text_input"] = df.apply(create_text_input, axis=1)
21
+
22
+ X = df["text_input"]
23
+ y = df[["Maker_Action", "Escalation_Level", "Risk_Category", "Risk_Drivers", "Investigation_Outcome", "Alert_Status"]]
24
+
25
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
26
+
27
+ pipeline = Pipeline([
28
+ ("vectorizer", TfidfVectorizer()),
29
+ ("classifier", MultiOutputClassifier(LogisticRegression(max_iter=1000)))
30
+ ])
31
+ pipeline.fit(X_train, y_train)
32
+
33
+ os.makedirs("models", exist_ok=True)
34
+ joblib.dump(pipeline, MODEL_PATH)
35
+
36
+ acc = pipeline.score(X_test, y_test)
37
+
38
+ return {"message": "Model trained successfully.", "accuracy": acc}
39
+ except Exception as e:
40
+ raise HTTPException(status_code=500, detail=str(e))