ayushm98 commited on
Commit
06b9814
·
1 Parent(s): ad8fa3f

feat: add DistilBERT training script

Browse files

- Fine-tune distilbert-base-uncased for binary classification
- Use HuggingFace Trainer with early stopping
- Track accuracy, F1, precision, recall metrics
- Save model and training metrics to artifacts

Files changed (1) hide show
  1. ml/training/train.py +277 -0
ml/training/train.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Train DistilBERT for complexity classification."""
2
+
3
+ import json
4
+ import os
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ import torch
9
+ from datasets import DatasetDict
10
+ from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
11
+ from transformers import (
12
+ AutoModelForSequenceClassification,
13
+ AutoTokenizer,
14
+ DataCollatorWithPadding,
15
+ EarlyStoppingCallback,
16
+ Trainer,
17
+ TrainingArguments,
18
+ )
19
+
20
+ # Add parent directory to path for imports
21
+ import sys
22
+ sys.path.insert(0, str(Path(__file__).parent.parent.parent))
23
+
24
+ from ml.data.load_dataset import load_arc_dataset, load_easy2hard_bench
25
+
26
+
27
+ def compute_metrics(eval_pred) -> dict:
28
+ """Compute evaluation metrics."""
29
+ logits, labels = eval_pred
30
+ predictions = np.argmax(logits, axis=-1)
31
+
32
+ return {
33
+ "accuracy": accuracy_score(labels, predictions),
34
+ "f1": f1_score(labels, predictions, average="binary"),
35
+ "precision": precision_score(labels, predictions, average="binary"),
36
+ "recall": recall_score(labels, predictions, average="binary"),
37
+ }
38
+
39
+
40
+ def tokenize_dataset(
41
+ dataset: DatasetDict,
42
+ tokenizer: AutoTokenizer,
43
+ max_length: int = 128,
44
+ ) -> DatasetDict:
45
+ """Tokenize the dataset."""
46
+
47
+ def tokenize_function(examples):
48
+ return tokenizer(
49
+ examples["text"],
50
+ padding=False, # Will be handled by data collator
51
+ truncation=True,
52
+ max_length=max_length,
53
+ )
54
+
55
+ tokenized = dataset.map(
56
+ tokenize_function,
57
+ batched=True,
58
+ remove_columns=["text", "difficulty_score"],
59
+ desc="Tokenizing",
60
+ )
61
+
62
+ return tokenized
63
+
64
+
65
+ def train_complexity_classifier(
66
+ model_name: str = "distilbert-base-uncased",
67
+ dataset_type: str = "arc",
68
+ max_samples: int | None = 5000,
69
+ output_dir: str = "ml/artifacts/complexity-classifier",
70
+ num_epochs: int = 5,
71
+ batch_size: int = 16,
72
+ learning_rate: float = 2e-5,
73
+ max_length: int = 128,
74
+ seed: int = 42,
75
+ ) -> dict:
76
+ """
77
+ Train a DistilBERT model for complexity classification.
78
+
79
+ Args:
80
+ model_name: HuggingFace model name
81
+ dataset_type: "easy2hard" or "arc"
82
+ max_samples: Maximum training samples (None for all)
83
+ output_dir: Directory to save model
84
+ num_epochs: Number of training epochs
85
+ batch_size: Training batch size
86
+ learning_rate: Learning rate
87
+ max_length: Maximum sequence length
88
+ seed: Random seed
89
+
90
+ Returns:
91
+ Dictionary with training metrics
92
+ """
93
+ # Set seed for reproducibility
94
+ torch.manual_seed(seed)
95
+ np.random.seed(seed)
96
+
97
+ output_dir = Path(output_dir)
98
+ output_dir.mkdir(parents=True, exist_ok=True)
99
+
100
+ print(f"Training complexity classifier")
101
+ print(f" Model: {model_name}")
102
+ print(f" Dataset: {dataset_type}")
103
+ print(f" Output: {output_dir}")
104
+ print()
105
+
106
+ # Load dataset
107
+ if dataset_type == "easy2hard":
108
+ dataset = load_easy2hard_bench(max_samples=max_samples, seed=seed)
109
+ else:
110
+ dataset = load_arc_dataset(max_samples=max_samples, seed=seed)
111
+
112
+ # Load tokenizer and model
113
+ print(f"\nLoading model: {model_name}")
114
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
115
+ model = AutoModelForSequenceClassification.from_pretrained(
116
+ model_name,
117
+ num_labels=2,
118
+ id2label={0: "simple", 1: "complex"},
119
+ label2id={"simple": 0, "complex": 1},
120
+ )
121
+
122
+ # Tokenize dataset
123
+ print("\nTokenizing dataset...")
124
+ tokenized_dataset = tokenize_dataset(dataset, tokenizer, max_length)
125
+
126
+ # Data collator for dynamic padding
127
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
128
+
129
+ # Training arguments
130
+ training_args = TrainingArguments(
131
+ output_dir=str(output_dir / "checkpoints"),
132
+ eval_strategy="epoch",
133
+ save_strategy="epoch",
134
+ learning_rate=learning_rate,
135
+ per_device_train_batch_size=batch_size,
136
+ per_device_eval_batch_size=batch_size,
137
+ num_train_epochs=num_epochs,
138
+ weight_decay=0.01,
139
+ load_best_model_at_end=True,
140
+ metric_for_best_model="f1",
141
+ greater_is_better=True,
142
+ logging_dir=str(output_dir / "logs"),
143
+ logging_steps=50,
144
+ seed=seed,
145
+ report_to="none", # Disable wandb/tensorboard
146
+ )
147
+
148
+ # Create trainer
149
+ trainer = Trainer(
150
+ model=model,
151
+ args=training_args,
152
+ train_dataset=tokenized_dataset["train"],
153
+ eval_dataset=tokenized_dataset["validation"],
154
+ tokenizer=tokenizer,
155
+ data_collator=data_collator,
156
+ compute_metrics=compute_metrics,
157
+ callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
158
+ )
159
+
160
+ # Train
161
+ print("\nStarting training...")
162
+ train_result = trainer.train()
163
+
164
+ # Evaluate on test set
165
+ print("\nEvaluating on test set...")
166
+ test_metrics = trainer.evaluate(tokenized_dataset["test"])
167
+
168
+ # Save the model
169
+ print(f"\nSaving model to {output_dir}")
170
+ trainer.save_model(str(output_dir))
171
+ tokenizer.save_pretrained(str(output_dir))
172
+
173
+ # Save metrics
174
+ metrics = {
175
+ "train": {
176
+ "loss": train_result.training_loss,
177
+ "epochs": train_result.metrics.get("epoch", num_epochs),
178
+ },
179
+ "test": {
180
+ "accuracy": test_metrics["eval_accuracy"],
181
+ "f1": test_metrics["eval_f1"],
182
+ "precision": test_metrics["eval_precision"],
183
+ "recall": test_metrics["eval_recall"],
184
+ "loss": test_metrics["eval_loss"],
185
+ },
186
+ "config": {
187
+ "model_name": model_name,
188
+ "dataset_type": dataset_type,
189
+ "max_samples": max_samples,
190
+ "num_epochs": num_epochs,
191
+ "batch_size": batch_size,
192
+ "learning_rate": learning_rate,
193
+ "max_length": max_length,
194
+ },
195
+ }
196
+
197
+ with open(output_dir / "metrics.json", "w") as f:
198
+ json.dump(metrics, f, indent=2)
199
+
200
+ print("\n" + "=" * 50)
201
+ print("Training complete!")
202
+ print("=" * 50)
203
+ print(f"\nTest Results:")
204
+ print(f" Accuracy: {test_metrics['eval_accuracy']:.4f}")
205
+ print(f" F1 Score: {test_metrics['eval_f1']:.4f}")
206
+ print(f" Precision: {test_metrics['eval_precision']:.4f}")
207
+ print(f" Recall: {test_metrics['eval_recall']:.4f}")
208
+ print(f"\nModel saved to: {output_dir}")
209
+
210
+ return metrics
211
+
212
+
213
+ if __name__ == "__main__":
214
+ import argparse
215
+
216
+ parser = argparse.ArgumentParser(description="Train complexity classifier")
217
+ parser.add_argument(
218
+ "--model",
219
+ type=str,
220
+ default="distilbert-base-uncased",
221
+ help="HuggingFace model name",
222
+ )
223
+ parser.add_argument(
224
+ "--dataset",
225
+ choices=["easy2hard", "arc"],
226
+ default="arc",
227
+ help="Dataset to use",
228
+ )
229
+ parser.add_argument(
230
+ "--max-samples",
231
+ type=int,
232
+ default=5000,
233
+ help="Maximum samples (None for all)",
234
+ )
235
+ parser.add_argument(
236
+ "--output-dir",
237
+ type=str,
238
+ default="ml/artifacts/complexity-classifier",
239
+ help="Output directory",
240
+ )
241
+ parser.add_argument(
242
+ "--epochs",
243
+ type=int,
244
+ default=5,
245
+ help="Number of epochs",
246
+ )
247
+ parser.add_argument(
248
+ "--batch-size",
249
+ type=int,
250
+ default=16,
251
+ help="Batch size",
252
+ )
253
+ parser.add_argument(
254
+ "--lr",
255
+ type=float,
256
+ default=2e-5,
257
+ help="Learning rate",
258
+ )
259
+ parser.add_argument(
260
+ "--max-length",
261
+ type=int,
262
+ default=128,
263
+ help="Maximum sequence length",
264
+ )
265
+
266
+ args = parser.parse_args()
267
+
268
+ train_complexity_classifier(
269
+ model_name=args.model,
270
+ dataset_type=args.dataset,
271
+ max_samples=args.max_samples,
272
+ output_dir=args.output_dir,
273
+ num_epochs=args.epochs,
274
+ batch_size=args.batch_size,
275
+ learning_rate=args.lr,
276
+ max_length=args.max_length,
277
+ )