| |
| from asyncio.log import logger |
| from dataclasses import dataclass |
| from typing import Dict, List |
| import json |
| import torch |
| import numpy as np |
| from pathlib import Path |
| from contextlib import nullcontext |
| from dataclasses import asdict |
| import os |
|
|
| @dataclass |
| class DynamicClassWeights: |
| """Handles class weights per language using dynamic batch statistics""" |
| weights_file: str = 'weights/language_class_weights.json' |
| |
| def __init__(self, weights_file: str = 'weights/language_class_weights.json'): |
| self.weights_file = weights_file |
| self.toxicity_labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] |
| self.language_columns = ['en', 'es', 'fr', 'it', 'tr', 'pt', 'ru'] |
| |
| |
| try: |
| with open(self.weights_file, 'r') as f: |
| data = json.load(f) |
| self.lang_scaling = {} |
| for lang in self.language_columns: |
| if lang in data['weights']: |
| |
| scales = [float(data['weights'][lang][label]['1']) |
| for label in self.toxicity_labels] |
| self.lang_scaling[lang] = sum(scales) / len(scales) |
| else: |
| self.lang_scaling[lang] = 1.0 |
| except Exception as e: |
| logger.warning(f"Could not load weights from {self.weights_file}: {str(e)}") |
| self._initialize_defaults() |
| |
| |
| self.running_stats = {lang: { |
| 'pos_counts': torch.zeros(len(self.toxicity_labels)), |
| 'total_counts': torch.zeros(len(self.toxicity_labels)), |
| 'smoothing_factor': 0.95 |
| } for lang in self.language_columns} |
| |
| def _initialize_defaults(self): |
| """Initialize safe default scaling factors""" |
| self.lang_scaling = {lang: 1.0 for lang in self.language_columns} |
| |
| def _update_running_stats(self, langs, labels): |
| """Update running statistics for each language""" |
| unique_langs = set(langs) |
| for lang in unique_langs: |
| if lang not in self.running_stats: |
| continue |
| |
| lang_mask = torch.tensor([l == lang for l in langs], dtype=torch.bool) |
| lang_labels = labels[lang_mask] |
| |
| if len(lang_labels) == 0: |
| continue |
| |
| |
| pos_count = lang_labels.sum(dim=0).float() |
| total_count = torch.full_like(pos_count, len(lang_labels)) |
| |
| |
| alpha = self.running_stats[lang]['smoothing_factor'] |
| self.running_stats[lang]['pos_counts'] = ( |
| alpha * self.running_stats[lang]['pos_counts'] + |
| (1 - alpha) * pos_count |
| ) |
| self.running_stats[lang]['total_counts'] = ( |
| alpha * self.running_stats[lang]['total_counts'] + |
| (1 - alpha) * total_count |
| ) |
| |
| def get_weights_for_batch(self, langs: List[str], labels: torch.Tensor, device: torch.device) -> Dict[str, torch.Tensor]: |
| """ |
| Calculate dynamic weights and focal parameters based on batch and historical statistics |
| Args: |
| langs: List of language codes |
| labels: Binary labels tensor [batch_size, num_labels] |
| device: Target device for tensors |
| Returns: |
| Dict with weights, alpha, and gamma tensors |
| """ |
| try: |
| batch_size = len(langs) |
| num_labels = labels.size(1) |
| |
| |
| self._update_running_stats(langs, labels) |
| |
| |
| lang_pos_ratios = {} |
| batch_pos_ratios = torch.zeros(num_labels, device=device) |
| lang_counts = {} |
| |
| for lang in set(langs): |
| lang_mask = torch.tensor([l == lang for l in langs], dtype=torch.bool, device=device) |
| if not lang_mask.any(): |
| continue |
| |
| |
| lang_labels = labels[lang_mask] |
| lang_pos_ratio = lang_labels.float().mean(dim=0) |
| lang_pos_ratios[lang] = lang_pos_ratio |
| |
| |
| lang_count = lang_mask.sum() |
| lang_counts[lang] = lang_count |
| batch_pos_ratios += lang_pos_ratio * (lang_count / batch_size) |
| |
| |
| weights = torch.ones(batch_size, num_labels, device=device) |
| alpha = torch.zeros(num_labels, device=device) |
| gamma = torch.zeros(num_labels, device=device) |
| |
| for i, (lang, label_vec) in enumerate(zip(langs, labels)): |
| if lang not in self.running_stats: |
| continue |
| |
| |
| hist_pos_ratio = ( |
| self.running_stats[lang]['pos_counts'] / |
| (self.running_stats[lang]['total_counts'] + 1e-7) |
| ).to(device) |
| |
| |
| current_pos_ratio = lang_pos_ratios.get(lang, batch_pos_ratios) |
| combined_pos_ratio = 0.7 * hist_pos_ratio + 0.3 * current_pos_ratio |
| |
| |
| log_ratio = torch.log1p(1.0 / (combined_pos_ratio + 1e-7)) |
| class_weights = torch.exp(log_ratio.clamp(-2, 2)) |
| |
| |
| weights[i] = class_weights * self.lang_scaling.get(lang, 1.0) |
| |
| |
| alpha_contrib = 1.0 / (combined_pos_ratio + 1e-7).clamp(0.05, 0.95) |
| gamma_contrib = log_ratio.clamp(1.0, 4.0) |
| |
| |
| weight = lang_counts.get(lang, 1) / batch_size |
| alpha += alpha_contrib * weight |
| gamma += gamma_contrib * weight |
| |
| |
| |
| class_adjustments = { |
| 'en': [1.0, 1.0, 0.9, 0.85, 1.1, 1.0], |
| 'ru': [1.0, 1.0, 1.0, 1.0, 0.9, 1.0], |
| 'tr': [1.0, 1.0, 1.0, 1.0, 0.9, 0.95], |
| 'es': [1.0, 1.0, 1.0, 1.0, 0.9, 1.0], |
| 'fr': [1.0, 1.0, 1.0, 1.0, 0.9, 1.0], |
| 'it': [1.0, 1.0, 1.0, 1.0, 0.9, 1.0], |
| 'pt': [1.0, 1.0, 1.0, 1.0, 0.9, 1.0] |
| } |
| |
| |
| for i, lang in enumerate(langs): |
| if lang in class_adjustments: |
| |
| weights[i] *= torch.tensor(class_adjustments[lang], device=device) |
| |
| |
| weights = weights / weights.mean() |
| |
| return { |
| 'weights': weights.clamp(0.1, 10.0), |
| 'alpha': alpha.clamp(0.1, 5.0), |
| 'gamma': gamma.clamp(1.0, 4.0) |
| } |
| |
| except Exception as e: |
| logger.error(f"Error computing batch weights: {str(e)}") |
| |
| return { |
| 'weights': torch.ones((batch_size, num_labels), device=device), |
| 'alpha': torch.full((num_labels,), 0.25, device=device), |
| 'gamma': torch.full((num_labels,), 2.0, device=device) |
| } |
|
|
| @dataclass |
| class MetricsTracker: |
| """Tracks training and validation metrics with error handling""" |
| best_auc: float = 0.0 |
| train_losses: List[float] = None |
| val_losses: List[float] = None |
| val_aucs: List[float] = None |
| epoch_times: List[float] = None |
| |
| def __post_init__(self): |
| self.train_losses = [] |
| self.val_losses = [] |
| self.val_aucs = [] |
| self.epoch_times = [] |
| |
| def update_train(self, loss: float): |
| """Update training metrics with validation""" |
| try: |
| if not isinstance(loss, (int, float)) or np.isnan(loss) or np.isinf(loss): |
| print(f"Warning: Invalid loss value: {loss}") |
| return |
| self.train_losses.append(float(loss)) |
| except Exception as e: |
| print(f"Warning: Could not update training metrics: {str(e)}") |
| |
| def update_validation(self, metrics: Dict) -> bool: |
| """Update validation metrics with error handling""" |
| try: |
| if not isinstance(metrics, dict): |
| raise ValueError("Metrics must be a dictionary") |
| |
| loss = metrics.get('loss', float('inf')) |
| auc = metrics.get('auc', 0.0) |
| |
| if np.isnan(loss) or np.isinf(loss): |
| print(f"Warning: Invalid loss value: {loss}") |
| loss = float('inf') |
| |
| if np.isnan(auc) or np.isinf(auc): |
| print(f"Warning: Invalid AUC value: {auc}") |
| auc = 0.0 |
| |
| self.val_losses.append(float(loss)) |
| self.val_aucs.append(float(auc)) |
| |
| |
| if auc > self.best_auc: |
| self.best_auc = auc |
| return True |
| return False |
| |
| except Exception as e: |
| print(f"Warning: Could not update validation metrics: {str(e)}") |
| return False |
| |
| def update_time(self, epoch_time: float): |
| """Update timing metrics with validation""" |
| try: |
| if not isinstance(epoch_time, (int, float)) or epoch_time <= 0: |
| print(f"Warning: Invalid epoch time: {epoch_time}") |
| return |
| self.epoch_times.append(float(epoch_time)) |
| except Exception as e: |
| print(f"Warning: Could not update timing metrics: {str(e)}") |
| |
| def get_eta(self, current_epoch: int, total_epochs: int) -> str: |
| """Calculate ETA based on average epoch time with error handling""" |
| try: |
| if not self.epoch_times: |
| return "Calculating..." |
| |
| if current_epoch >= total_epochs: |
| return "Complete" |
| |
| avg_epoch_time = sum(self.epoch_times) / len(self.epoch_times) |
| remaining_epochs = total_epochs - current_epoch |
| eta_seconds = avg_epoch_time * remaining_epochs |
| |
| hours = int(eta_seconds // 3600) |
| minutes = int((eta_seconds % 3600) // 60) |
| |
| return f"{hours:02d}:{minutes:02d}:00" |
| |
| except Exception as e: |
| print(f"Warning: Could not calculate ETA: {str(e)}") |
| return "Unknown" |
|
|
| @dataclass |
| class TrainingConfig: |
| """Basic training configuration with consolidated default values""" |
| |
| model_name: str = "xlm-roberta-large" |
| max_length: int = 512 |
| hidden_size: int = 1024 |
| num_attention_heads: int = 16 |
| model_dropout: float = 0.0 |
| freeze_layers: int = 8 |
| |
| |
| cache_dir: str = 'cached_dataset' |
| label_columns: List[str] = None |
| |
| |
| batch_size: int = 128 |
| grad_accum_steps: int = 1 |
| epochs: int = 6 |
| lr: float = 2e-5 |
| num_cycles: int = 2 |
| weight_decay: float = 2e-7 |
| max_grad_norm: float = 1.0 |
| warmup_ratio: float = 0.1 |
| label_smoothing: float = 0.01 |
| min_lr_ratio: float = 0.01 |
| |
| |
| activation_checkpointing: bool = True |
| mixed_precision: str = "fp16" |
| _num_workers: int = None |
| gc_frequency: int = 500 |
| tensor_float_32: bool = True |
| |
| |
| num_cycles: int = 2 |
|
|
| def __post_init__(self): |
| """Initialize and validate configuration""" |
| |
| self.label_columns = [ |
| 'toxic', 'severe_toxic', 'obscene', |
| 'threat', 'insult', 'identity_hate' |
| ] |
| |
| |
| os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128,expandable_segments:True' |
| |
| |
| if self.lr <= 0: |
| raise ValueError(f"Learning rate must be positive, got {self.lr}") |
| if self.lr < 1e-7: |
| raise ValueError(f"Learning rate too small: {self.lr}") |
| if self.lr > 1.0: |
| raise ValueError(f"Learning rate too large: {self.lr}") |
| |
| |
| if self.weight_decay > 0: |
| wd_to_lr_ratio = self.weight_decay / self.lr |
| if wd_to_lr_ratio > 0.1: |
| logger.warning( |
| "Weight decay too high: %.2e (%.2fx learning rate). " |
| "Should be 0.01-0.1x learning rate.", |
| self.weight_decay, wd_to_lr_ratio |
| ) |
| effective_lr = self.lr * (1 - self.weight_decay) |
| if effective_lr < self.lr * 0.9: |
| logger.warning( |
| "Weight decay %.2e reduces effective learning rate to %.2e (%.1f%% reduction)", |
| self.weight_decay, effective_lr, (1 - effective_lr/self.lr) * 100 |
| ) |
| |
| |
| if torch.cuda.is_available(): |
| try: |
| torch.cuda.init() |
| |
| torch.cuda.set_per_process_memory_fraction(0.95) |
| self.device = torch.device('cuda') |
| |
| if self.mixed_precision == "bf16": |
| if not torch.cuda.is_bf16_supported(): |
| print("Warning: BF16 not supported on this GPU. Falling back to FP16") |
| self.mixed_precision = "fp16" |
| |
| if self.tensor_float_32: |
| if torch.cuda.get_device_capability()[0] >= 8: |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
| else: |
| print("Warning: TF32 not supported on this GPU. Disabling.") |
| self.tensor_float_32 = False |
| |
| except Exception as e: |
| print(f"Warning: CUDA initialization failed: {str(e)}") |
| self.device = torch.device('cpu') |
| self.mixed_precision = "no" |
| else: |
| self.device = torch.device('cpu') |
| if self.mixed_precision != "no": |
| print("Warning: Mixed precision not supported on CPU. Disabling.") |
| self.mixed_precision = "no" |
| |
| |
| try: |
| for directory in ["weights", "logs"]: |
| dir_path = Path(directory) |
| if not dir_path.exists(): |
| dir_path.mkdir(parents=True) |
| elif not dir_path.is_dir(): |
| raise NotADirectoryError(f"{directory} exists but is not a directory") |
| except Exception as e: |
| print(f"Error creating directories: {str(e)}") |
| raise |
| |
| |
| self.toxicity_labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] |
| self.num_labels = len(self.toxicity_labels) |
| |
| |
| self.use_mixed_precision = self.mixed_precision != "no" |
|
|
| def validate_model_config(self, model): |
| """Validate configuration against model architecture""" |
| try: |
| |
| if self.freeze_layers > 0: |
| total_layers = len(list(model.base_model.encoder.layer)) |
| if self.freeze_layers > total_layers: |
| raise ValueError(f"Can't freeze {self.freeze_layers} layers in {total_layers}-layer model") |
| logger.info(f"Freezing {self.freeze_layers} out of {total_layers} layers") |
| |
| |
| param_groups = self.get_param_groups(model) |
| if self.weight_decay > 0: |
| low_lr_groups = [g for g in param_groups if g['lr'] < 0.01] |
| if low_lr_groups: |
| logger.warning("Found parameter groups with low learning rates (< 0.01) and non-zero weight decay:") |
| for group in low_lr_groups: |
| logger.warning(f"Group with lr={group['lr']:.4f}") |
| |
| return True |
| except Exception as e: |
| logger.error(f"Model configuration validation failed: {str(e)}") |
| raise |
|
|
| @property |
| def dtype(self) -> torch.dtype: |
| """Get the appropriate dtype based on mixed precision settings""" |
| if self.mixed_precision == "bf16": |
| return torch.bfloat16 |
| elif self.mixed_precision == "fp16": |
| return torch.float16 |
| return torch.float32 |
| |
| def get_autocast_context(self): |
| """Get the appropriate autocast context based on configuration.""" |
| if not self.use_mixed_precision: |
| return nullcontext() |
| dtype = torch.bfloat16 if self.mixed_precision == "bf16" else torch.float16 |
| return torch.autocast(device_type=self.device.type, dtype=dtype) |
| |
| def to_serializable_dict(self): |
| """Convert config to a dictionary for saving.""" |
| config_dict = asdict(self) |
| return config_dict |
| |
| def get_param_groups(self, model): |
| """Get parameter groups with base learning rate""" |
| return [{'params': model.parameters(), 'lr': self.lr}] |
| |
| @property |
| def use_amp(self): |
| """Check if AMP should be used based on device and mixed precision setting""" |
| return self.device.type == 'cuda' and self.mixed_precision != "no" |
| |
| @property |
| def grad_norm_clip(self): |
| """Adaptive gradient clipping based on precision""" |
| if self.mixed_precision == "bf16": |
| return 1.5 |
| if self.mixed_precision == "fp16": |
| return 1.0 |
| return 5.0 |
|
|
| @property |
| def num_workers(self): |
| """Dynamically adjust workers based on system resources""" |
| if self._num_workers is None: |
| cpu_count = os.cpu_count() |
| if cpu_count is None: |
| self._num_workers = 0 |
| else: |
| |
| self._num_workers = min(4, max(0, cpu_count - 2)) |
| logger.info(f"Dynamically set num_workers to {self._num_workers} (CPU count: {cpu_count})") |
| return self._num_workers |
| |
| @num_workers.setter |
| def num_workers(self, value): |
| """Allow manual override of num_workers""" |
| self._num_workers = value |
| logger.info(f"Manually set num_workers to {value}") |