Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """Fresh-process checkpoint evaluation for HF Jobs. | |
| Downloads a checkpoint artifact uploaded by a prior training job and evaluates it | |
| from a new Python process, avoiding post-training CUDA fragmentation in the | |
| training container. | |
| """ | |
| from __future__ import annotations | |
| import dataclasses | |
| import json | |
| import os | |
| import sys | |
| import time | |
| from pathlib import Path | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| try: | |
| sys.stdout.reconfigure(line_buffering=True) # type: ignore[attr-defined] | |
| except Exception: | |
| pass | |
| def _require_env(name: str) -> str: | |
| value = os.environ.get(name, '').strip() | |
| if not value: | |
| raise SystemExit(f'[ckpt_eval] missing required env {name}') | |
| return value | |
| def _ckpt_path() -> Path: | |
| local = os.environ.get('HYDRA_EVAL_CKPT_PATH') | |
| if local: | |
| p = Path(local).expanduser() | |
| print(f'[ckpt_eval] using local checkpoint {p}', flush=True) | |
| return p | |
| repo_id = _require_env('HF_REPO_ID') | |
| explicit_path = os.environ.get('HYDRA_EVAL_CKPT_REPO_PATH', '').strip().lstrip('/') | |
| if explicit_path: | |
| path_in_repo = explicit_path | |
| else: | |
| source_job = _require_env('HYDRA_EVAL_CKPT_JOB_ID') | |
| filename = os.environ.get('HYDRA_EVAL_CKPT_NAME', 'pretrain_final.pt') | |
| path_in_repo = f'jobs/{source_job}/{filename}' | |
| print(f'[ckpt_eval] downloading {repo_id}/{path_in_repo}', flush=True) | |
| downloaded = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=path_in_repo, | |
| repo_type='model', | |
| token=os.environ.get('HF_TOKEN'), | |
| ) | |
| return Path(downloaded) | |
| def main() -> int: | |
| t0 = time.time() | |
| print('[ckpt_eval] phase=start', flush=True) | |
| repo_root = Path('/workspace/feather') if Path('/workspace/feather').exists() else Path.cwd() | |
| os.chdir(repo_root) | |
| sys.path.insert(0, str(repo_root)) | |
| # Imports after cwd is set so overlay modules win inside the image. | |
| import prepare as _prepare_mod | |
| from prepare import MAX_SEQ_LEN, Tokenizer | |
| from hydra.config import ( | |
| D_MODEL, D_STATE, ENGRAM_KEY_DIM, ENGRAM_LAYER_IDX, ENGRAM_N_COLUMNS, | |
| EXPAND, HEADDIM, N_HEADS, N_LAYER, PostSemClawConfig, | |
| ) | |
| from hydra.model import PostSemClawModel | |
| def config_from_dict(payload: dict) -> PostSemClawConfig: | |
| field_names = {field.name for field in dataclasses.fields(PostSemClawConfig)} | |
| kwargs = {key: value for key, value in payload.items() if key in field_names} | |
| for key in ('hyena_layers', 'gdn_layers'): | |
| if key in kwargs and isinstance(kwargs[key], list): | |
| kwargs[key] = tuple(kwargs[key]) | |
| return PostSemClawConfig(**kwargs) | |
| if os.environ.get('HYDRA_USE_NEMOTRON', '0') == '1': | |
| import prepare_nemotron as _p_nemo | |
| from prepare_nemotron import evaluate_bpb | |
| _p_nemo.ensure_tokenizer() | |
| import subsystems.sdr_retina as _sdr_retina | |
| _sdr_retina.build_retina() | |
| else: | |
| from prepare import evaluate_bpb | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print(f'[ckpt_eval] device={device} cuda={int(torch.cuda.is_available())}', flush=True) | |
| torch.set_float32_matmul_precision('high') | |
| if torch.cuda.is_available(): | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| ckpt = torch.load(str(_ckpt_path()), map_location='cpu', weights_only=False) | |
| tokenizer = Tokenizer.from_directory() | |
| vocab_size = tokenizer.get_vocab_size() | |
| cfg_payload = ckpt.get('config') | |
| if isinstance(cfg_payload, dict): | |
| config = config_from_dict(cfg_payload) | |
| else: | |
| config = PostSemClawConfig( | |
| sequence_len=MAX_SEQ_LEN, | |
| vocab_size=vocab_size, | |
| n_layer=N_LAYER, | |
| d_model=D_MODEL, | |
| d_state=D_STATE, | |
| headdim=HEADDIM, | |
| n_heads=N_HEADS, | |
| expand=EXPAND, | |
| engram_n_columns=ENGRAM_N_COLUMNS, | |
| engram_key_dim=ENGRAM_KEY_DIM, | |
| engram_layer_idx=ENGRAM_LAYER_IDX, | |
| ) | |
| print(f'[ckpt_eval] checkpoint_step={ckpt.get("step")} vocab_size={vocab_size}', flush=True) | |
| with torch.device('meta'): | |
| model = PostSemClawModel(config) | |
| model.to_empty(device=device) | |
| missing, unexpected = model.load_state_dict(ckpt.get('model_state_dict', ckpt), strict=False) | |
| print(f'[ckpt_eval] load_state missing={len(missing)} unexpected={len(unexpected)}', flush=True) | |
| model.eval() | |
| if hasattr(model, 'set_bos_token_id'): | |
| model.set_bos_token_id(tokenizer.get_bos_token_id()) | |
| del ckpt | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| eval_tokens = int(os.environ.get('HYDRA_EVAL_TOKENS', os.environ.get('HYDRA_STREAM_EVAL_TOKENS', '262144'))) | |
| eval_batch = int(os.environ.get('HYDRA_EVAL_BATCH', '1')) | |
| _prepare_mod.EVAL_TOKENS = eval_tokens | |
| os.environ['HYDRA_STREAM_EVAL_TOKENS'] = str(eval_tokens) | |
| print(f'[ckpt_eval] running eval tokens={eval_tokens} batch={eval_batch}', flush=True) | |
| with torch.no_grad(), torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=torch.cuda.is_available()): | |
| val_bpb = evaluate_bpb(model, tokenizer, eval_batch) | |
| val_ppl = 2 ** val_bpb | |
| metrics = { | |
| 'checkpoint_job_id': os.environ.get('HYDRA_EVAL_CKPT_JOB_ID'), | |
| 'checkpoint_name': os.environ.get('HYDRA_EVAL_CKPT_NAME', 'pretrain_final.pt'), | |
| 'checkpoint_repo_path': os.environ.get('HYDRA_EVAL_CKPT_REPO_PATH'), | |
| 'eval_tokens': eval_tokens, | |
| 'eval_batch': eval_batch, | |
| 'val_bpb': float(val_bpb), | |
| 'val_ppl': float(val_ppl), | |
| 'seconds': round(time.time() - t0, 3), | |
| } | |
| print(f'[CKPT_EVAL_JSON] {json.dumps(metrics, sort_keys=True)}', flush=True) | |
| print('[ckpt_eval] phase=done', flush=True) | |
| return 0 | |
| if __name__ == '__main__': | |
| # Full-corpus streaming eval can leave HF datasets downloader/native threads | |
| # alive at interpreter shutdown after [CKPT_EVAL_JSON] is already flushed. | |
| # Exit the process directly so HF Jobs records the completed metric instead | |
| # of converting a post-metric PyGILState finalization abort into ERROR. | |
| _rc = main() | |
| sys.stdout.flush() | |
| sys.stderr.flush() | |
| os._exit(_rc) | |