feather-a10g-large-runtime / overlay /scripts /hf_checkpoint_eval.py
Jackoatmon's picture
Update Feather a10g-large training runtime image
4175f18 verified
#!/usr/bin/env python3
"""Fresh-process checkpoint evaluation for HF Jobs.
Downloads a checkpoint artifact uploaded by a prior training job and evaluates it
from a new Python process, avoiding post-training CUDA fragmentation in the
training container.
"""
from __future__ import annotations
import dataclasses
import json
import os
import sys
import time
from pathlib import Path
import torch
from huggingface_hub import hf_hub_download
try:
sys.stdout.reconfigure(line_buffering=True) # type: ignore[attr-defined]
except Exception:
pass
def _require_env(name: str) -> str:
value = os.environ.get(name, '').strip()
if not value:
raise SystemExit(f'[ckpt_eval] missing required env {name}')
return value
def _ckpt_path() -> Path:
local = os.environ.get('HYDRA_EVAL_CKPT_PATH')
if local:
p = Path(local).expanduser()
print(f'[ckpt_eval] using local checkpoint {p}', flush=True)
return p
repo_id = _require_env('HF_REPO_ID')
explicit_path = os.environ.get('HYDRA_EVAL_CKPT_REPO_PATH', '').strip().lstrip('/')
if explicit_path:
path_in_repo = explicit_path
else:
source_job = _require_env('HYDRA_EVAL_CKPT_JOB_ID')
filename = os.environ.get('HYDRA_EVAL_CKPT_NAME', 'pretrain_final.pt')
path_in_repo = f'jobs/{source_job}/{filename}'
print(f'[ckpt_eval] downloading {repo_id}/{path_in_repo}', flush=True)
downloaded = hf_hub_download(
repo_id=repo_id,
filename=path_in_repo,
repo_type='model',
token=os.environ.get('HF_TOKEN'),
)
return Path(downloaded)
def main() -> int:
t0 = time.time()
print('[ckpt_eval] phase=start', flush=True)
repo_root = Path('/workspace/feather') if Path('/workspace/feather').exists() else Path.cwd()
os.chdir(repo_root)
sys.path.insert(0, str(repo_root))
# Imports after cwd is set so overlay modules win inside the image.
import prepare as _prepare_mod
from prepare import MAX_SEQ_LEN, Tokenizer
from hydra.config import (
D_MODEL, D_STATE, ENGRAM_KEY_DIM, ENGRAM_LAYER_IDX, ENGRAM_N_COLUMNS,
EXPAND, HEADDIM, N_HEADS, N_LAYER, PostSemClawConfig,
)
from hydra.model import PostSemClawModel
def config_from_dict(payload: dict) -> PostSemClawConfig:
field_names = {field.name for field in dataclasses.fields(PostSemClawConfig)}
kwargs = {key: value for key, value in payload.items() if key in field_names}
for key in ('hyena_layers', 'gdn_layers'):
if key in kwargs and isinstance(kwargs[key], list):
kwargs[key] = tuple(kwargs[key])
return PostSemClawConfig(**kwargs)
if os.environ.get('HYDRA_USE_NEMOTRON', '0') == '1':
import prepare_nemotron as _p_nemo
from prepare_nemotron import evaluate_bpb
_p_nemo.ensure_tokenizer()
import subsystems.sdr_retina as _sdr_retina
_sdr_retina.build_retina()
else:
from prepare import evaluate_bpb
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'[ckpt_eval] device={device} cuda={int(torch.cuda.is_available())}', flush=True)
torch.set_float32_matmul_precision('high')
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
ckpt = torch.load(str(_ckpt_path()), map_location='cpu', weights_only=False)
tokenizer = Tokenizer.from_directory()
vocab_size = tokenizer.get_vocab_size()
cfg_payload = ckpt.get('config')
if isinstance(cfg_payload, dict):
config = config_from_dict(cfg_payload)
else:
config = PostSemClawConfig(
sequence_len=MAX_SEQ_LEN,
vocab_size=vocab_size,
n_layer=N_LAYER,
d_model=D_MODEL,
d_state=D_STATE,
headdim=HEADDIM,
n_heads=N_HEADS,
expand=EXPAND,
engram_n_columns=ENGRAM_N_COLUMNS,
engram_key_dim=ENGRAM_KEY_DIM,
engram_layer_idx=ENGRAM_LAYER_IDX,
)
print(f'[ckpt_eval] checkpoint_step={ckpt.get("step")} vocab_size={vocab_size}', flush=True)
with torch.device('meta'):
model = PostSemClawModel(config)
model.to_empty(device=device)
missing, unexpected = model.load_state_dict(ckpt.get('model_state_dict', ckpt), strict=False)
print(f'[ckpt_eval] load_state missing={len(missing)} unexpected={len(unexpected)}', flush=True)
model.eval()
if hasattr(model, 'set_bos_token_id'):
model.set_bos_token_id(tokenizer.get_bos_token_id())
del ckpt
if torch.cuda.is_available():
torch.cuda.empty_cache()
eval_tokens = int(os.environ.get('HYDRA_EVAL_TOKENS', os.environ.get('HYDRA_STREAM_EVAL_TOKENS', '262144')))
eval_batch = int(os.environ.get('HYDRA_EVAL_BATCH', '1'))
_prepare_mod.EVAL_TOKENS = eval_tokens
os.environ['HYDRA_STREAM_EVAL_TOKENS'] = str(eval_tokens)
print(f'[ckpt_eval] running eval tokens={eval_tokens} batch={eval_batch}', flush=True)
with torch.no_grad(), torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=torch.cuda.is_available()):
val_bpb = evaluate_bpb(model, tokenizer, eval_batch)
val_ppl = 2 ** val_bpb
metrics = {
'checkpoint_job_id': os.environ.get('HYDRA_EVAL_CKPT_JOB_ID'),
'checkpoint_name': os.environ.get('HYDRA_EVAL_CKPT_NAME', 'pretrain_final.pt'),
'checkpoint_repo_path': os.environ.get('HYDRA_EVAL_CKPT_REPO_PATH'),
'eval_tokens': eval_tokens,
'eval_batch': eval_batch,
'val_bpb': float(val_bpb),
'val_ppl': float(val_ppl),
'seconds': round(time.time() - t0, 3),
}
print(f'[CKPT_EVAL_JSON] {json.dumps(metrics, sort_keys=True)}', flush=True)
print('[ckpt_eval] phase=done', flush=True)
return 0
if __name__ == '__main__':
# Full-corpus streaming eval can leave HF datasets downloader/native threads
# alive at interpreter shutdown after [CKPT_EVAL_JSON] is already flushed.
# Exit the process directly so HF Jobs records the completed metric instead
# of converting a post-metric PyGILState finalization abort into ERROR.
_rc = main()
sys.stdout.flush()
sys.stderr.flush()
os._exit(_rc)