| """ |
| VicAI Utilities |
| Helper functions for training and evaluation. |
| """ |
|
|
| import json |
| import logging |
| import math |
| import os |
| import sys |
| from pathlib import Path |
| from typing import Dict, Optional |
|
|
| import torch |
| import torch.distributed as dist |
| from torch.optim import AdamW |
|
|
|
|
| def get_logger(name: str, log_file: Optional[Path] = None) -> logging.Logger: |
| """Create a logger with file and console handlers.""" |
| logger = logging.getLogger(name) |
| logger.setLevel(logging.INFO) |
| |
| |
| logger.handlers = [] |
| |
| |
| formatter = logging.Formatter( |
| '%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
| datefmt='%Y-%m-%d %H:%M:%S' |
| ) |
| |
| |
| console_handler = logging.StreamHandler(sys.stdout) |
| console_handler.setLevel(logging.INFO) |
| console_handler.setFormatter(formatter) |
| logger.addHandler(console_handler) |
| |
| |
| if log_file: |
| log_file.parent.mkdir(parents=True, exist_ok=True) |
| file_handler = logging.FileHandler(log_file) |
| file_handler.setLevel(logging.INFO) |
| file_handler.setFormatter(formatter) |
| logger.addHandler(file_handler) |
| |
| return logger |
|
|
|
|
| def save_checkpoint( |
| model, |
| optimizer, |
| scaler, |
| step: int, |
| loss: float, |
| path: Path, |
| ): |
| """Save model checkpoint.""" |
| path.parent.mkdir(parents=True, exist_ok=True) |
| |
| |
| state_dict = model.state_dict() |
| if hasattr(model, 'module'): |
| state_dict = model.module.state_dict() |
| |
| checkpoint = { |
| 'model': state_dict, |
| 'optimizer': optimizer.state_dict(), |
| 'scaler': scaler.state_dict() if scaler else None, |
| 'step': step, |
| 'loss': loss, |
| } |
| |
| torch.save(checkpoint, path) |
|
|
|
|
| def load_checkpoint( |
| model, |
| optimizer, |
| scaler, |
| path: str, |
| device, |
| ): |
| """Load model checkpoint.""" |
| checkpoint = torch.load(path, map_location=device) |
| |
| |
| state_dict = checkpoint['model'] |
| if hasattr(model, 'module'): |
| model.module.load_state_dict(state_dict) |
| else: |
| model.load_state_dict(state_dict) |
| |
| optimizer.load_state_dict(checkpoint['optimizer']) |
| |
| if scaler and checkpoint.get('scaler'): |
| scaler.load_state_dict(checkpoint['scaler']) |
| |
| return checkpoint.get('step', 0) |
|
|
|
|
| def get_lr_scheduler(optimizer, args): |
| """Create learning rate scheduler with warmup and cosine decay.""" |
| |
| def lr_lambda(current_step): |
| if current_step < args.warmup_steps: |
| |
| return current_step / args.warmup_steps |
| else: |
| |
| progress = (current_step - args.warmup_steps) / (args.max_steps - args.warmup_steps) |
| progress = min(1.0, progress) |
| cosine_decay = 0.5 * (1 + math.cos(math.pi * progress)) |
| return args.min_lr / args.learning_rate + (1 - args.min_lr / args.learning_rate) * cosine_decay |
| |
| from torch.optim.lr_scheduler import LambdaLR |
| return LambdaLR(optimizer, lr_lambda) |
|
|
|
|
| def configure_optimizers(model, args): |
| """Configure optimizer with weight decay.""" |
| |
| decay_params = [] |
| no_decay_params = [] |
| |
| for name, param in model.named_parameters(): |
| if not param.requires_grad: |
| continue |
| |
| |
| if 'bias' in name or 'norm' in name or 'embedding' in name: |
| no_decay_params.append(param) |
| else: |
| decay_params.append(param) |
| |
| param_groups = [ |
| {'params': decay_params, 'weight_decay': args.weight_decay}, |
| {'params': no_decay_params, 'weight_decay': 0.0}, |
| ] |
| |
| optimizer = AdamW( |
| param_groups, |
| lr=args.learning_rate, |
| betas=(args.beta1, args.beta2), |
| eps=1e-8, |
| ) |
| |
| return optimizer |
|
|
|
|
| def estimate_loss(model, data_loader, device, num_batches=10): |
| """Estimate loss on a data loader.""" |
| model.eval() |
| total_loss = 0 |
| |
| with torch.no_grad(): |
| for i, batch in enumerate(data_loader): |
| if i >= num_batches: |
| break |
| |
| input_ids = batch['input_ids'].to(device) |
| labels = batch['labels'].to(device) |
| |
| outputs = model(input_ids, targets=labels) |
| total_loss += outputs['loss'].item() |
| |
| model.train() |
| return total_loss / num_batches |
|
|
|
|
| def get_grad_norm(model): |
| """Calculate gradient norm.""" |
| total_norm = 0.0 |
| for p in model.parameters(): |
| if p.grad is not None: |
| total_norm += p.grad.data.norm(2).item() ** 2 |
| return total_norm ** 0.5 |
|
|
|
|
| def clip_gradients(model, max_norm): |
| """Clip gradients by norm.""" |
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) |
|
|
|
|
| class AverageMeter: |
| """Track running average of metrics.""" |
| |
| def __init__(self): |
| self.reset() |
| |
| def reset(self): |
| self.val = 0 |
| self.avg = 0 |
| self.sum = 0 |
| self.count = 0 |
| |
| def update(self, val, n=1): |
| self.val = val |
| self.sum += val * n |
| self.count += n |
| self.avg = self.sum / self.count |
|
|
|
|
| class EarlyStopping: |
| """Early stopping to prevent overfitting.""" |
| |
| def __init__(self, patience=5, min_delta=0.0): |
| self.patience = patience |
| self.min_delta = min_delta |
| self.counter = 0 |
| self.best_loss = None |
| self.early_stop = False |
| |
| def __call__(self, val_loss): |
| if self.best_loss is None: |
| self.best_loss = val_loss |
| elif val_loss > self.best_loss - self.min_delta: |
| self.counter += 1 |
| if self.counter >= self.patience: |
| self.early_stop = True |
| else: |
| self.best_loss = val_loss |
| self.counter = 0 |
| |
| return self.early_stop |
|
|
|
|
| def count_parameters(model): |
| """Count trainable parameters.""" |
| return sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
|
|
| def format_num_parameters(num_params): |
| """Format parameter count for display.""" |
| if num_params >= 1e9: |
| return f"{num_params / 1e9:.2f}B" |
| elif num_params >= 1e6: |
| return f"{num_params / 1e6:.2f}M" |
| elif num_params >= 1e3: |
| return f"{num_params / 1e3:.2f}K" |
| else: |
| return str(num_params) |
|
|
|
|
| def get_device_info(): |
| """Get information about available GPUs.""" |
| if not torch.cuda.is_available(): |
| return "No CUDA available" |
| |
| info = [] |
| for i in range(torch.cuda.device_count()): |
| props = torch.cuda.get_device_properties(i) |
| info.append( |
| f"GPU {i}: {props.name} ({props.total_memory / 1e9:.1f} GB)" |
| ) |
| |
| return "\n".join(info) |
|
|
|
|
| def print_model_summary(model): |
| """Print a summary of the model architecture.""" |
| print("\n" + "=" * 60) |
| print("MODEL SUMMARY") |
| print("=" * 60) |
| |
| total_params = 0 |
| trainable_params = 0 |
| |
| print(f"\n{'Layer':<40} {'Parameters':>15} {'Trainable':>10}") |
| print("-" * 70) |
| |
| for name, param in model.named_parameters(): |
| num_params = param.numel() |
| total_params += num_params |
| if param.requires_grad: |
| trainable_params += num_params |
| trainable = "Yes" |
| else: |
| trainable = "No" |
| |
| print(f"{name:<40} {num_params:>15,} {trainable:>10}") |
| |
| print("-" * 70) |
| print(f"{'Total':<40} {total_params:>15,}") |
| print(f"{'Trainable':<40} {trainable_params:>15,}") |
| print(f"{'Non-trainable':<40} {total_params - trainable_params:>15,}") |
| print("=" * 60 + "\n") |
|
|
|
|
| def save_training_config(args, output_path: Path): |
| """Save training configuration to JSON.""" |
| config = vars(args) |
| with open(output_path, 'w') as f: |
| json.dump(config, f, indent=2) |
|
|
|
|
| def load_training_config(config_path: Path): |
| """Load training configuration from JSON.""" |
| with open(config_path, 'r') as f: |
| return json.load(f) |
|
|
|
|
| def all_reduce_dict(data: Dict, device): |
| """All reduce dictionary values across processes.""" |
| if not dist.is_initialized(): |
| return data |
| |
| reduced_data = {} |
| for key, value in data.items(): |
| if isinstance(value, (int, float)): |
| tensor = torch.tensor([value], device=device) |
| dist.all_reduce(tensor, op=dist.ReduceOp.AVG) |
| reduced_data[key] = tensor.item() |
| else: |
| reduced_data[key] = value |
| |
| return reduced_data |
|
|
|
|
| def set_seed(seed: int): |
| """Set random seed for reproducibility.""" |
| import random |
| import numpy as np |
| |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
| |
| |
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = False |
|
|
|
|
| def get_memory_usage(): |
| """Get current memory usage.""" |
| if torch.cuda.is_available(): |
| allocated = torch.cuda.memory_allocated() / 1e9 |
| reserved = torch.cuda.memory_reserved() / 1e9 |
| max_allocated = torch.cuda.max_memory_allocated() / 1e9 |
| return { |
| 'allocated_gb': allocated, |
| 'reserved_gb': reserved, |
| 'max_allocated_gb': max_allocated, |
| } |
| return {'allocated_gb': 0, 'reserved_gb': 0, 'max_allocated_gb': 0} |
|
|
|
|
| if __name__ == "__main__": |
| |
| logger = get_logger("test") |
| logger.info("Testing logger") |
| |
| print(get_device_info()) |
| |
| meter = AverageMeter() |
| for i in range(10): |
| meter.update(i) |
| print(f"Average: {meter.avg}") |
|
|