| """ |
| Utility functions for training. |
| |
| Single Responsibility: General utilities that don't fit elsewhere. |
| """ |
|
|
| import os |
| import sys |
| import torch |
| from typing import Tuple, Optional |
| from dataclasses import dataclass |
|
|
|
|
| |
| |
| |
| _verbose = True |
|
|
|
|
| def setup_logging(verbose: bool = True): |
| """Configure logging verbosity.""" |
| global _verbose |
| _verbose = verbose |
|
|
|
|
| def log(msg: str): |
| """Log message to stdout with flush.""" |
| if _verbose: |
| print(msg) |
| sys.stdout.flush() |
|
|
|
|
| |
| |
| |
| @dataclass |
| class DeviceInfo: |
| """Information about the compute device.""" |
| device_type: str |
| device_name: str |
| vram_gb: float |
| ram_total_gb: float |
| ram_available_gb: float |
| num_gpus: int |
|
|
| def __str__(self) -> str: |
| parts = [f"Device: {self.device_type}"] |
| if self.device_name: |
| parts.append(f"({self.device_name})") |
| if self.vram_gb > 0: |
| parts.append(f"VRAM: {self.vram_gb:.0f}GB") |
| if self.num_gpus > 1: |
| parts.append(f"x{self.num_gpus} GPUs") |
| return " | ".join(parts) |
|
|
|
|
| def get_device_info() -> DeviceInfo: |
| """Get information about the compute device.""" |
| device_type = "cpu" |
| device_name = "" |
| vram_gb = 0.0 |
| num_gpus = 0 |
|
|
| |
| if torch.cuda.is_available(): |
| device_type = "cuda" |
| num_gpus = torch.cuda.device_count() |
| try: |
| props = torch.cuda.get_device_properties(0) |
| device_name = props.name |
| vram_gb = props.total_memory / (1024**3) |
| except Exception: |
| pass |
|
|
| |
| elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): |
| device_type = "mps" |
| device_name = "Apple Silicon" |
| num_gpus = 1 |
|
|
| |
| ram_total, ram_available = get_ram_info() |
|
|
| return DeviceInfo( |
| device_type=device_type, |
| device_name=device_name, |
| vram_gb=vram_gb, |
| ram_total_gb=ram_total, |
| ram_available_gb=ram_available, |
| num_gpus=num_gpus, |
| ) |
|
|
|
|
| def get_ram_info() -> Tuple[float, float]: |
| """Get RAM info in GB (total, available).""" |
| try: |
| import psutil |
| total = psutil.virtual_memory().total / 1024**3 |
| available = psutil.virtual_memory().available / 1024**3 |
| return total, available |
| except ImportError: |
| pass |
|
|
| try: |
| import subprocess |
| result = subprocess.run( |
| ['free', '-b'], |
| capture_output=True, text=True |
| ) |
| lines = result.stdout.strip().split('\n') |
| if len(lines) >= 2: |
| parts = lines[1].split() |
| total = float(parts[1]) / 1024**3 |
| available = float(parts[6]) / 1024**3 if len(parts) > 6 else float(parts[3]) / 1024**3 |
| return total, available |
| except Exception: |
| pass |
|
|
| return 0.0, 0.0 |
|
|
|
|
| def log_memory_usage() -> str: |
| """Get current memory usage string.""" |
| parts = [] |
|
|
| if torch.cuda.is_available(): |
| used = torch.cuda.memory_allocated() / 1024**3 |
| reserved = torch.cuda.memory_reserved() / 1024**3 |
| parts.append(f"GPU: {used:.2f}GB / {reserved:.2f}GB") |
|
|
| try: |
| import psutil |
| ram_used = psutil.virtual_memory().used / 1024**3 |
| ram_total = psutil.virtual_memory().total / 1024**3 |
| parts.append(f"RAM: {ram_used:.1f}GB / {ram_total:.1f}GB") |
| except ImportError: |
| pass |
|
|
| return " | ".join(parts) |
|
|
|
|
| |
| |
| |
| def limit_ram_usage(max_ram_gb: float): |
| """Limit RAM usage via resource limits.""" |
| try: |
| import resource |
| max_bytes = int(max_ram_gb * 1024**3) |
| resource.setrlimit(resource.RLIMIT_AS, (max_bytes, max_bytes)) |
| except Exception: |
| pass |
|
|
|
|
| def setup_cuda_optimizations(vram_fraction: float = 0.80): |
| """Configure CUDA optimizations.""" |
| if not torch.cuda.is_available(): |
| return |
|
|
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
| torch.backends.cudnn.benchmark = True |
| torch.set_float32_matmul_precision('high') |
|
|
| try: |
| torch.cuda.set_per_process_memory_fraction(vram_fraction) |
| torch.cuda.empty_cache() |
| except Exception: |
| pass |
|
|
|
|
| def should_enable_gradient_checkpointing( |
| vram_gb: float, |
| dynamic_decay: bool = False, |
| threshold_fraction: float = 0.4 |
| ) -> bool: |
| """ |
| Determine if gradient checkpointing should be enabled. |
| |
| Args: |
| vram_gb: Total VRAM in GB |
| dynamic_decay: Whether using dynamic decay (longer sequences over time) |
| threshold_fraction: Fraction of VRAM that should be free |
| |
| Returns: |
| Whether to enable gradient checkpointing |
| """ |
| if not torch.cuda.is_available(): |
| return False |
|
|
| |
| if dynamic_decay and vram_gb <= 32: |
| return True |
|
|
| |
| try: |
| torch.cuda.empty_cache() |
| free_bytes, total_bytes = torch.cuda.mem_get_info(0) |
| free_gb = free_bytes / 1024**3 |
| threshold_gb = vram_gb * threshold_fraction |
| return free_gb < threshold_gb |
| except Exception: |
| |
| return vram_gb < 20 |
|
|
|
|
| |
| |
| |
| STEP_FILE = "/tmp/training_step.txt" |
|
|
|
|
| def write_step(step: int): |
| """Write current training step to file (main process only).""" |
| try: |
| with open(STEP_FILE, "w") as f: |
| f.write(str(step)) |
| except Exception: |
| pass |
|
|
|
|
| def read_step() -> int: |
| """Read current training step from file.""" |
| try: |
| with open(STEP_FILE, "r") as f: |
| return int(f.read().strip()) |
| except Exception: |
| return 0 |
|
|
|
|
| |
| |
| |
| def setup_hf_login(): |
| """Setup HuggingFace login from environment.""" |
| hf_token = os.environ.get("HF_TOKEN") |
| if hf_token: |
| try: |
| from huggingface_hub import login |
| login(token=hf_token) |
| return True |
| except Exception: |
| pass |
| return False |
|
|
|
|
| def load_tokenizer(model_path: str): |
| """Load tokenizer with proper padding token.""" |
| from transformers import AutoTokenizer |
|
|
| tokenizer = AutoTokenizer.from_pretrained(model_path) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| return tokenizer |
|
|