songhieng commited on
Commit
bf40bd8
·
verified ·
1 Parent(s): 9aa4daf

Update src/mlops/trainer.py

Browse files
Files changed (1) hide show
  1. src/mlops/trainer.py +466 -466
src/mlops/trainer.py CHANGED
@@ -1,466 +1,466 @@
1
- """
2
- Model Trainer Module
3
- ====================
4
-
5
- Provides model training functionality with progress tracking,
6
- checkpointing, and experiment logging.
7
- """
8
-
9
- import os
10
- # Set environment variables before transformers import
11
- os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '3')
12
- os.environ.setdefault('TRANSFORMERS_NO_TF', '1')
13
-
14
- import json
15
- import time
16
- import logging
17
- from pathlib import Path
18
- from datetime import datetime
19
- from typing import Dict, List, Optional, Tuple, Callable, Any
20
- from dataclasses import dataclass, field
21
- import numpy as np
22
-
23
- import torch
24
- from torch.utils.data import Dataset, DataLoader
25
- from transformers import (
26
- AutoTokenizer,
27
- AutoModelForSequenceClassification,
28
- TrainingArguments,
29
- Trainer,
30
- EarlyStoppingCallback,
31
- TrainerCallback
32
- )
33
- from sklearn.model_selection import train_test_split
34
- from sklearn.metrics import accuracy_score, precision_recall_fscore_support
35
-
36
- from .config import TrainingConfig
37
-
38
- logger = logging.getLogger(__name__)
39
-
40
-
41
- @dataclass
42
- class TrainingMetrics:
43
- """Container for training metrics."""
44
- epoch: int = 0
45
- train_loss: float = 0.0
46
- eval_loss: float = 0.0
47
- accuracy: float = 0.0
48
- precision: float = 0.0
49
- recall: float = 0.0
50
- f1: float = 0.0
51
- learning_rate: float = 0.0
52
- timestamp: str = ""
53
-
54
- def to_dict(self) -> dict:
55
- return {
56
- "epoch": self.epoch,
57
- "train_loss": self.train_loss,
58
- "eval_loss": self.eval_loss,
59
- "accuracy": self.accuracy,
60
- "precision": self.precision,
61
- "recall": self.recall,
62
- "f1": self.f1,
63
- "learning_rate": self.learning_rate,
64
- "timestamp": self.timestamp
65
- }
66
-
67
-
68
- @dataclass
69
- class TrainingProgress:
70
- """Container for training progress information."""
71
- status: str = "idle" # idle, training, completed, failed
72
- current_epoch: int = 0
73
- total_epochs: int = 0
74
- current_step: int = 0
75
- total_steps: int = 0
76
- progress_percent: float = 0.0
77
- eta_seconds: float = 0.0
78
- metrics_history: List[TrainingMetrics] = field(default_factory=list)
79
- error_message: str = ""
80
- model_path: Optional[str] = None
81
- final_metrics: Optional[TrainingMetrics] = None
82
- start_time: float = 0.0
83
- end_time: float = 0.0
84
-
85
- def update_progress(self):
86
- """Update progress percentage."""
87
- if self.total_steps > 0:
88
- self.progress_percent = (self.current_step / self.total_steps) * 100
89
-
90
- def get_elapsed_time(self) -> float:
91
- """Get elapsed training time in seconds."""
92
- if self.start_time == 0:
93
- return 0.0
94
- end = self.end_time if self.end_time > 0 else time.time()
95
- return end - self.start_time
96
-
97
-
98
- class TextClassificationDataset(Dataset):
99
- """PyTorch Dataset for text classification."""
100
-
101
- def __init__(self, texts: List[str], labels: List[int],
102
- tokenizer, max_length: int = 256):
103
- self.texts = texts
104
- self.labels = labels
105
- self.tokenizer = tokenizer
106
- self.max_length = max_length
107
-
108
- def __len__(self):
109
- return len(self.texts)
110
-
111
- def __getitem__(self, idx):
112
- text = str(self.texts[idx])
113
- label = self.labels[idx]
114
-
115
- encoding = self.tokenizer(
116
- text,
117
- truncation=True,
118
- padding='max_length',
119
- max_length=self.max_length,
120
- return_tensors='pt'
121
- )
122
-
123
- return {
124
- 'input_ids': encoding['input_ids'].flatten(),
125
- 'attention_mask': encoding['attention_mask'].flatten(),
126
- 'labels': torch.tensor(label, dtype=torch.long)
127
- }
128
-
129
-
130
- class ProgressCallback(TrainerCallback):
131
- """Custom callback for tracking training progress."""
132
-
133
- def __init__(self, progress: TrainingProgress,
134
- update_callback: Optional[Callable] = None):
135
- self.progress = progress
136
- self.update_callback = update_callback
137
-
138
- def on_train_begin(self, args, state, control, **kwargs):
139
- self.progress.status = "training"
140
- self.progress.start_time = time.time()
141
- self.progress.total_steps = state.max_steps
142
-
143
- def on_step_end(self, args, state, control, **kwargs):
144
- self.progress.current_step = state.global_step
145
- self.progress.update_progress()
146
-
147
- # Calculate ETA
148
- if state.global_step > 0:
149
- elapsed = time.time() - self.progress.start_time
150
- steps_remaining = state.max_steps - state.global_step
151
- time_per_step = elapsed / state.global_step
152
- self.progress.eta_seconds = steps_remaining * time_per_step
153
-
154
- if self.update_callback:
155
- self.update_callback(self.progress)
156
-
157
- def on_epoch_end(self, args, state, control, **kwargs):
158
- self.progress.current_epoch = int(state.epoch)
159
-
160
- def on_log(self, args, state, control, logs=None, **kwargs):
161
- if logs:
162
- metrics = TrainingMetrics(
163
- epoch=int(state.epoch) if state.epoch else 0,
164
- train_loss=logs.get('loss', 0.0),
165
- eval_loss=logs.get('eval_loss', 0.0),
166
- learning_rate=logs.get('learning_rate', 0.0),
167
- timestamp=datetime.now().isoformat()
168
- )
169
- self.progress.metrics_history.append(metrics)
170
-
171
- def on_train_end(self, args, state, control, **kwargs):
172
- self.progress.status = "completed"
173
- self.progress.end_time = time.time()
174
- self.progress.progress_percent = 100.0
175
-
176
-
177
- class ModelTrainer:
178
- """
179
- Main trainer class for text classification models.
180
-
181
- Supports:
182
- - Multiple model architectures (BERT, RoBERTa, XLM-RoBERTa, etc.)
183
- - Progress tracking and callbacks
184
- - Checkpointing and model saving
185
- - Experiment logging
186
- """
187
-
188
- def __init__(self, config: TrainingConfig):
189
- """
190
- Initialize the trainer.
191
-
192
- Args:
193
- config: Training configuration
194
- """
195
- self.config = config
196
- self.model = None
197
- self.tokenizer = None
198
- self.trainer = None
199
- self.progress = TrainingProgress(total_epochs=config.num_epochs)
200
- self._setup_output_dir()
201
-
202
- def _setup_output_dir(self):
203
- """Create output directory for models and logs."""
204
- os.makedirs(self.config.output_dir, exist_ok=True)
205
- os.makedirs(os.path.join(self.config.output_dir, "logs"), exist_ok=True)
206
-
207
- def load_model(self, progress_callback: Optional[Callable] = None) -> bool:
208
- """
209
- Load model and tokenizer.
210
-
211
- Returns:
212
- True if successful, False otherwise
213
- """
214
- try:
215
- logger.info(f"Loading model: {self.config.model_name}")
216
-
217
- if progress_callback:
218
- progress_callback("Loading tokenizer...")
219
-
220
- self.tokenizer = AutoTokenizer.from_pretrained(
221
- self.config.model_name,
222
- use_fast=True
223
- )
224
-
225
- if progress_callback:
226
- progress_callback("Loading model...")
227
-
228
- self.model = AutoModelForSequenceClassification.from_pretrained(
229
- self.config.model_name,
230
- num_labels=self.config.num_labels,
231
- ignore_mismatched_sizes=True
232
- )
233
-
234
- logger.info("Model and tokenizer loaded successfully")
235
- return True
236
-
237
- except Exception as e:
238
- logger.error(f"Failed to load model: {str(e)}")
239
- self.progress.status = "failed"
240
- self.progress.error_message = str(e)
241
- return False
242
-
243
- def prepare_data(self, texts: List[str], labels: List[int]) -> Tuple[Dataset, Dataset, Dataset]:
244
- """
245
- Prepare datasets for training.
246
-
247
- Args:
248
- texts: List of text samples
249
- labels: List of corresponding labels
250
-
251
- Returns:
252
- Tuple of (train_dataset, val_dataset, test_dataset)
253
- """
254
- # Split data
255
- train_texts, temp_texts, train_labels, temp_labels = train_test_split(
256
- texts, labels,
257
- test_size=(1 - self.config.train_split),
258
- random_state=self.config.random_seed,
259
- stratify=labels if len(set(labels)) > 1 else None
260
- )
261
-
262
- # Split validation and test from remaining data
263
- val_ratio = self.config.validation_split / (1 - self.config.train_split)
264
- val_texts, test_texts, val_labels, test_labels = train_test_split(
265
- temp_texts, temp_labels,
266
- test_size=(1 - val_ratio),
267
- random_state=self.config.random_seed,
268
- stratify=temp_labels if len(set(temp_labels)) > 1 else None
269
- )
270
-
271
- # Create datasets
272
- train_dataset = TextClassificationDataset(
273
- train_texts, train_labels, self.tokenizer, self.config.max_length
274
- )
275
- val_dataset = TextClassificationDataset(
276
- val_texts, val_labels, self.tokenizer, self.config.max_length
277
- )
278
- test_dataset = TextClassificationDataset(
279
- test_texts, test_labels, self.tokenizer, self.config.max_length
280
- )
281
-
282
- logger.info(f"Data split: train={len(train_dataset)}, val={len(val_dataset)}, test={len(test_dataset)}")
283
-
284
- return train_dataset, val_dataset, test_dataset
285
-
286
- def compute_metrics(self, eval_pred) -> Dict[str, float]:
287
- """Compute metrics for evaluation."""
288
- predictions, labels = eval_pred
289
- predictions = np.argmax(predictions, axis=1)
290
-
291
- accuracy = accuracy_score(labels, predictions)
292
- precision, recall, f1, _ = precision_recall_fscore_support(
293
- labels, predictions, average='weighted', zero_division=0
294
- )
295
-
296
- return {
297
- 'accuracy': accuracy,
298
- 'precision': precision,
299
- 'recall': recall,
300
- 'f1': f1
301
- }
302
-
303
- def train(self, texts: List[str], labels: List[int],
304
- progress_callback: Optional[Callable] = None,
305
- status_callback: Optional[Callable] = None) -> TrainingProgress:
306
- """
307
- Train the model.
308
-
309
- Args:
310
- texts: Training texts
311
- labels: Training labels
312
- progress_callback: Optional callback for progress updates
313
- status_callback: Optional callback for status messages
314
-
315
- Returns:
316
- TrainingProgress object with training results
317
- """
318
- try:
319
- self.progress = TrainingProgress(total_epochs=self.config.num_epochs)
320
-
321
- # Load model if not already loaded
322
- if self.model is None:
323
- if status_callback:
324
- status_callback("Loading model...")
325
- if not self.load_model(status_callback):
326
- return self.progress
327
-
328
- # Prepare data
329
- if status_callback:
330
- status_callback("Preparing datasets...")
331
-
332
- train_dataset, val_dataset, test_dataset = self.prepare_data(texts, labels)
333
-
334
- # Create unique output directory for this run
335
- run_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
336
- run_output_dir = os.path.join(
337
- self.config.output_dir,
338
- f"run_{run_timestamp}"
339
- )
340
- os.makedirs(run_output_dir, exist_ok=True)
341
-
342
- # Save config
343
- config_path = os.path.join(run_output_dir, "training_config.json")
344
- with open(config_path, 'w', encoding='utf-8') as f:
345
- json.dump(self.config.to_dict(), f, indent=2, ensure_ascii=False)
346
-
347
- # Setup training arguments
348
- training_args = TrainingArguments(
349
- output_dir=run_output_dir,
350
- num_train_epochs=self.config.num_epochs,
351
- per_device_train_batch_size=self.config.batch_size,
352
- per_device_eval_batch_size=self.config.batch_size,
353
- warmup_ratio=self.config.warmup_ratio,
354
- weight_decay=self.config.weight_decay,
355
- learning_rate=self.config.learning_rate,
356
- logging_dir=os.path.join(run_output_dir, "logs"),
357
- logging_steps=self.config.logging_steps,
358
- evaluation_strategy=self.config.evaluation_strategy,
359
- save_strategy=self.config.evaluation_strategy,
360
- load_best_model_at_end=self.config.save_best_model,
361
- metric_for_best_model="f1",
362
- greater_is_better=True,
363
- save_total_limit=2,
364
- fp16=self.config.use_fp16 and torch.cuda.is_available(),
365
- gradient_accumulation_steps=self.config.gradient_accumulation_steps,
366
- report_to="none", # Disable default reporting
367
- seed=self.config.random_seed,
368
- dataloader_pin_memory=False, # For CPU compatibility
369
- )
370
-
371
- # Create trainer with custom callback
372
- progress_tracker = ProgressCallback(self.progress, progress_callback)
373
-
374
- self.trainer = Trainer(
375
- model=self.model,
376
- args=training_args,
377
- train_dataset=train_dataset,
378
- eval_dataset=val_dataset,
379
- compute_metrics=self.compute_metrics,
380
- callbacks=[progress_tracker]
381
- )
382
-
383
- # Start training
384
- if status_callback:
385
- status_callback("Training started...")
386
-
387
- logger.info("Starting model training...")
388
- self.trainer.train()
389
- logger.info("Training loop completed successfully")
390
-
391
- # Evaluate on test set
392
- if status_callback:
393
- status_callback("Evaluating on test set...")
394
-
395
- logger.info("Starting test set evaluation...")
396
- test_results = self.trainer.evaluate(test_dataset)
397
- logger.info(f"Test evaluation completed: {test_results}")
398
-
399
- # Add final metrics
400
- final_metrics = TrainingMetrics(
401
- epoch=self.config.num_epochs,
402
- eval_loss=test_results.get('eval_loss', 0),
403
- accuracy=test_results.get('eval_accuracy', 0),
404
- precision=test_results.get('eval_precision', 0),
405
- recall=test_results.get('eval_recall', 0),
406
- f1=test_results.get('eval_f1', 0),
407
- timestamp=datetime.now().isoformat()
408
- )
409
- self.progress.metrics_history.append(final_metrics)
410
-
411
- # Save model
412
- if status_callback:
413
- status_callback("Saving model...")
414
-
415
- model_save_path = os.path.join(run_output_dir, "final_model")
416
- logger.info(f"Saving model to {model_save_path}...")
417
- os.makedirs(model_save_path, exist_ok=True)
418
- self.trainer.save_model(model_save_path)
419
- self.tokenizer.save_pretrained(model_save_path)
420
- logger.info(f"Model saved successfully to {model_save_path}")
421
-
422
- # Save training metrics
423
- metrics_path = os.path.join(run_output_dir, "metrics.json")
424
- with open(metrics_path, 'w', encoding='utf-8') as f:
425
- json.dump({
426
- "final_metrics": final_metrics.to_dict(),
427
- "history": [m.to_dict() for m in self.progress.metrics_history],
428
- "test_results": test_results
429
- }, f, indent=2, ensure_ascii=False)
430
-
431
- self.progress.status = "completed"
432
- self.progress.model_path = model_save_path
433
- self.progress.final_metrics = final_metrics
434
-
435
- logger.info(f"Training completed! Model saved to {model_save_path}")
436
-
437
- return self.progress
438
-
439
- except Exception as e:
440
- logger.error(f"Training failed: {str(e)}")
441
- self.progress.status = "failed"
442
- self.progress.error_message = str(e)
443
- self.progress.end_time = time.time()
444
- return self.progress
445
-
446
- def get_model_path(self) -> Optional[str]:
447
- """Get path to the trained model."""
448
- if hasattr(self.progress, 'model_path'):
449
- return self.progress.model_path
450
- return None
451
-
452
- def cleanup(self):
453
- """Cleanup resources."""
454
- if self.model is not None:
455
- del self.model
456
- self.model = None
457
- if self.tokenizer is not None:
458
- del self.tokenizer
459
- self.tokenizer = None
460
- if torch.cuda.is_available():
461
- torch.cuda.empty_cache()
462
-
463
-
464
- def create_trainer(config: TrainingConfig) -> ModelTrainer:
465
- """Factory function to create a ModelTrainer instance."""
466
- return ModelTrainer(config)
 
1
+ """
2
+ Model Trainer Module
3
+ ====================
4
+
5
+ Provides model training functionality with progress tracking,
6
+ checkpointing, and experiment logging.
7
+ """
8
+
9
+ import os
10
+ # Set environment variables before transformers import
11
+ os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '3')
12
+ os.environ.setdefault('TRANSFORMERS_NO_TF', '1')
13
+
14
+ import json
15
+ import time
16
+ import logging
17
+ from pathlib import Path
18
+ from datetime import datetime
19
+ from typing import Dict, List, Optional, Tuple, Callable, Any
20
+ from dataclasses import dataclass, field
21
+ import numpy as np
22
+
23
+ import torch
24
+ from torch.utils.data import Dataset, DataLoader
25
+ from transformers import (
26
+ AutoTokenizer,
27
+ AutoModelForSequenceClassification,
28
+ TrainingArguments,
29
+ Trainer,
30
+ EarlyStoppingCallback,
31
+ TrainerCallback
32
+ )
33
+ from sklearn.model_selection import train_test_split
34
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support
35
+
36
+ from .config import TrainingConfig
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+
41
+ @dataclass
42
+ class TrainingMetrics:
43
+ """Container for training metrics."""
44
+ epoch: int = 0
45
+ train_loss: float = 0.0
46
+ eval_loss: float = 0.0
47
+ accuracy: float = 0.0
48
+ precision: float = 0.0
49
+ recall: float = 0.0
50
+ f1: float = 0.0
51
+ learning_rate: float = 0.0
52
+ timestamp: str = ""
53
+
54
+ def to_dict(self) -> dict:
55
+ return {
56
+ "epoch": self.epoch,
57
+ "train_loss": self.train_loss,
58
+ "eval_loss": self.eval_loss,
59
+ "accuracy": self.accuracy,
60
+ "precision": self.precision,
61
+ "recall": self.recall,
62
+ "f1": self.f1,
63
+ "learning_rate": self.learning_rate,
64
+ "timestamp": self.timestamp
65
+ }
66
+
67
+
68
+ @dataclass
69
+ class TrainingProgress:
70
+ """Container for training progress information."""
71
+ status: str = "idle" # idle, training, completed, failed
72
+ current_epoch: int = 0
73
+ total_epochs: int = 0
74
+ current_step: int = 0
75
+ total_steps: int = 0
76
+ progress_percent: float = 0.0
77
+ eta_seconds: float = 0.0
78
+ metrics_history: List[TrainingMetrics] = field(default_factory=list)
79
+ error_message: str = ""
80
+ model_path: Optional[str] = None
81
+ final_metrics: Optional[TrainingMetrics] = None
82
+ start_time: float = 0.0
83
+ end_time: float = 0.0
84
+
85
+ def update_progress(self):
86
+ """Update progress percentage."""
87
+ if self.total_steps > 0:
88
+ self.progress_percent = (self.current_step / self.total_steps) * 100
89
+
90
+ def get_elapsed_time(self) -> float:
91
+ """Get elapsed training time in seconds."""
92
+ if self.start_time == 0:
93
+ return 0.0
94
+ end = self.end_time if self.end_time > 0 else time.time()
95
+ return end - self.start_time
96
+
97
+
98
+ class TextClassificationDataset(Dataset):
99
+ """PyTorch Dataset for text classification."""
100
+
101
+ def __init__(self, texts: List[str], labels: List[int],
102
+ tokenizer, max_length: int = 256):
103
+ self.texts = texts
104
+ self.labels = labels
105
+ self.tokenizer = tokenizer
106
+ self.max_length = max_length
107
+
108
+ def __len__(self):
109
+ return len(self.texts)
110
+
111
+ def __getitem__(self, idx):
112
+ text = str(self.texts[idx])
113
+ label = self.labels[idx]
114
+
115
+ encoding = self.tokenizer(
116
+ text,
117
+ truncation=True,
118
+ padding='max_length',
119
+ max_length=self.max_length,
120
+ return_tensors='pt'
121
+ )
122
+
123
+ return {
124
+ 'input_ids': encoding['input_ids'].flatten(),
125
+ 'attention_mask': encoding['attention_mask'].flatten(),
126
+ 'labels': torch.tensor(label, dtype=torch.long)
127
+ }
128
+
129
+
130
+ class ProgressCallback(TrainerCallback):
131
+ """Custom callback for tracking training progress."""
132
+
133
+ def __init__(self, progress: TrainingProgress,
134
+ update_callback: Optional[Callable] = None):
135
+ self.progress = progress
136
+ self.update_callback = update_callback
137
+
138
+ def on_train_begin(self, args, state, control, **kwargs):
139
+ self.progress.status = "training"
140
+ self.progress.start_time = time.time()
141
+ self.progress.total_steps = state.max_steps
142
+
143
+ def on_step_end(self, args, state, control, **kwargs):
144
+ self.progress.current_step = state.global_step
145
+ self.progress.update_progress()
146
+
147
+ # Calculate ETA
148
+ if state.global_step > 0:
149
+ elapsed = time.time() - self.progress.start_time
150
+ steps_remaining = state.max_steps - state.global_step
151
+ time_per_step = elapsed / state.global_step
152
+ self.progress.eta_seconds = steps_remaining * time_per_step
153
+
154
+ if self.update_callback:
155
+ self.update_callback(self.progress)
156
+
157
+ def on_epoch_end(self, args, state, control, **kwargs):
158
+ self.progress.current_epoch = int(state.epoch)
159
+
160
+ def on_log(self, args, state, control, logs=None, **kwargs):
161
+ if logs:
162
+ metrics = TrainingMetrics(
163
+ epoch=int(state.epoch) if state.epoch else 0,
164
+ train_loss=logs.get('loss', 0.0),
165
+ eval_loss=logs.get('eval_loss', 0.0),
166
+ learning_rate=logs.get('learning_rate', 0.0),
167
+ timestamp=datetime.now().isoformat()
168
+ )
169
+ self.progress.metrics_history.append(metrics)
170
+
171
+ def on_train_end(self, args, state, control, **kwargs):
172
+ self.progress.status = "completed"
173
+ self.progress.end_time = time.time()
174
+ self.progress.progress_percent = 100.0
175
+
176
+
177
+ class ModelTrainer:
178
+ """
179
+ Main trainer class for text classification models.
180
+
181
+ Supports:
182
+ - Multiple model architectures (BERT, RoBERTa, XLM-RoBERTa, etc.)
183
+ - Progress tracking and callbacks
184
+ - Checkpointing and model saving
185
+ - Experiment logging
186
+ """
187
+
188
+ def __init__(self, config: TrainingConfig):
189
+ """
190
+ Initialize the trainer.
191
+
192
+ Args:
193
+ config: Training configuration
194
+ """
195
+ self.config = config
196
+ self.model = None
197
+ self.tokenizer = None
198
+ self.trainer = None
199
+ self.progress = TrainingProgress(total_epochs=config.num_epochs)
200
+ self._setup_output_dir()
201
+
202
+ def _setup_output_dir(self):
203
+ """Create output directory for models and logs."""
204
+ os.makedirs(self.config.output_dir, exist_ok=True)
205
+ os.makedirs(os.path.join(self.config.output_dir, "logs"), exist_ok=True)
206
+
207
+ def load_model(self, progress_callback: Optional[Callable] = None) -> bool:
208
+ """
209
+ Load model and tokenizer.
210
+
211
+ Returns:
212
+ True if successful, False otherwise
213
+ """
214
+ try:
215
+ logger.info(f"Loading model: {self.config.model_name}")
216
+
217
+ if progress_callback:
218
+ progress_callback("Loading tokenizer...")
219
+
220
+ self.tokenizer = AutoTokenizer.from_pretrained(
221
+ self.config.model_name,
222
+ use_fast=True
223
+ )
224
+
225
+ if progress_callback:
226
+ progress_callback("Loading model...")
227
+
228
+ self.model = AutoModelForSequenceClassification.from_pretrained(
229
+ self.config.model_name,
230
+ num_labels=self.config.num_labels,
231
+ ignore_mismatched_sizes=True
232
+ )
233
+
234
+ logger.info("Model and tokenizer loaded successfully")
235
+ return True
236
+
237
+ except Exception as e:
238
+ logger.error(f"Failed to load model: {str(e)}")
239
+ self.progress.status = "failed"
240
+ self.progress.error_message = str(e)
241
+ return False
242
+
243
+ def prepare_data(self, texts: List[str], labels: List[int]) -> Tuple[Dataset, Dataset, Dataset]:
244
+ """
245
+ Prepare datasets for training.
246
+
247
+ Args:
248
+ texts: List of text samples
249
+ labels: List of corresponding labels
250
+
251
+ Returns:
252
+ Tuple of (train_dataset, val_dataset, test_dataset)
253
+ """
254
+ # Split data
255
+ train_texts, temp_texts, train_labels, temp_labels = train_test_split(
256
+ texts, labels,
257
+ test_size=(1 - self.config.train_split),
258
+ random_state=self.config.random_seed,
259
+ stratify=labels if len(set(labels)) > 1 else None
260
+ )
261
+
262
+ # Split validation and test from remaining data
263
+ val_ratio = self.config.validation_split / (1 - self.config.train_split)
264
+ val_texts, test_texts, val_labels, test_labels = train_test_split(
265
+ temp_texts, temp_labels,
266
+ test_size=(1 - val_ratio),
267
+ random_state=self.config.random_seed,
268
+ stratify=temp_labels if len(set(temp_labels)) > 1 else None
269
+ )
270
+
271
+ # Create datasets
272
+ train_dataset = TextClassificationDataset(
273
+ train_texts, train_labels, self.tokenizer, self.config.max_length
274
+ )
275
+ val_dataset = TextClassificationDataset(
276
+ val_texts, val_labels, self.tokenizer, self.config.max_length
277
+ )
278
+ test_dataset = TextClassificationDataset(
279
+ test_texts, test_labels, self.tokenizer, self.config.max_length
280
+ )
281
+
282
+ logger.info(f"Data split: train={len(train_dataset)}, val={len(val_dataset)}, test={len(test_dataset)}")
283
+
284
+ return train_dataset, val_dataset, test_dataset
285
+
286
+ def compute_metrics(self, eval_pred) -> Dict[str, float]:
287
+ """Compute metrics for evaluation."""
288
+ predictions, labels = eval_pred
289
+ predictions = np.argmax(predictions, axis=1)
290
+
291
+ accuracy = accuracy_score(labels, predictions)
292
+ precision, recall, f1, _ = precision_recall_fscore_support(
293
+ labels, predictions, average='weighted', zero_division=0
294
+ )
295
+
296
+ return {
297
+ 'accuracy': accuracy,
298
+ 'precision': precision,
299
+ 'recall': recall,
300
+ 'f1': f1
301
+ }
302
+
303
+ def train(self, texts: List[str], labels: List[int],
304
+ progress_callback: Optional[Callable] = None,
305
+ status_callback: Optional[Callable] = None) -> TrainingProgress:
306
+ """
307
+ Train the model.
308
+
309
+ Args:
310
+ texts: Training texts
311
+ labels: Training labels
312
+ progress_callback: Optional callback for progress updates
313
+ status_callback: Optional callback for status messages
314
+
315
+ Returns:
316
+ TrainingProgress object with training results
317
+ """
318
+ try:
319
+ self.progress = TrainingProgress(total_epochs=self.config.num_epochs)
320
+
321
+ # Load model if not already loaded
322
+ if self.model is None:
323
+ if status_callback:
324
+ status_callback("Loading model...")
325
+ if not self.load_model(status_callback):
326
+ return self.progress
327
+
328
+ # Prepare data
329
+ if status_callback:
330
+ status_callback("Preparing datasets...")
331
+
332
+ train_dataset, val_dataset, test_dataset = self.prepare_data(texts, labels)
333
+
334
+ # Create unique output directory for this run
335
+ run_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
336
+ run_output_dir = os.path.join(
337
+ self.config.output_dir,
338
+ f"run_{run_timestamp}"
339
+ )
340
+ os.makedirs(run_output_dir, exist_ok=True)
341
+
342
+ # Save config
343
+ config_path = os.path.join(run_output_dir, "training_config.json")
344
+ with open(config_path, 'w', encoding='utf-8') as f:
345
+ json.dump(self.config.to_dict(), f, indent=2, ensure_ascii=False)
346
+
347
+ # Setup training arguments
348
+ training_args = TrainingArguments(
349
+ output_dir=run_output_dir,
350
+ num_train_epochs=self.config.num_epochs,
351
+ per_device_train_batch_size=self.config.batch_size,
352
+ per_device_eval_batch_size=self.config.batch_size,
353
+ warmup_ratio=self.config.warmup_ratio,
354
+ weight_decay=self.config.weight_decay,
355
+ learning_rate=self.config.learning_rate,
356
+ logging_dir=os.path.join(run_output_dir, "logs"),
357
+ logging_steps=self.config.logging_steps,
358
+ eval_strategy=self.config.eval_strategy,
359
+ save_strategy=self.config.eval_strategy,
360
+ load_best_model_at_end=self.config.save_best_model,
361
+ metric_for_best_model="f1",
362
+ greater_is_better=True,
363
+ save_total_limit=2,
364
+ fp16=self.config.use_fp16 and torch.cuda.is_available(),
365
+ gradient_accumulation_steps=self.config.gradient_accumulation_steps,
366
+ report_to="none", # Disable default reporting
367
+ seed=self.config.random_seed,
368
+ dataloader_pin_memory=False, # For CPU compatibility
369
+ )
370
+
371
+ # Create trainer with custom callback
372
+ progress_tracker = ProgressCallback(self.progress, progress_callback)
373
+
374
+ self.trainer = Trainer(
375
+ model=self.model,
376
+ args=training_args,
377
+ train_dataset=train_dataset,
378
+ eval_dataset=val_dataset,
379
+ compute_metrics=self.compute_metrics,
380
+ callbacks=[progress_tracker]
381
+ )
382
+
383
+ # Start training
384
+ if status_callback:
385
+ status_callback("Training started...")
386
+
387
+ logger.info("Starting model training...")
388
+ self.trainer.train()
389
+ logger.info("Training loop completed successfully")
390
+
391
+ # Evaluate on test set
392
+ if status_callback:
393
+ status_callback("Evaluating on test set...")
394
+
395
+ logger.info("Starting test set evaluation...")
396
+ test_results = self.trainer.evaluate(test_dataset)
397
+ logger.info(f"Test evaluation completed: {test_results}")
398
+
399
+ # Add final metrics
400
+ final_metrics = TrainingMetrics(
401
+ epoch=self.config.num_epochs,
402
+ eval_loss=test_results.get('eval_loss', 0),
403
+ accuracy=test_results.get('eval_accuracy', 0),
404
+ precision=test_results.get('eval_precision', 0),
405
+ recall=test_results.get('eval_recall', 0),
406
+ f1=test_results.get('eval_f1', 0),
407
+ timestamp=datetime.now().isoformat()
408
+ )
409
+ self.progress.metrics_history.append(final_metrics)
410
+
411
+ # Save model
412
+ if status_callback:
413
+ status_callback("Saving model...")
414
+
415
+ model_save_path = os.path.join(run_output_dir, "final_model")
416
+ logger.info(f"Saving model to {model_save_path}...")
417
+ os.makedirs(model_save_path, exist_ok=True)
418
+ self.trainer.save_model(model_save_path)
419
+ self.tokenizer.save_pretrained(model_save_path)
420
+ logger.info(f"Model saved successfully to {model_save_path}")
421
+
422
+ # Save training metrics
423
+ metrics_path = os.path.join(run_output_dir, "metrics.json")
424
+ with open(metrics_path, 'w', encoding='utf-8') as f:
425
+ json.dump({
426
+ "final_metrics": final_metrics.to_dict(),
427
+ "history": [m.to_dict() for m in self.progress.metrics_history],
428
+ "test_results": test_results
429
+ }, f, indent=2, ensure_ascii=False)
430
+
431
+ self.progress.status = "completed"
432
+ self.progress.model_path = model_save_path
433
+ self.progress.final_metrics = final_metrics
434
+
435
+ logger.info(f"Training completed! Model saved to {model_save_path}")
436
+
437
+ return self.progress
438
+
439
+ except Exception as e:
440
+ logger.error(f"Training failed: {str(e)}")
441
+ self.progress.status = "failed"
442
+ self.progress.error_message = str(e)
443
+ self.progress.end_time = time.time()
444
+ return self.progress
445
+
446
+ def get_model_path(self) -> Optional[str]:
447
+ """Get path to the trained model."""
448
+ if hasattr(self.progress, 'model_path'):
449
+ return self.progress.model_path
450
+ return None
451
+
452
+ def cleanup(self):
453
+ """Cleanup resources."""
454
+ if self.model is not None:
455
+ del self.model
456
+ self.model = None
457
+ if self.tokenizer is not None:
458
+ del self.tokenizer
459
+ self.tokenizer = None
460
+ if torch.cuda.is_available():
461
+ torch.cuda.empty_cache()
462
+
463
+
464
+ def create_trainer(config: TrainingConfig) -> ModelTrainer:
465
+ """Factory function to create a ModelTrainer instance."""
466
+ return ModelTrainer(config)