vidyasagar786 commited on
Commit
907307c
Β·
verified Β·
1 Parent(s): 4d3516d

Upload train_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_model.py +207 -0
train_model.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import joblib
3
+ import pandas as pd
4
+ import numpy as np
5
+ import xgboost as xgb
6
+ import matplotlib.pyplot as plt
7
+
8
+ from tqdm.auto import tqdm
9
+ from sklearn.feature_extraction.text import TfidfVectorizer
10
+ from sklearn.metrics import classification_report
11
+ from sklearn.preprocessing import StandardScaler
12
+ from sklearn.metrics import confusion_matrix
13
+ from scipy.sparse import hstack, csr_matrix
14
+
15
+ # ===============================
16
+ # PATHS
17
+ # ===============================
18
+
19
+ TRAIN_PATH = "/Users/vidyasagarkaruturi/Downloads/machine learning/src/data/processed/train.csv"
20
+ VAL_PATH = "/Users/vidyasagarkaruturi/Downloads/machine learning/src/data/processed/val.csv"
21
+ TEST_PATH = "/Users/vidyasagarkaruturi/Downloads/machine learning/src/data/processed/test.csv"
22
+
23
+ MODEL_SAVE_PATH = "document_classifier_xgb.pkl"
24
+
25
+ # ===============================
26
+ # LOAD DATA
27
+ # ===============================
28
+
29
+ print("πŸ“‚ Loading data...")
30
+
31
+ train_df = pd.read_csv(TRAIN_PATH)
32
+ val_df = pd.read_csv(VAL_PATH)
33
+ test_df = pd.read_csv(TEST_PATH)
34
+
35
+ X_train_text = train_df["text"].fillna("")
36
+ X_val_text = val_df["text"].fillna("")
37
+ X_test_text = test_df["text"].fillna("")
38
+
39
+ y_train = train_df["label"]
40
+ y_val = val_df["label"]
41
+ y_test = test_df["label"]
42
+
43
+ print("βœ… Data loaded successfully")
44
+
45
+ # ===============================
46
+ # TF-IDF FEATURES
47
+ # ===============================
48
+
49
+ print("🧠 Creating TF-IDF features...")
50
+
51
+ word_vectorizer = TfidfVectorizer(
52
+ max_features=40000,
53
+ ngram_range=(1, 2),
54
+ stop_words="english"
55
+ )
56
+
57
+ char_vectorizer = TfidfVectorizer(
58
+ analyzer="char",
59
+ ngram_range=(3, 5),
60
+ max_features=20000
61
+ )
62
+
63
+ X_train_word = word_vectorizer.fit_transform(X_train_text)
64
+ X_val_word = word_vectorizer.transform(X_val_text)
65
+ X_test_word = word_vectorizer.transform(X_test_text)
66
+
67
+ X_train_char = char_vectorizer.fit_transform(X_train_text)
68
+ X_val_char = char_vectorizer.transform(X_val_text)
69
+ X_test_char = char_vectorizer.transform(X_test_text)
70
+
71
+ X_train_text_features = hstack([X_train_word, X_train_char])
72
+ X_val_text_features = hstack([X_val_word, X_val_char])
73
+ X_test_text_features = hstack([X_test_word, X_test_char])
74
+
75
+ print("βœ… Text features ready")
76
+
77
+ # ===============================
78
+ # NUMERIC FEATURES
79
+ # ===============================
80
+
81
+ print("πŸ”’ Adding numeric features...")
82
+
83
+ numeric_cols = [
84
+ "char_count",
85
+ "digit_count",
86
+ "uppercase_count",
87
+ "currency_count",
88
+ "line_count"
89
+ ]
90
+
91
+ scaler = StandardScaler()
92
+
93
+ X_train_num = scaler.fit_transform(train_df[numeric_cols])
94
+ X_val_num = scaler.transform(val_df[numeric_cols])
95
+ X_test_num = scaler.transform(test_df[numeric_cols])
96
+
97
+ X_train_num = csr_matrix(X_train_num)
98
+ X_val_num = csr_matrix(X_val_num)
99
+ X_test_num = csr_matrix(X_test_num)
100
+
101
+ # Combine text + numeric
102
+ X_train = hstack([X_train_text_features, X_train_num])
103
+ X_val = hstack([X_val_text_features, X_val_num])
104
+ X_test = hstack([X_test_text_features, X_test_num])
105
+
106
+ print("βœ… Feature matrix ready")
107
+ # ===============================
108
+ # MODEL
109
+ # ===============================
110
+
111
+ print("πŸš€ Starting training...")
112
+
113
+ N_ESTIMATORS = 400
114
+
115
+ class TqdmCallback(xgb.callback.TrainingCallback):
116
+ def __init__(self, total):
117
+ self.pbar = tqdm(total=total, desc="Training Progress", unit="trees")
118
+
119
+ def after_iteration(self, model, epoch, evals_log):
120
+ self.pbar.update(1)
121
+ return False
122
+
123
+ def after_training(self, model):
124
+ self.pbar.close()
125
+ return model
126
+
127
+ model = xgb.XGBClassifier(
128
+ n_estimators=N_ESTIMATORS,
129
+ max_depth=6,
130
+ learning_rate=0.1,
131
+ tree_method="hist",
132
+ eval_metric="mlogloss",
133
+ early_stopping_rounds=30,
134
+ callbacks=[TqdmCallback(N_ESTIMATORS)]
135
+ )
136
+
137
+ start_time = time.time()
138
+
139
+ model.fit(
140
+ X_train,
141
+ y_train,
142
+ eval_set=[(X_train, y_train), (X_val, y_val)],
143
+ verbose=False
144
+ )
145
+
146
+ print(f"\n⏱ Training completed in {round(time.time() - start_time, 2)} seconds")
147
+
148
+ # ===============================
149
+ # EVALUATION
150
+ # ===============================
151
+
152
+ print("\nπŸ“Š Validation Performance:")
153
+ val_preds = model.predict(X_val)
154
+ print(classification_report(y_val, val_preds))
155
+
156
+ print("\nπŸ“Š Test Performance:")
157
+ test_preds = model.predict(X_test)
158
+ print(classification_report(y_test, test_preds))
159
+
160
+ # ===============================
161
+ # TRAINING CURVE
162
+ # ===============================
163
+
164
+ results = model.evals_result()
165
+
166
+ train_loss = results["validation_0"]["mlogloss"]
167
+ val_loss = results["validation_1"]["mlogloss"]
168
+
169
+ plt.figure(figsize=(8,5))
170
+ plt.plot(train_loss, label="Train Loss")
171
+ plt.plot(val_loss, label="Validation Loss")
172
+ plt.xlabel("Boosting Rounds")
173
+ plt.ylabel("Log Loss")
174
+ plt.title("Training Curve")
175
+ plt.legend()
176
+ plt.savefig("training_curve.png", dpi=150, bbox_inches="tight")
177
+ plt.close()
178
+ print("πŸ“ˆ Training curve saved to training_curve.png")
179
+
180
+ # ===============================
181
+ # FEATURE IMPORTANCE
182
+ # ===============================
183
+
184
+ plt.figure(figsize=(10,8))
185
+ xgb.plot_importance(model, max_num_features=20)
186
+ plt.title("Top 20 Important Features")
187
+ plt.savefig("feature_importance.png", dpi=150, bbox_inches="tight")
188
+ plt.close()
189
+ print("πŸ“Š Feature importance saved to feature_importance.png")
190
+
191
+ # ===============================
192
+ # SAVE MODEL
193
+ # ===============================
194
+
195
+ # Clear callbacks before saving β€” TqdmCallback holds an open file handle
196
+ # (TextIOWrapper) that joblib/pickle cannot serialize.
197
+ model.set_params(callbacks=[])
198
+
199
+ joblib.dump({
200
+ "model": model,
201
+ "word_vectorizer": word_vectorizer,
202
+ "char_vectorizer": char_vectorizer,
203
+ "scaler": scaler
204
+ }, MODEL_SAVE_PATH)
205
+
206
+ print(f"\nπŸ’Ύ Model saved to {MODEL_SAVE_PATH}")
207
+ print("πŸ”₯ All done!")