ganeshkonapalli commited on
Commit
89e4a53
·
verified ·
1 Parent(s): fb5b58a

Create train_utils.py

Browse files
Files changed (1) hide show
  1. train_utils.py +111 -0
train_utils.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import pandas as pd
4
+ from sklearn.metrics import classification_report
5
+ from tqdm import tqdm
6
+ import joblib
7
+
8
+ from config import LABEL_COLUMNS, MODEL_SAVE_DIR
9
+
10
+
11
+ def train_logreg_models(X, y, label_encoders, model_class):
12
+ """
13
+ Trains one Logistic Regression model per label column.
14
+
15
+ Args:
16
+ X (array-like): Feature matrix (e.g., TF-IDF vectors).
17
+ y (DataFrame): Target DataFrame containing all label columns.
18
+ label_encoders (dict): Label encoders for each target.
19
+ model_class: LogisticRegression class.
20
+
21
+ Returns:
22
+ dict: Trained models keyed by label name.
23
+ """
24
+ models = {}
25
+ for col in LABEL_COLUMNS:
26
+ print(f"Training Logistic Regression model for {col}...")
27
+ model = model_class()
28
+ model.fit(X, y[col])
29
+ models[col] = model
30
+ return models
31
+
32
+
33
+ def evaluate_logreg_models(models, X_val, y_val, label_encoders):
34
+ """
35
+ Evaluates Logistic Regression models on validation data.
36
+
37
+ Args:
38
+ models (dict): Dictionary of trained models per label.
39
+ X_val (array-like): Validation features.
40
+ y_val (DataFrame): Validation labels.
41
+ label_encoders (dict): Encoders used for decoding.
42
+
43
+ Returns:
44
+ tuple: (classification_reports, true_labels_list, predicted_labels_list)
45
+ """
46
+ reports = {}
47
+ truths = []
48
+ predictions = []
49
+
50
+ for col in LABEL_COLUMNS:
51
+ model = models[col]
52
+ y_true = y_val[col]
53
+ y_pred = model.predict(X_val)
54
+
55
+ truths.append(y_true.tolist())
56
+ predictions.append(y_pred.tolist())
57
+
58
+ report = classification_report(
59
+ y_true, y_pred, output_dict=True, zero_division=0
60
+ )
61
+ reports[col] = report
62
+
63
+ return reports, truths, predictions
64
+
65
+
66
+ def summarize_metrics(metrics):
67
+ summary = []
68
+ for field, report in metrics.items():
69
+ precision = report['weighted avg'].get('precision', 0)
70
+ recall = report['weighted avg'].get('recall', 0)
71
+ f1 = report['weighted avg'].get('f1-score', 0)
72
+ support = report['weighted avg'].get('support', 0)
73
+ accuracy = report.get('accuracy', 0)
74
+ summary.append({
75
+ "Field": field,
76
+ "Precision": precision,
77
+ "Recall": recall,
78
+ "F1-Score": f1,
79
+ "Accuracy": accuracy,
80
+ "Support": support
81
+ })
82
+ return pd.DataFrame(summary)
83
+
84
+
85
+ def save_logreg_models(models, model_name):
86
+ model_path = os.path.join(MODEL_SAVE_DIR, f"{model_name}.pkl")
87
+ joblib.dump(models, model_path)
88
+ print(f"Saved Logistic Regression models to {model_path}")
89
+
90
+
91
+ def load_logreg_models(model_name):
92
+ model_path = os.path.join(MODEL_SAVE_DIR, f"{model_name}.pkl")
93
+ if not os.path.exists(model_path):
94
+ raise FileNotFoundError(f"Model not found at {model_path}")
95
+ models = joblib.load(model_path)
96
+ print(f"Loaded Logistic Regression models from {model_path}")
97
+ return models
98
+
99
+
100
+ def predict_logreg_probabilities(models, X):
101
+ """
102
+ Returns probability distributions for each label.
103
+
104
+ Returns:
105
+ list: One list per label of probability arrays.
106
+ """
107
+ all_probs = []
108
+ for col in LABEL_COLUMNS:
109
+ probs = models[col].predict_proba(X)
110
+ all_probs.append(probs)
111
+ return all_probs