omini-model / training /utils.py
marcos
feat: Refactor training with SOLID principles and add optimizations
e20f447
Raw
History Blame Contribute Delete
6.98 kB
"""
Utility functions for training.
Single Responsibility: General utilities that don't fit elsewhere.
"""
import os
import sys
import torch
from typing import Tuple, Optional
from dataclasses import dataclass
# ============================================================
# Logging
# ============================================================
_verbose = True
def setup_logging(verbose: bool = True):
"""Configure logging verbosity."""
global _verbose
_verbose = verbose
def log(msg: str):
"""Log message to stdout with flush."""
if _verbose:
print(msg)
sys.stdout.flush()
# ============================================================
# Device Information
# ============================================================
@dataclass
class DeviceInfo:
"""Information about the compute device."""
device_type: str
device_name: str
vram_gb: float
ram_total_gb: float
ram_available_gb: float
num_gpus: int
def __str__(self) -> str:
parts = [f"Device: {self.device_type}"]
if self.device_name:
parts.append(f"({self.device_name})")
if self.vram_gb > 0:
parts.append(f"VRAM: {self.vram_gb:.0f}GB")
if self.num_gpus > 1:
parts.append(f"x{self.num_gpus} GPUs")
return " | ".join(parts)
def get_device_info() -> DeviceInfo:
"""Get information about the compute device."""
device_type = "cpu"
device_name = ""
vram_gb = 0.0
num_gpus = 0
# CUDA
if torch.cuda.is_available():
device_type = "cuda"
num_gpus = torch.cuda.device_count()
try:
props = torch.cuda.get_device_properties(0)
device_name = props.name
vram_gb = props.total_memory / (1024**3)
except Exception:
pass
# MPS
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
device_type = "mps"
device_name = "Apple Silicon"
num_gpus = 1
# RAM info
ram_total, ram_available = get_ram_info()
return DeviceInfo(
device_type=device_type,
device_name=device_name,
vram_gb=vram_gb,
ram_total_gb=ram_total,
ram_available_gb=ram_available,
num_gpus=num_gpus,
)
def get_ram_info() -> Tuple[float, float]:
"""Get RAM info in GB (total, available)."""
try:
import psutil
total = psutil.virtual_memory().total / 1024**3
available = psutil.virtual_memory().available / 1024**3
return total, available
except ImportError:
pass
try:
import subprocess
result = subprocess.run(
['free', '-b'],
capture_output=True, text=True
)
lines = result.stdout.strip().split('\n')
if len(lines) >= 2:
parts = lines[1].split()
total = float(parts[1]) / 1024**3
available = float(parts[6]) / 1024**3 if len(parts) > 6 else float(parts[3]) / 1024**3
return total, available
except Exception:
pass
return 0.0, 0.0
def log_memory_usage() -> str:
"""Get current memory usage string."""
parts = []
if torch.cuda.is_available():
used = torch.cuda.memory_allocated() / 1024**3
reserved = torch.cuda.memory_reserved() / 1024**3
parts.append(f"GPU: {used:.2f}GB / {reserved:.2f}GB")
try:
import psutil
ram_used = psutil.virtual_memory().used / 1024**3
ram_total = psutil.virtual_memory().total / 1024**3
parts.append(f"RAM: {ram_used:.1f}GB / {ram_total:.1f}GB")
except ImportError:
pass
return " | ".join(parts)
# ============================================================
# Memory Management
# ============================================================
def limit_ram_usage(max_ram_gb: float):
"""Limit RAM usage via resource limits."""
try:
import resource
max_bytes = int(max_ram_gb * 1024**3)
resource.setrlimit(resource.RLIMIT_AS, (max_bytes, max_bytes))
except Exception:
pass
def setup_cuda_optimizations(vram_fraction: float = 0.80):
"""Configure CUDA optimizations."""
if not torch.cuda.is_available():
return
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision('high')
try:
torch.cuda.set_per_process_memory_fraction(vram_fraction)
torch.cuda.empty_cache()
except Exception:
pass
def should_enable_gradient_checkpointing(
vram_gb: float,
dynamic_decay: bool = False,
threshold_fraction: float = 0.4
) -> bool:
"""
Determine if gradient checkpointing should be enabled.
Args:
vram_gb: Total VRAM in GB
dynamic_decay: Whether using dynamic decay (longer sequences over time)
threshold_fraction: Fraction of VRAM that should be free
Returns:
Whether to enable gradient checkpointing
"""
if not torch.cuda.is_available():
return False
# With dynamic_decay, sequences get longer over time
if dynamic_decay and vram_gb <= 32:
return True
# Check available VRAM
try:
torch.cuda.empty_cache()
free_bytes, total_bytes = torch.cuda.mem_get_info(0)
free_gb = free_bytes / 1024**3
threshold_gb = vram_gb * threshold_fraction
return free_gb < threshold_gb
except Exception:
# Conservative: enable if VRAM < 20GB
return vram_gb < 20
# ============================================================
# Step Sharing (for DDP + DataLoader workers)
# ============================================================
STEP_FILE = "/tmp/training_step.txt"
def write_step(step: int):
"""Write current training step to file (main process only)."""
try:
with open(STEP_FILE, "w") as f:
f.write(str(step))
except Exception:
pass
def read_step() -> int:
"""Read current training step from file."""
try:
with open(STEP_FILE, "r") as f:
return int(f.read().strip())
except Exception:
return 0
# ============================================================
# HuggingFace Helpers
# ============================================================
def setup_hf_login():
"""Setup HuggingFace login from environment."""
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
try:
from huggingface_hub import login
login(token=hf_token)
return True
except Exception:
pass
return False
def load_tokenizer(model_path: str):
"""Load tokenizer with proper padding token."""
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
return tokenizer