Taranpreet Singh commited on
Commit
7ed53c4
·
1 Parent(s): 98d799c

Fix: include training_utils module for HF Space runtime

Browse files
Files changed (1) hide show
  1. training_utils.py +150 -0
training_utils.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import json
4
+ import joblib
5
+ import logging
6
+ import numpy as np
7
+ from datetime import datetime
8
+ from sklearn.ensemble import RandomForestClassifier
9
+ from sklearn.model_selection import StratifiedKFold
10
+ from sklearn.metrics import (precision_recall_fscore_support, roc_auc_score,
11
+ average_precision_score, confusion_matrix, precision_recall_curve)
12
+
13
+ logger = logging.getLogger('nids')
14
+
15
+
16
+
17
+ class BinaryLabelEncoder:
18
+ """Simple encoder mapping: BENIGN -> 0, anything else -> 1.
19
+ Provides transform/inverse_transform and `classes_` compatible attribute.
20
+ """
21
+ def __init__(self):
22
+ self.classes_ = np.array([0, 1])
23
+
24
+ def transform(self, y_series):
25
+ # accept pandas Series or array-like labels
26
+ y_str = np.array(y_series).astype(str)
27
+ return (y_str != 'BENIGN').astype(int)
28
+
29
+ def inverse_transform(self, y_arr):
30
+ y = np.array(y_arr).astype(int)
31
+ return np.where(y == 0, 'BENIGN', 'ATTACK')
32
+
33
+
34
+ def validate_and_select_features(df, features):
35
+ missing = [c for c in features if c not in df.columns]
36
+ if missing:
37
+ raise ValueError(f"Missing feature columns: {missing}")
38
+ X = df[features].copy()
39
+ # drop constant features
40
+ nunique = X.nunique()
41
+ const_cols = nunique[nunique <= 1].index.tolist()
42
+ if const_cols:
43
+ logger.info('Dropping constant columns: %s', const_cols)
44
+ X.drop(columns=const_cols, inplace=True)
45
+ return X
46
+
47
+
48
+ def train_model_cv(df, features, target='Label', n_splits=5, n_estimators=100, max_depth=None, seed=42):
49
+ """Train RandomForest with StratifiedKFold and return best model plus metrics.
50
+
51
+ - Uses class_weight='balanced' to handle class imbalance (no SMOTE).
52
+ - Computes precision, recall, F1, PR-AUC, ROC-AUC and confusion matrices per fold.
53
+ """
54
+ # explicit binary encoding
55
+ encoder = BinaryLabelEncoder()
56
+ y_raw = df[target].astype(str)
57
+ y = encoder.transform(y_raw)
58
+ X = validate_and_select_features(df, features)
59
+
60
+ skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=seed)
61
+ fold_metrics = []
62
+ models = []
63
+
64
+ # collect validation probabilities and labels across folds for PR curve and thresholding
65
+ all_val_probas = []
66
+ all_val_labels = []
67
+
68
+ X_arr = X.values
69
+ y_arr = y.values
70
+
71
+ for fold, (train_idx, val_idx) in enumerate(skf.split(X_arr, y_arr), start=1):
72
+ clf = RandomForestClassifier(
73
+ n_estimators=n_estimators,
74
+ max_depth=(None if max_depth == 0 else max_depth),
75
+ class_weight='balanced',
76
+ random_state=seed,
77
+ n_jobs=-1,
78
+ )
79
+ clf.fit(X_arr[train_idx], y_arr[train_idx])
80
+
81
+ proba = clf.predict_proba(X_arr[val_idx])[:, 1]
82
+ preds = (proba >= 0.5).astype(int)
83
+
84
+ all_val_probas.extend(proba.tolist())
85
+ all_val_labels.extend(y_arr[val_idx].tolist())
86
+
87
+ prec, rec, f1, _ = precision_recall_fscore_support(y_arr[val_idx], preds, average='binary', zero_division=0)
88
+ pr_auc = average_precision_score(y_arr[val_idx], proba)
89
+ try:
90
+ roc = roc_auc_score(y_arr[val_idx], proba)
91
+ except Exception:
92
+ roc = float('nan')
93
+
94
+ cm = confusion_matrix(y_arr[val_idx], preds).tolist()
95
+ fold_metrics.append({
96
+ 'fold': fold,
97
+ 'precision': float(prec),
98
+ 'recall': float(rec),
99
+ 'f1': float(f1),
100
+ 'pr_auc': float(pr_auc),
101
+ 'roc_auc': float(roc),
102
+ 'confusion_matrix': cm
103
+ })
104
+ models.append(clf)
105
+ logger.info('Fold %d metrics: prec=%.3f rec=%.3f f1=%.3f pr_auc=%.3f', fold, prec, rec, f1, pr_auc)
106
+
107
+ # pick best model by f1
108
+ best_idx = int(np.argmax([m['f1'] for m in fold_metrics]))
109
+ best_model = models[best_idx]
110
+
111
+ # aggregate metrics
112
+ agg = {}
113
+ for k in ['precision', 'recall', 'f1', 'pr_auc', 'roc_auc']:
114
+ vals = [fm[k] for fm in fold_metrics if not np.isnan(fm[k])]
115
+ agg[f'{k}_mean'] = float(np.mean(vals)) if vals else float('nan')
116
+ agg[f'{k}_std'] = float(np.std(vals)) if vals else float('nan')
117
+
118
+ results = {'folds': fold_metrics, 'aggregate': agg}
119
+ # compute overall PR curve from CV validation outputs
120
+ all_val_probas = np.array(all_val_probas)
121
+ all_val_labels = np.array(all_val_labels)
122
+ precision, recall, pr_thresholds = precision_recall_curve(all_val_labels, all_val_probas)
123
+
124
+ results['pr_curve'] = {
125
+ 'precision': precision.tolist(),
126
+ 'recall': recall.tolist(),
127
+ 'thresholds': pr_thresholds.tolist()
128
+ }
129
+
130
+ # artifact hygiene: directories
131
+ ts = datetime.utcnow().isoformat() + 'Z'
132
+ results['timestamp'] = ts
133
+ results['seed'] = int(seed)
134
+ results['features'] = list(X.columns)
135
+ results['cv_validation_counts'] = int(len(all_val_labels))
136
+
137
+ models_dir = os.path.join('models')
138
+ metrics_dir = os.path.join('metrics')
139
+ os.makedirs(models_dir, exist_ok=True)
140
+ os.makedirs(metrics_dir, exist_ok=True)
141
+
142
+ model_path = os.path.join(models_dir, 'rf_model.joblib')
143
+ metrics_path = os.path.join(metrics_dir, 'training_metrics.json')
144
+
145
+ joblib.dump(best_model, model_path)
146
+ with open(metrics_path, 'w') as fh:
147
+ json.dump(results, fh, indent=2)
148
+ logger.info('Training complete. Metrics saved to %s, model saved to %s', metrics_path, model_path)
149
+
150
+ return best_model, results, X, y, all_val_probas, all_val_labels, encoder