subbunanepalli commited on
Commit
906f3f9
·
verified ·
1 Parent(s): f8cea34

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +35 -13
train.py CHANGED
@@ -1,40 +1,62 @@
1
- import pandas as pd
2
- from fastapi import HTTPException
3
  import os
4
  import joblib
 
 
5
  from sklearn.pipeline import Pipeline
6
- from sklearn.linear_model import LogisticRegression
7
  from sklearn.model_selection import train_test_split
8
  from sklearn.feature_extraction.text import TfidfVectorizer
 
9
  from sklearn.multioutput import MultiOutputClassifier
10
 
11
  from utils import create_text_input
12
 
 
13
  DATA_PATH = "data/synthetic_transactions_samples_5000.csv"
14
- MODEL_PATH = "models/logreg_model.pkl"
 
15
 
16
  def train_model():
17
  try:
18
- df = pd.read_csv(DATA_PATH)
19
- df = df.fillna("")
20
  df["text_input"] = df.apply(create_text_input, axis=1)
21
 
 
22
  X = df["text_input"]
23
- y = df[["Maker_Action", "Escalation_Level", "Risk_Category", "Risk_Drivers", "Investigation_Outcome", "Red_Flag_Reason"]]
24
-
25
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
26
-
 
 
 
 
 
 
 
 
 
 
 
27
  pipeline = Pipeline([
28
  ("vectorizer", TfidfVectorizer()),
29
  ("classifier", MultiOutputClassifier(LogisticRegression(max_iter=1000)))
30
  ])
 
 
31
  pipeline.fit(X_train, y_train)
32
 
33
- os.makedirs("models", exist_ok=True)
 
34
  joblib.dump(pipeline, MODEL_PATH)
35
 
36
- acc = pipeline.score(X_test, y_test)
 
 
 
 
 
 
37
 
38
- return {"message": "Model trained successfully.", "accuracy": acc}
39
  except Exception as e:
40
  raise HTTPException(status_code=500, detail=str(e))
 
 
 
1
  import os
2
  import joblib
3
+ import pandas as pd
4
+ from fastapi import HTTPException
5
  from sklearn.pipeline import Pipeline
 
6
  from sklearn.model_selection import train_test_split
7
  from sklearn.feature_extraction.text import TfidfVectorizer
8
+ from sklearn.linear_model import LogisticRegression
9
  from sklearn.multioutput import MultiOutputClassifier
10
 
11
  from utils import create_text_input
12
 
13
+ # ========== Config ==========
14
  DATA_PATH = "data/synthetic_transactions_samples_5000.csv"
15
+ MODEL_DIR = "models"
16
+ MODEL_PATH = os.path.join(MODEL_DIR, "logreg_model.pkl")
17
 
18
  def train_model():
19
  try:
20
+ # Load and preprocess data
21
+ df = pd.read_csv(DATA_PATH).fillna("")
22
  df["text_input"] = df.apply(create_text_input, axis=1)
23
 
24
+ # Features and targets
25
  X = df["text_input"]
26
+ y = df[[
27
+ "Maker_Action",
28
+ "Escalation_Level",
29
+ "Risk_Category",
30
+ "Risk_Drivers",
31
+ "Investigation_Outcome",
32
+ "Red_Flag_Reason"
33
+ ]]
34
+
35
+ # Train/test split
36
+ X_train, X_test, y_train, y_test = train_test_split(
37
+ X, y, test_size=0.2, random_state=42
38
+ )
39
+
40
+ # Pipeline: TF-IDF + MultiOutput LR
41
  pipeline = Pipeline([
42
  ("vectorizer", TfidfVectorizer()),
43
  ("classifier", MultiOutputClassifier(LogisticRegression(max_iter=1000)))
44
  ])
45
+
46
+ # Train
47
  pipeline.fit(X_train, y_train)
48
 
49
+ # Save model
50
+ os.makedirs(MODEL_DIR, exist_ok=True)
51
  joblib.dump(pipeline, MODEL_PATH)
52
 
53
+ # Evaluate
54
+ accuracy = pipeline.score(X_test, y_test)
55
+
56
+ return {
57
+ "message": "Model trained and saved successfully.",
58
+ "accuracy": round(accuracy, 4)
59
+ }
60
 
 
61
  except Exception as e:
62
  raise HTTPException(status_code=500, detail=str(e))