File size: 6,288 Bytes
22741d9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 | #!/usr/bin/env python3
"""Fresh-process checkpoint evaluation for HF Jobs.
Downloads a checkpoint artifact uploaded by a prior training job and evaluates it
from a new Python process, avoiding post-training CUDA fragmentation in the
training container.
"""
from __future__ import annotations
import dataclasses
import json
import os
import sys
import time
from pathlib import Path
import torch
from huggingface_hub import hf_hub_download
try:
sys.stdout.reconfigure(line_buffering=True) # type: ignore[attr-defined]
except Exception:
pass
def _require_env(name: str) -> str:
value = os.environ.get(name, '').strip()
if not value:
raise SystemExit(f'[ckpt_eval] missing required env {name}')
return value
def _ckpt_path() -> Path:
local = os.environ.get('HYDRA_EVAL_CKPT_PATH')
if local:
p = Path(local).expanduser()
print(f'[ckpt_eval] using local checkpoint {p}', flush=True)
return p
repo_id = _require_env('HF_REPO_ID')
explicit_path = os.environ.get('HYDRA_EVAL_CKPT_REPO_PATH', '').strip().lstrip('/')
if explicit_path:
path_in_repo = explicit_path
else:
source_job = _require_env('HYDRA_EVAL_CKPT_JOB_ID')
filename = os.environ.get('HYDRA_EVAL_CKPT_NAME', 'pretrain_final.pt')
path_in_repo = f'jobs/{source_job}/{filename}'
print(f'[ckpt_eval] downloading {repo_id}/{path_in_repo}', flush=True)
downloaded = hf_hub_download(
repo_id=repo_id,
filename=path_in_repo,
repo_type='model',
token=os.environ.get('HF_TOKEN'),
)
return Path(downloaded)
def main() -> int:
t0 = time.time()
print('[ckpt_eval] phase=start', flush=True)
repo_root = Path('/workspace/feather') if Path('/workspace/feather').exists() else Path.cwd()
os.chdir(repo_root)
sys.path.insert(0, str(repo_root))
# Imports after cwd is set so overlay modules win inside the image.
import prepare as _prepare_mod
from prepare import MAX_SEQ_LEN, Tokenizer
from hydra.config import (
D_MODEL, D_STATE, ENGRAM_KEY_DIM, ENGRAM_LAYER_IDX, ENGRAM_N_COLUMNS,
EXPAND, HEADDIM, N_HEADS, N_LAYER, PostSemClawConfig,
)
from hydra.model import PostSemClawModel
def config_from_dict(payload: dict) -> PostSemClawConfig:
field_names = {field.name for field in dataclasses.fields(PostSemClawConfig)}
kwargs = {key: value for key, value in payload.items() if key in field_names}
for key in ('hyena_layers', 'gdn_layers'):
if key in kwargs and isinstance(kwargs[key], list):
kwargs[key] = tuple(kwargs[key])
return PostSemClawConfig(**kwargs)
if os.environ.get('HYDRA_USE_NEMOTRON', '0') == '1':
import prepare_nemotron as _p_nemo
from prepare_nemotron import evaluate_bpb
_p_nemo.ensure_tokenizer()
import subsystems.sdr_retina as _sdr_retina
_sdr_retina.build_retina()
else:
from prepare import evaluate_bpb
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'[ckpt_eval] device={device} cuda={int(torch.cuda.is_available())}', flush=True)
torch.set_float32_matmul_precision('high')
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
ckpt = torch.load(str(_ckpt_path()), map_location='cpu', weights_only=False)
tokenizer = Tokenizer.from_directory()
vocab_size = tokenizer.get_vocab_size()
cfg_payload = ckpt.get('config')
if isinstance(cfg_payload, dict):
config = config_from_dict(cfg_payload)
else:
config = PostSemClawConfig(
sequence_len=MAX_SEQ_LEN,
vocab_size=vocab_size,
n_layer=N_LAYER,
d_model=D_MODEL,
d_state=D_STATE,
headdim=HEADDIM,
n_heads=N_HEADS,
expand=EXPAND,
engram_n_columns=ENGRAM_N_COLUMNS,
engram_key_dim=ENGRAM_KEY_DIM,
engram_layer_idx=ENGRAM_LAYER_IDX,
)
print(f'[ckpt_eval] checkpoint_step={ckpt.get("step")} vocab_size={vocab_size}', flush=True)
with torch.device('meta'):
model = PostSemClawModel(config)
model.to_empty(device=device)
missing, unexpected = model.load_state_dict(ckpt.get('model_state_dict', ckpt), strict=False)
print(f'[ckpt_eval] load_state missing={len(missing)} unexpected={len(unexpected)}', flush=True)
model.eval()
if hasattr(model, 'set_bos_token_id'):
model.set_bos_token_id(tokenizer.get_bos_token_id())
del ckpt
if torch.cuda.is_available():
torch.cuda.empty_cache()
eval_tokens = int(os.environ.get('HYDRA_EVAL_TOKENS', os.environ.get('HYDRA_STREAM_EVAL_TOKENS', '262144')))
eval_batch = int(os.environ.get('HYDRA_EVAL_BATCH', '1'))
_prepare_mod.EVAL_TOKENS = eval_tokens
os.environ['HYDRA_STREAM_EVAL_TOKENS'] = str(eval_tokens)
print(f'[ckpt_eval] running eval tokens={eval_tokens} batch={eval_batch}', flush=True)
with torch.no_grad(), torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=torch.cuda.is_available()):
val_bpb = evaluate_bpb(model, tokenizer, eval_batch)
val_ppl = 2 ** val_bpb
metrics = {
'checkpoint_job_id': os.environ.get('HYDRA_EVAL_CKPT_JOB_ID'),
'checkpoint_name': os.environ.get('HYDRA_EVAL_CKPT_NAME', 'pretrain_final.pt'),
'checkpoint_repo_path': os.environ.get('HYDRA_EVAL_CKPT_REPO_PATH'),
'eval_tokens': eval_tokens,
'eval_batch': eval_batch,
'val_bpb': float(val_bpb),
'val_ppl': float(val_ppl),
'seconds': round(time.time() - t0, 3),
}
print(f'[CKPT_EVAL_JSON] {json.dumps(metrics, sort_keys=True)}', flush=True)
print('[ckpt_eval] phase=done', flush=True)
return 0
if __name__ == '__main__':
# Full-corpus streaming eval can leave HF datasets downloader/native threads
# alive at interpreter shutdown after [CKPT_EVAL_JSON] is already flushed.
# Exit the process directly so HF Jobs records the completed metric instead
# of converting a post-metric PyGILState finalization abort into ERROR.
_rc = main()
sys.stdout.flush()
sys.stderr.flush()
os._exit(_rc)
|