feather-a10-runtime / overlay /scripts /hydra_generation.py
Jackoatmon's picture
Update Feather training runtime image
951f760 verified
#!/usr/bin/env python3
from __future__ import annotations
import os
from pathlib import Path
from typing import Callable
import torch
from scripts.benchmark_checkpoint import hydrate_checkpoint
from scripts.hf_routing import resolve_routing
def default_checkpoint_path() -> Path:
return Path(os.path.expanduser("~/.cache/autoresearch/latest.pt"))
def checkpoint_candidates(*, cache_dir: Path | None = None) -> list[Path]:
base = cache_dir or Path(os.path.expanduser("~/.cache/autoresearch"))
return [
base / "best_bpb.pt",
base / "pretrain_final.pt",
base / "latest.pt",
]
def resolve_checkpoint_path(explicit_path: Path | None, *, cache_dir: Path | None = None) -> Path:
if explicit_path is not None:
return explicit_path
for candidate in checkpoint_candidates(cache_dir=cache_dir):
if candidate.exists():
return candidate
return default_checkpoint_path()
def validate_checkpoint_compatibility(
*,
baseline_arch: str,
missing_keys: list[str],
unexpected_keys: list[str],
total_model_keys: int,
) -> None:
if baseline_arch == "transformer" and (missing_keys or unexpected_keys):
raise RuntimeError(
"checkpoint incompatible with transformer baseline architecture; "
"use a transformer-trained checkpoint or keep HYDRA_BASELINE_ARCH=mamba3"
)
mismatch_count = len(missing_keys) + len(unexpected_keys)
if total_model_keys > 0 and mismatch_count > max(8, total_model_keys // 2):
raise RuntimeError("checkpoint incompatible with requested model architecture")
def generate_from_callable(
generator: Callable[[str], str] | Callable[..., str],
prompt: str,
*,
max_new_tokens: int,
temperature: float,
top_p: float,
) -> str:
text = generator(
prompt,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
)
return str(text).strip()
def load_hydra_causal_lm(checkpoint_path: Path | None = None, device: str | None = None):
ckpt_path = resolve_checkpoint_path(checkpoint_path)
if not ckpt_path.exists():
hydrated = hydrate_checkpoint(
cache_dir=ckpt_path.parent,
output_repo=resolve_routing(token=os.environ.get("HF_TOKEN")).output_repo,
token=os.environ.get("HF_TOKEN"),
)
if hydrated is not None:
ckpt_path = hydrated
if not ckpt_path.exists():
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
from transformers import GenerationConfig, GenerationMixin, PretrainedConfig, PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from hydra.config import PostSemClawConfig
from hydra.model import PostSemClawModel
from prepare import Tokenizer
resolved_device = device or ("cuda" if torch.cuda.is_available() else "cpu")
class _HydraGenConfig(PretrainedConfig):
model_type = "hydra"
def __init__(self, vocab_size: int = 65536, **kw):
super().__init__(**kw)
self.vocab_size = vocab_size
class HydraForCausalLM(PreTrainedModel, GenerationMixin):
config_class = _HydraGenConfig
def __init__(self, gen_config, inner_model):
super().__init__(gen_config)
self.inner = inner_model
self.config.vocab_size = gen_config.vocab_size
def forward(self, input_ids, attention_mask=None, **kw):
logits = self.inner(input_ids)
return CausalLMOutputWithPast(loss=None, logits=logits, past_key_values=None)
def prepare_inputs_for_generation(self, input_ids, **kw):
return {"input_ids": input_ids}
def get_input_embeddings(self):
return self.inner.wte
def can_generate(self) -> bool:
return True
@property
def _supports_cache_class(self):
return False
tokenizer = Tokenizer.from_directory()
vocab_size = tokenizer.get_vocab_size()
bos = tokenizer.get_bos_token_id()
ckpt = torch.load(str(ckpt_path), map_location="cpu", weights_only=False)
cfg = PostSemClawConfig(**ckpt["config"])
with torch.device("meta"):
inner = PostSemClawModel(cfg)
inner.to_empty(device=resolved_device)
missing, unexpected = inner.load_state_dict(ckpt["model_state_dict"], strict=False)
validate_checkpoint_compatibility(
baseline_arch=os.environ.get("HYDRA_BASELINE_ARCH", "mamba3").strip().lower(),
missing_keys=list(missing),
unexpected_keys=list(unexpected),
total_model_keys=len(inner.state_dict()),
)
inner.eval()
gen_cfg = _HydraGenConfig(vocab_size=vocab_size)
gen_cfg.bos_token_id = bos
gen_cfg.eos_token_id = bos
gen_cfg.pad_token_id = bos
model = HydraForCausalLM(gen_cfg, inner).to(resolved_device)
model.eval()
return tokenizer, model, bos, resolved_device, GenerationConfig
def build_hydra_generator(
*,
checkpoint_path: Path | None = None,
device: str | None = None,
max_new_tokens: int,
temperature: float,
top_p: float,
):
tokenizer, model, bos, resolved_device, GenerationConfig = load_hydra_causal_lm(checkpoint_path=checkpoint_path, device=device)
def _generate(prompt: str) -> str:
ids = torch.tensor([tokenizer.encode(prompt)], dtype=torch.long, device=resolved_device)
gen_config = GenerationConfig(
max_new_tokens=max_new_tokens,
use_cache=False,
do_sample=temperature > 0.0,
temperature=temperature,
top_p=top_p,
bos_token_id=bos,
eos_token_id=bos,
pad_token_id=bos,
)
if str(resolved_device).startswith("cuda"):
with torch.no_grad(), torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
out = model.generate(ids, generation_config=gen_config)
else:
with torch.no_grad():
out = model.generate(ids, generation_config=gen_config)
return tokenizer.decode(out[0].tolist())
return _generate