#!/usr/bin/env python3 """Fresh-process checkpoint evaluation for HF Jobs. Downloads a checkpoint artifact uploaded by a prior training job and evaluates it from a new Python process, avoiding post-training CUDA fragmentation in the training container. """ from __future__ import annotations import dataclasses import json import os import sys import time from pathlib import Path import torch from huggingface_hub import hf_hub_download try: sys.stdout.reconfigure(line_buffering=True) # type: ignore[attr-defined] except Exception: pass def _require_env(name: str) -> str: value = os.environ.get(name, '').strip() if not value: raise SystemExit(f'[ckpt_eval] missing required env {name}') return value def _ckpt_path() -> Path: local = os.environ.get('HYDRA_EVAL_CKPT_PATH') if local: p = Path(local).expanduser() print(f'[ckpt_eval] using local checkpoint {p}', flush=True) return p repo_id = _require_env('HF_REPO_ID') explicit_path = os.environ.get('HYDRA_EVAL_CKPT_REPO_PATH', '').strip().lstrip('/') if explicit_path: path_in_repo = explicit_path else: source_job = _require_env('HYDRA_EVAL_CKPT_JOB_ID') filename = os.environ.get('HYDRA_EVAL_CKPT_NAME', 'pretrain_final.pt') path_in_repo = f'jobs/{source_job}/{filename}' print(f'[ckpt_eval] downloading {repo_id}/{path_in_repo}', flush=True) downloaded = hf_hub_download( repo_id=repo_id, filename=path_in_repo, repo_type='model', token=os.environ.get('HF_TOKEN'), ) return Path(downloaded) def main() -> int: t0 = time.time() print('[ckpt_eval] phase=start', flush=True) repo_root = Path('/workspace/feather') if Path('/workspace/feather').exists() else Path.cwd() os.chdir(repo_root) sys.path.insert(0, str(repo_root)) # Imports after cwd is set so overlay modules win inside the image. import prepare as _prepare_mod from prepare import MAX_SEQ_LEN, Tokenizer from hydra.config import ( D_MODEL, D_STATE, ENGRAM_KEY_DIM, ENGRAM_LAYER_IDX, ENGRAM_N_COLUMNS, EXPAND, HEADDIM, N_HEADS, N_LAYER, PostSemClawConfig, ) from hydra.model import PostSemClawModel def config_from_dict(payload: dict) -> PostSemClawConfig: field_names = {field.name for field in dataclasses.fields(PostSemClawConfig)} kwargs = {key: value for key, value in payload.items() if key in field_names} for key in ('hyena_layers', 'gdn_layers'): if key in kwargs and isinstance(kwargs[key], list): kwargs[key] = tuple(kwargs[key]) return PostSemClawConfig(**kwargs) if os.environ.get('HYDRA_USE_NEMOTRON', '0') == '1': import prepare_nemotron as _p_nemo from prepare_nemotron import evaluate_bpb _p_nemo.ensure_tokenizer() import subsystems.sdr_retina as _sdr_retina _sdr_retina.build_retina() else: from prepare import evaluate_bpb device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f'[ckpt_eval] device={device} cuda={int(torch.cuda.is_available())}', flush=True) torch.set_float32_matmul_precision('high') if torch.cuda.is_available(): torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True ckpt = torch.load(str(_ckpt_path()), map_location='cpu', weights_only=False) tokenizer = Tokenizer.from_directory() vocab_size = tokenizer.get_vocab_size() cfg_payload = ckpt.get('config') if isinstance(cfg_payload, dict): config = config_from_dict(cfg_payload) else: config = PostSemClawConfig( sequence_len=MAX_SEQ_LEN, vocab_size=vocab_size, n_layer=N_LAYER, d_model=D_MODEL, d_state=D_STATE, headdim=HEADDIM, n_heads=N_HEADS, expand=EXPAND, engram_n_columns=ENGRAM_N_COLUMNS, engram_key_dim=ENGRAM_KEY_DIM, engram_layer_idx=ENGRAM_LAYER_IDX, ) print(f'[ckpt_eval] checkpoint_step={ckpt.get("step")} vocab_size={vocab_size}', flush=True) with torch.device('meta'): model = PostSemClawModel(config) model.to_empty(device=device) missing, unexpected = model.load_state_dict(ckpt.get('model_state_dict', ckpt), strict=False) print(f'[ckpt_eval] load_state missing={len(missing)} unexpected={len(unexpected)}', flush=True) model.eval() if hasattr(model, 'set_bos_token_id'): model.set_bos_token_id(tokenizer.get_bos_token_id()) del ckpt if torch.cuda.is_available(): torch.cuda.empty_cache() eval_tokens = int(os.environ.get('HYDRA_EVAL_TOKENS', os.environ.get('HYDRA_STREAM_EVAL_TOKENS', '262144'))) eval_batch = int(os.environ.get('HYDRA_EVAL_BATCH', '1')) _prepare_mod.EVAL_TOKENS = eval_tokens os.environ['HYDRA_STREAM_EVAL_TOKENS'] = str(eval_tokens) print(f'[ckpt_eval] running eval tokens={eval_tokens} batch={eval_batch}', flush=True) with torch.no_grad(), torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=torch.cuda.is_available()): val_bpb = evaluate_bpb(model, tokenizer, eval_batch) val_ppl = 2 ** val_bpb metrics = { 'checkpoint_job_id': os.environ.get('HYDRA_EVAL_CKPT_JOB_ID'), 'checkpoint_name': os.environ.get('HYDRA_EVAL_CKPT_NAME', 'pretrain_final.pt'), 'checkpoint_repo_path': os.environ.get('HYDRA_EVAL_CKPT_REPO_PATH'), 'eval_tokens': eval_tokens, 'eval_batch': eval_batch, 'val_bpb': float(val_bpb), 'val_ppl': float(val_ppl), 'seconds': round(time.time() - t0, 3), } print(f'[CKPT_EVAL_JSON] {json.dumps(metrics, sort_keys=True)}', flush=True) print('[ckpt_eval] phase=done', flush=True) return 0 if __name__ == '__main__': # Full-corpus streaming eval can leave HF datasets downloader/native threads # alive at interpreter shutdown after [CKPT_EVAL_JSON] is already flushed. # Exit the process directly so HF Jobs records the completed metric instead # of converting a post-metric PyGILState finalization abort into ERROR. _rc = main() sys.stdout.flush() sys.stderr.flush() os._exit(_rc)