Spaces:
Sleeping
Sleeping
| """ | |
| Common utilities for nanochat. | |
| """ | |
| import os | |
| import re | |
| import logging | |
| import urllib.request | |
| import torch | |
| import torch.distributed as dist | |
| from filelock import FileLock | |
| # The dtype used for compute (matmuls, activations). Master weights stay fp32 for optimizer precision. | |
| # Linear layers cast their weights to this dtype in forward, replacing torch.amp.autocast. | |
| # Override with NANOCHAT_DTYPE env var: "bfloat16", "float16", "float32" | |
| _DTYPE_MAP = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32} | |
| def _detect_compute_dtype(): | |
| env = os.environ.get("NANOCHAT_DTYPE") | |
| if env is not None: | |
| return _DTYPE_MAP[env], f"set via NANOCHAT_DTYPE={env}" | |
| if torch.cuda.is_available(): | |
| # bf16 requires SM 80+ (Ampere: A100, A10, etc.) | |
| # Older GPUs like V100 (SM 70) and T4 (SM 75) only have fp16 tensor cores | |
| capability = torch.cuda.get_device_capability() | |
| if capability >= (8, 0): | |
| return torch.bfloat16, f"auto-detected: CUDA SM {capability[0]}{capability[1]} (bf16 supported)" | |
| # fp16 training requires GradScaler (not yet implemented), so fall back to fp32. | |
| # Users can still force fp16 via NANOCHAT_DTYPE=float16 if they know what they're doing. | |
| return torch.float32, f"auto-detected: CUDA SM {capability[0]}{capability[1]} (pre-Ampere, bf16 not supported, using fp32)" | |
| return torch.float32, "auto-detected: no CUDA (CPU/MPS)" | |
| COMPUTE_DTYPE, COMPUTE_DTYPE_REASON = _detect_compute_dtype() | |
| class ColoredFormatter(logging.Formatter): | |
| """Custom formatter that adds colors to log messages.""" | |
| # ANSI color codes | |
| COLORS = { | |
| 'DEBUG': '\033[36m', # Cyan | |
| 'INFO': '\033[32m', # Green | |
| 'WARNING': '\033[33m', # Yellow | |
| 'ERROR': '\033[31m', # Red | |
| 'CRITICAL': '\033[35m', # Magenta | |
| } | |
| RESET = '\033[0m' | |
| BOLD = '\033[1m' | |
| def format(self, record): | |
| # Add color to the level name | |
| levelname = record.levelname | |
| if levelname in self.COLORS: | |
| record.levelname = f"{self.COLORS[levelname]}{self.BOLD}{levelname}{self.RESET}" | |
| # Format the message | |
| message = super().format(record) | |
| # Add color to specific parts of the message | |
| if levelname == 'INFO': | |
| # Highlight numbers and percentages | |
| message = re.sub(r'(\d+\.?\d*\s*(?:GB|MB|%|docs))', rf'{self.BOLD}\1{self.RESET}', message) | |
| message = re.sub(r'(Shard \d+)', rf'{self.COLORS["INFO"]}{self.BOLD}\1{self.RESET}', message) | |
| return message | |
| def setup_default_logging(): | |
| handler = logging.StreamHandler() | |
| handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| handlers=[handler] | |
| ) | |
| setup_default_logging() | |
| logger = logging.getLogger(__name__) | |
| def get_base_dir(): | |
| # co-locate nanochat intermediates with other cached data in ~/.cache (by default) | |
| if os.environ.get("NANOCHAT_BASE_DIR"): | |
| nanochat_dir = os.environ.get("NANOCHAT_BASE_DIR") | |
| else: | |
| home_dir = os.path.expanduser("~") | |
| cache_dir = os.path.join(home_dir, ".cache") | |
| nanochat_dir = os.path.join(cache_dir, "nanochat") | |
| os.makedirs(nanochat_dir, exist_ok=True) | |
| return nanochat_dir | |
| def download_file_with_lock(url, filename, postprocess_fn=None): | |
| """ | |
| Downloads a file from a URL to a local path in the base directory. | |
| Uses a lock file to prevent concurrent downloads among multiple ranks. | |
| """ | |
| base_dir = get_base_dir() | |
| file_path = os.path.join(base_dir, filename) | |
| lock_path = file_path + ".lock" | |
| if os.path.exists(file_path): | |
| return file_path | |
| with FileLock(lock_path): | |
| # Only a single rank can acquire this lock | |
| # All other ranks block until it is released | |
| # Recheck after acquiring lock | |
| if os.path.exists(file_path): | |
| return file_path | |
| # Download the content as bytes | |
| print(f"Downloading {url}...") | |
| with urllib.request.urlopen(url) as response: | |
| content = response.read() # bytes | |
| # Write to local file | |
| with open(file_path, 'wb') as f: | |
| f.write(content) | |
| print(f"Downloaded to {file_path}") | |
| # Run the postprocess function if provided | |
| if postprocess_fn is not None: | |
| postprocess_fn(file_path) | |
| return file_path | |
| def print0(s="",**kwargs): | |
| ddp_rank = int(os.environ.get('RANK', 0)) | |
| if ddp_rank == 0: | |
| print(s, **kwargs) | |
| def print_banner(): | |
| # Cool DOS Rebel font ASCII banner made with https://manytools.org/hacker-tools/ascii-banner/ | |
| banner = """ | |
| █████ █████ | |
| ░░███ ░░███ | |
| ████████ ██████ ████████ ██████ ██████ ░███████ ██████ ███████ | |
| ░░███░░███ ░░░░░███ ░░███░░███ ███░░███ ███░░███ ░███░░███ ░░░░░███░░░███░ | |
| ░███ ░███ ███████ ░███ ░███ ░███ ░███░███ ░░░ ░███ ░███ ███████ ░███ | |
| ░███ ░███ ███░░███ ░███ ░███ ░███ ░███░███ ███ ░███ ░███ ███░░███ ░███ ███ | |
| ████ █████░░████████ ████ █████░░██████ ░░██████ ████ █████░░███████ ░░█████ | |
| ░░░░ ░░░░░ ░░░░░░░░ ░░░░ ░░░░░ ░░░░░░ ░░░░░░ ░░░░ ░░░░░ ░░░░░░░░ ░░░░░ | |
| """ | |
| print0(banner) | |
| def is_ddp_requested() -> bool: | |
| """ | |
| True if launched by torchrun (env present), even before init. | |
| Used to decide whether we *should* initialize a PG. | |
| """ | |
| return all(k in os.environ for k in ("RANK", "LOCAL_RANK", "WORLD_SIZE")) | |
| def is_ddp_initialized() -> bool: | |
| """ | |
| True if torch.distributed is available and the process group is initialized. | |
| Used at cleanup to avoid destroying a non-existent PG. | |
| """ | |
| return dist.is_available() and dist.is_initialized() | |
| def get_dist_info(): | |
| if is_ddp_requested(): | |
| # We rely on torchrun's env to decide if we SHOULD init. | |
| # (Initialization itself happens in compute init.) | |
| assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE']) | |
| ddp_rank = int(os.environ['RANK']) | |
| ddp_local_rank = int(os.environ['LOCAL_RANK']) | |
| ddp_world_size = int(os.environ['WORLD_SIZE']) | |
| return True, ddp_rank, ddp_local_rank, ddp_world_size | |
| else: | |
| return False, 0, 0, 1 | |
| def autodetect_device_type(): | |
| # prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU | |
| if torch.cuda.is_available(): | |
| device_type = "cuda" | |
| elif torch.backends.mps.is_available(): | |
| device_type = "mps" | |
| else: | |
| device_type = "cpu" | |
| print0(f"Autodetected device type: {device_type}") | |
| return device_type | |
| def compute_init(device_type="cuda"): # cuda|cpu|mps | |
| """Basic initialization that we keep doing over and over, so make common.""" | |
| assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm" | |
| if device_type == "cuda": | |
| assert torch.cuda.is_available(), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'" | |
| if device_type == "mps": | |
| assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'" | |
| # Reproducibility | |
| # Note that we set the global seeds here, but most of the code uses explicit rng objects. | |
| # The only place where global rng might be used is nn.Module initialization of the model weights. | |
| torch.manual_seed(42) | |
| if device_type == "cuda": | |
| torch.cuda.manual_seed(42) | |
| # skipping full reproducibility for now, possibly investigate slowdown later | |
| # torch.use_deterministic_algorithms(True) | |
| # Precision | |
| if device_type == "cuda": | |
| torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls, see https://docs.pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html | |
| # Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA | |
| is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() | |
| if is_ddp_requested and device_type == "cuda": | |
| device = torch.device("cuda", ddp_local_rank) | |
| torch.cuda.set_device(device) # make "cuda" default to this device | |
| dist.init_process_group(backend="nccl", device_id=device) | |
| dist.barrier() | |
| else: | |
| device = torch.device(device_type) # mps|cpu | |
| if ddp_rank == 0: | |
| logger.info(f"Distributed world size: {ddp_world_size}") | |
| return is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size, device | |
| def compute_cleanup(): | |
| """Companion function to compute_init, to clean things up before script exit""" | |
| if is_ddp_initialized(): | |
| dist.destroy_process_group() | |
| class DummyWandb: | |
| """Useful if we wish to not use wandb but have all the same signatures""" | |
| def __init__(self): | |
| pass | |
| def log(self, *args, **kwargs): | |
| pass | |
| def finish(self): | |
| pass | |
| # hardcoded BF16 peak flops for various GPUs | |
| # inspired by torchtitan: https://github.com/pytorch/torchtitan/blob/main/torchtitan/tools/utils.py | |
| # and PR: https://github.com/karpathy/nanochat/pull/147 | |
| def get_peak_flops(device_name: str) -> float: | |
| name = device_name.lower() | |
| # Table order matters: more specific patterns first. | |
| _PEAK_FLOPS_TABLE = ( | |
| # NVIDIA Blackwell | |
| (["gb200"], 2.5e15), | |
| (["grace blackwell"], 2.5e15), | |
| (["b200"], 2.25e15), | |
| (["b100"], 1.8e15), | |
| # NVIDIA Hopper | |
| (["h200", "nvl"], 836e12), | |
| (["h200", "pcie"], 836e12), | |
| (["h200"], 989e12), | |
| (["h100", "nvl"], 835e12), | |
| (["h100", "pcie"], 756e12), | |
| (["h100"], 989e12), | |
| (["h800", "nvl"], 989e12), | |
| (["h800"], 756e12), | |
| # NVIDIA Ampere data center | |
| (["a100"], 312e12), | |
| (["a800"], 312e12), | |
| (["a40"], 149.7e12), | |
| (["a30"], 165e12), | |
| # NVIDIA Ada data center | |
| (["l40s"], 362e12), | |
| (["l40-s"], 362e12), | |
| (["l40 s"], 362e12), | |
| (["l4"], 121e12), | |
| # AMD CDNA accelerators | |
| (["mi355"], 2.5e15), | |
| (["mi325"], 1.3074e15), | |
| (["mi300x"], 1.3074e15), | |
| (["mi300a"], 980.6e12), | |
| (["mi250x"], 383e12), | |
| (["mi250"], 362.1e12), | |
| # Consumer RTX | |
| (["5090"], 209.5e12), | |
| (["4090"], 165.2e12), | |
| (["3090"], 71e12), | |
| ) | |
| for patterns, flops in _PEAK_FLOPS_TABLE: | |
| if all(p in name for p in patterns): | |
| return flops | |
| if "data center gpu max 1550" in name: | |
| # Ponte Vecchio (PVC) - dynamic based on compute units | |
| max_comp_units = torch.xpu.get_device_properties("xpu").max_compute_units | |
| return 512 * max_comp_units * 1300 * 10**6 | |
| # Unknown GPU - return inf so MFU shows as 0% rather than a wrong guess | |
| logger.warning(f"Peak flops undefined for: {device_name}, MFU will show as 0%") | |
| return float('inf') | |