""" Utility functions for training. Single Responsibility: General utilities that don't fit elsewhere. """ import os import sys import torch from typing import Tuple, Optional from dataclasses import dataclass # ============================================================ # Logging # ============================================================ _verbose = True def setup_logging(verbose: bool = True): """Configure logging verbosity.""" global _verbose _verbose = verbose def log(msg: str): """Log message to stdout with flush.""" if _verbose: print(msg) sys.stdout.flush() # ============================================================ # Device Information # ============================================================ @dataclass class DeviceInfo: """Information about the compute device.""" device_type: str device_name: str vram_gb: float ram_total_gb: float ram_available_gb: float num_gpus: int def __str__(self) -> str: parts = [f"Device: {self.device_type}"] if self.device_name: parts.append(f"({self.device_name})") if self.vram_gb > 0: parts.append(f"VRAM: {self.vram_gb:.0f}GB") if self.num_gpus > 1: parts.append(f"x{self.num_gpus} GPUs") return " | ".join(parts) def get_device_info() -> DeviceInfo: """Get information about the compute device.""" device_type = "cpu" device_name = "" vram_gb = 0.0 num_gpus = 0 # CUDA if torch.cuda.is_available(): device_type = "cuda" num_gpus = torch.cuda.device_count() try: props = torch.cuda.get_device_properties(0) device_name = props.name vram_gb = props.total_memory / (1024**3) except Exception: pass # MPS elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): device_type = "mps" device_name = "Apple Silicon" num_gpus = 1 # RAM info ram_total, ram_available = get_ram_info() return DeviceInfo( device_type=device_type, device_name=device_name, vram_gb=vram_gb, ram_total_gb=ram_total, ram_available_gb=ram_available, num_gpus=num_gpus, ) def get_ram_info() -> Tuple[float, float]: """Get RAM info in GB (total, available).""" try: import psutil total = psutil.virtual_memory().total / 1024**3 available = psutil.virtual_memory().available / 1024**3 return total, available except ImportError: pass try: import subprocess result = subprocess.run( ['free', '-b'], capture_output=True, text=True ) lines = result.stdout.strip().split('\n') if len(lines) >= 2: parts = lines[1].split() total = float(parts[1]) / 1024**3 available = float(parts[6]) / 1024**3 if len(parts) > 6 else float(parts[3]) / 1024**3 return total, available except Exception: pass return 0.0, 0.0 def log_memory_usage() -> str: """Get current memory usage string.""" parts = [] if torch.cuda.is_available(): used = torch.cuda.memory_allocated() / 1024**3 reserved = torch.cuda.memory_reserved() / 1024**3 parts.append(f"GPU: {used:.2f}GB / {reserved:.2f}GB") try: import psutil ram_used = psutil.virtual_memory().used / 1024**3 ram_total = psutil.virtual_memory().total / 1024**3 parts.append(f"RAM: {ram_used:.1f}GB / {ram_total:.1f}GB") except ImportError: pass return " | ".join(parts) # ============================================================ # Memory Management # ============================================================ def limit_ram_usage(max_ram_gb: float): """Limit RAM usage via resource limits.""" try: import resource max_bytes = int(max_ram_gb * 1024**3) resource.setrlimit(resource.RLIMIT_AS, (max_bytes, max_bytes)) except Exception: pass def setup_cuda_optimizations(vram_fraction: float = 0.80): """Configure CUDA optimizations.""" if not torch.cuda.is_available(): return torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.benchmark = True torch.set_float32_matmul_precision('high') try: torch.cuda.set_per_process_memory_fraction(vram_fraction) torch.cuda.empty_cache() except Exception: pass def should_enable_gradient_checkpointing( vram_gb: float, dynamic_decay: bool = False, threshold_fraction: float = 0.4 ) -> bool: """ Determine if gradient checkpointing should be enabled. Args: vram_gb: Total VRAM in GB dynamic_decay: Whether using dynamic decay (longer sequences over time) threshold_fraction: Fraction of VRAM that should be free Returns: Whether to enable gradient checkpointing """ if not torch.cuda.is_available(): return False # With dynamic_decay, sequences get longer over time if dynamic_decay and vram_gb <= 32: return True # Check available VRAM try: torch.cuda.empty_cache() free_bytes, total_bytes = torch.cuda.mem_get_info(0) free_gb = free_bytes / 1024**3 threshold_gb = vram_gb * threshold_fraction return free_gb < threshold_gb except Exception: # Conservative: enable if VRAM < 20GB return vram_gb < 20 # ============================================================ # Step Sharing (for DDP + DataLoader workers) # ============================================================ STEP_FILE = "/tmp/training_step.txt" def write_step(step: int): """Write current training step to file (main process only).""" try: with open(STEP_FILE, "w") as f: f.write(str(step)) except Exception: pass def read_step() -> int: """Read current training step from file.""" try: with open(STEP_FILE, "r") as f: return int(f.read().strip()) except Exception: return 0 # ============================================================ # HuggingFace Helpers # ============================================================ def setup_hf_login(): """Setup HuggingFace login from environment.""" hf_token = os.environ.get("HF_TOKEN") if hf_token: try: from huggingface_hub import login login(token=hf_token) return True except Exception: pass return False def load_tokenizer(model_path: str): """Load tokenizer with proper padding token.""" from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(model_path) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token return tokenizer