""" Model Trainer Module ==================== Provides model training functionality with progress tracking, checkpointing, and experiment logging. """ import os # Set environment variables before transformers import os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '3') os.environ.setdefault('TRANSFORMERS_NO_TF', '1') import json import time import logging from pathlib import Path from datetime import datetime from typing import Dict, List, Optional, Tuple, Callable, Any from dataclasses import dataclass, field import numpy as np import torch from torch.utils.data import Dataset, DataLoader from transformers import ( AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, EarlyStoppingCallback, TrainerCallback ) from sklearn.model_selection import train_test_split from sklearn.metrics import accuracy_score, precision_recall_fscore_support from .config import TrainingConfig logger = logging.getLogger(__name__) @dataclass class TrainingMetrics: """Container for training metrics.""" epoch: int = 0 train_loss: float = 0.0 eval_loss: float = 0.0 accuracy: float = 0.0 precision: float = 0.0 recall: float = 0.0 f1: float = 0.0 learning_rate: float = 0.0 timestamp: str = "" def to_dict(self) -> dict: return { "epoch": self.epoch, "train_loss": self.train_loss, "eval_loss": self.eval_loss, "accuracy": self.accuracy, "precision": self.precision, "recall": self.recall, "f1": self.f1, "learning_rate": self.learning_rate, "timestamp": self.timestamp } @dataclass class TrainingProgress: """Container for training progress information.""" status: str = "idle" # idle, training, completed, failed current_epoch: int = 0 total_epochs: int = 0 current_step: int = 0 total_steps: int = 0 progress_percent: float = 0.0 eta_seconds: float = 0.0 metrics_history: List[TrainingMetrics] = field(default_factory=list) error_message: str = "" model_path: Optional[str] = None final_metrics: Optional[TrainingMetrics] = None start_time: float = 0.0 end_time: float = 0.0 def update_progress(self): """Update progress percentage.""" if self.total_steps > 0: self.progress_percent = (self.current_step / self.total_steps) * 100 def get_elapsed_time(self) -> float: """Get elapsed training time in seconds.""" if self.start_time == 0: return 0.0 end = self.end_time if self.end_time > 0 else time.time() return end - self.start_time class TextClassificationDataset(Dataset): """PyTorch Dataset for text classification.""" def __init__(self, texts: List[str], labels: List[int], tokenizer, max_length: int = 256): self.texts = texts self.labels = labels self.tokenizer = tokenizer self.max_length = max_length def __len__(self): return len(self.texts) def __getitem__(self, idx): text = str(self.texts[idx]) label = self.labels[idx] encoding = self.tokenizer( text, truncation=True, padding='max_length', max_length=self.max_length, return_tensors='pt' ) return { 'input_ids': encoding['input_ids'].flatten(), 'attention_mask': encoding['attention_mask'].flatten(), 'labels': torch.tensor(label, dtype=torch.long) } class ProgressCallback(TrainerCallback): """Custom callback for tracking training progress.""" def __init__(self, progress: TrainingProgress, update_callback: Optional[Callable] = None): self.progress = progress self.update_callback = update_callback def on_train_begin(self, args, state, control, **kwargs): self.progress.status = "training" self.progress.start_time = time.time() self.progress.total_steps = state.max_steps def on_step_end(self, args, state, control, **kwargs): self.progress.current_step = state.global_step self.progress.update_progress() # Calculate ETA if state.global_step > 0: elapsed = time.time() - self.progress.start_time steps_remaining = state.max_steps - state.global_step time_per_step = elapsed / state.global_step self.progress.eta_seconds = steps_remaining * time_per_step if self.update_callback: self.update_callback(self.progress) def on_epoch_end(self, args, state, control, **kwargs): self.progress.current_epoch = int(state.epoch) def on_log(self, args, state, control, logs=None, **kwargs): if logs: metrics = TrainingMetrics( epoch=int(state.epoch) if state.epoch else 0, train_loss=logs.get('loss', 0.0), eval_loss=logs.get('eval_loss', 0.0), learning_rate=logs.get('learning_rate', 0.0), timestamp=datetime.now().isoformat() ) self.progress.metrics_history.append(metrics) def on_train_end(self, args, state, control, **kwargs): self.progress.status = "completed" self.progress.end_time = time.time() self.progress.progress_percent = 100.0 class ModelTrainer: """ Main trainer class for text classification models. Supports: - Multiple model architectures (BERT, RoBERTa, XLM-RoBERTa, etc.) - Progress tracking and callbacks - Checkpointing and model saving - Experiment logging """ def __init__(self, config: TrainingConfig): """ Initialize the trainer. Args: config: Training configuration """ self.config = config self.model = None self.tokenizer = None self.trainer = None self.progress = TrainingProgress(total_epochs=config.num_epochs) self._setup_output_dir() def _setup_output_dir(self): """Create output directory for models and logs.""" os.makedirs(self.config.output_dir, exist_ok=True) os.makedirs(os.path.join(self.config.output_dir, "logs"), exist_ok=True) def load_model(self, progress_callback: Optional[Callable] = None) -> bool: """ Load model and tokenizer. Returns: True if successful, False otherwise """ try: logger.info(f"Loading model: {self.config.model_name}") if progress_callback: progress_callback("Loading tokenizer...") self.tokenizer = AutoTokenizer.from_pretrained( self.config.model_name, use_fast=True ) if progress_callback: progress_callback("Loading model...") self.model = AutoModelForSequenceClassification.from_pretrained( self.config.model_name, num_labels=self.config.num_labels, ignore_mismatched_sizes=True ) logger.info("Model and tokenizer loaded successfully") return True except Exception as e: logger.error(f"Failed to load model: {str(e)}") self.progress.status = "failed" self.progress.error_message = str(e) return False def prepare_data(self, texts: List[str], labels: List[int]) -> Tuple[Dataset, Dataset, Dataset]: """ Prepare datasets for training. Args: texts: List of text samples labels: List of corresponding labels Returns: Tuple of (train_dataset, val_dataset, test_dataset) """ # Split data train_texts, temp_texts, train_labels, temp_labels = train_test_split( texts, labels, test_size=(1 - self.config.train_split), random_state=self.config.random_seed, stratify=labels if len(set(labels)) > 1 else None ) # Split validation and test from remaining data val_ratio = self.config.validation_split / (1 - self.config.train_split) val_texts, test_texts, val_labels, test_labels = train_test_split( temp_texts, temp_labels, test_size=(1 - val_ratio), random_state=self.config.random_seed, stratify=temp_labels if len(set(temp_labels)) > 1 else None ) # Create datasets train_dataset = TextClassificationDataset( train_texts, train_labels, self.tokenizer, self.config.max_length ) val_dataset = TextClassificationDataset( val_texts, val_labels, self.tokenizer, self.config.max_length ) test_dataset = TextClassificationDataset( test_texts, test_labels, self.tokenizer, self.config.max_length ) logger.info(f"Data split: train={len(train_dataset)}, val={len(val_dataset)}, test={len(test_dataset)}") return train_dataset, val_dataset, test_dataset def compute_metrics(self, eval_pred) -> Dict[str, float]: """Compute metrics for evaluation.""" predictions, labels = eval_pred predictions = np.argmax(predictions, axis=1) accuracy = accuracy_score(labels, predictions) precision, recall, f1, _ = precision_recall_fscore_support( labels, predictions, average='weighted', zero_division=0 ) return { 'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1 } def train(self, texts: List[str], labels: List[int], progress_callback: Optional[Callable] = None, status_callback: Optional[Callable] = None) -> TrainingProgress: """ Train the model. Args: texts: Training texts labels: Training labels progress_callback: Optional callback for progress updates status_callback: Optional callback for status messages Returns: TrainingProgress object with training results """ try: self.progress = TrainingProgress(total_epochs=self.config.num_epochs) # Load model if not already loaded if self.model is None: if status_callback: status_callback("Loading model...") if not self.load_model(status_callback): return self.progress # Prepare data if status_callback: status_callback("Preparing datasets...") train_dataset, val_dataset, test_dataset = self.prepare_data(texts, labels) # Create unique output directory for this run run_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") run_output_dir = os.path.join( self.config.output_dir, f"run_{run_timestamp}" ) os.makedirs(run_output_dir, exist_ok=True) # Save config config_path = os.path.join(run_output_dir, "training_config.json") with open(config_path, 'w', encoding='utf-8') as f: json.dump(self.config.to_dict(), f, indent=2, ensure_ascii=False) # Setup training arguments training_args = TrainingArguments( output_dir=run_output_dir, num_train_epochs=self.config.num_epochs, per_device_train_batch_size=self.config.batch_size, per_device_eval_batch_size=self.config.batch_size, warmup_ratio=self.config.warmup_ratio, weight_decay=self.config.weight_decay, learning_rate=self.config.learning_rate, logging_dir=os.path.join(run_output_dir, "logs"), logging_steps=self.config.logging_steps, eval_strategy=self.config.eval_strategy, save_strategy=self.config.eval_strategy, load_best_model_at_end=self.config.save_best_model, metric_for_best_model="f1", greater_is_better=True, save_total_limit=2, fp16=self.config.use_fp16 and torch.cuda.is_available(), gradient_accumulation_steps=self.config.gradient_accumulation_steps, report_to="none", # Disable default reporting seed=self.config.random_seed, dataloader_pin_memory=False, # For CPU compatibility ) # Create trainer with custom callback progress_tracker = ProgressCallback(self.progress, progress_callback) self.trainer = Trainer( model=self.model, args=training_args, train_dataset=train_dataset, eval_dataset=val_dataset, compute_metrics=self.compute_metrics, callbacks=[progress_tracker] ) # Start training if status_callback: status_callback("Training started...") logger.info("Starting model training...") self.trainer.train() logger.info("Training loop completed successfully") # Evaluate on test set if status_callback: status_callback("Evaluating on test set...") logger.info("Starting test set evaluation...") test_results = self.trainer.evaluate(test_dataset) logger.info(f"Test evaluation completed: {test_results}") # Add final metrics final_metrics = TrainingMetrics( epoch=self.config.num_epochs, eval_loss=test_results.get('eval_loss', 0), accuracy=test_results.get('eval_accuracy', 0), precision=test_results.get('eval_precision', 0), recall=test_results.get('eval_recall', 0), f1=test_results.get('eval_f1', 0), timestamp=datetime.now().isoformat() ) self.progress.metrics_history.append(final_metrics) # Save model if status_callback: status_callback("Saving model...") model_save_path = os.path.join(run_output_dir, "final_model") logger.info(f"Saving model to {model_save_path}...") os.makedirs(model_save_path, exist_ok=True) self.trainer.save_model(model_save_path) self.tokenizer.save_pretrained(model_save_path) logger.info(f"Model saved successfully to {model_save_path}") # Save training metrics metrics_path = os.path.join(run_output_dir, "metrics.json") with open(metrics_path, 'w', encoding='utf-8') as f: json.dump({ "final_metrics": final_metrics.to_dict(), "history": [m.to_dict() for m in self.progress.metrics_history], "test_results": test_results }, f, indent=2, ensure_ascii=False) self.progress.status = "completed" self.progress.model_path = model_save_path self.progress.final_metrics = final_metrics logger.info(f"Training completed! Model saved to {model_save_path}") return self.progress except Exception as e: logger.error(f"Training failed: {str(e)}") self.progress.status = "failed" self.progress.error_message = str(e) self.progress.end_time = time.time() return self.progress def get_model_path(self) -> Optional[str]: """Get path to the trained model.""" if hasattr(self.progress, 'model_path'): return self.progress.model_path return None def cleanup(self): """Cleanup resources.""" if self.model is not None: del self.model self.model = None if self.tokenizer is not None: del self.tokenizer self.tokenizer = None if torch.cuda.is_available(): torch.cuda.empty_cache() def create_trainer(config: TrainingConfig) -> ModelTrainer: """Factory function to create a ModelTrainer instance.""" return ModelTrainer(config)