Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| from __future__ import annotations | |
| import os | |
| from pathlib import Path | |
| from typing import Callable | |
| import torch | |
| from scripts.benchmark_checkpoint import hydrate_checkpoint | |
| from scripts.hf_routing import resolve_routing | |
| def default_checkpoint_path() -> Path: | |
| return Path(os.path.expanduser("~/.cache/autoresearch/latest.pt")) | |
| def checkpoint_candidates(*, cache_dir: Path | None = None) -> list[Path]: | |
| base = cache_dir or Path(os.path.expanduser("~/.cache/autoresearch")) | |
| return [ | |
| base / "best_bpb.pt", | |
| base / "pretrain_final.pt", | |
| base / "latest.pt", | |
| ] | |
| def resolve_checkpoint_path(explicit_path: Path | None, *, cache_dir: Path | None = None) -> Path: | |
| if explicit_path is not None: | |
| return explicit_path | |
| for candidate in checkpoint_candidates(cache_dir=cache_dir): | |
| if candidate.exists(): | |
| return candidate | |
| return default_checkpoint_path() | |
| def validate_checkpoint_compatibility( | |
| *, | |
| baseline_arch: str, | |
| missing_keys: list[str], | |
| unexpected_keys: list[str], | |
| total_model_keys: int, | |
| ) -> None: | |
| if baseline_arch == "transformer" and (missing_keys or unexpected_keys): | |
| raise RuntimeError( | |
| "checkpoint incompatible with transformer baseline architecture; " | |
| "use a transformer-trained checkpoint or keep HYDRA_BASELINE_ARCH=mamba3" | |
| ) | |
| mismatch_count = len(missing_keys) + len(unexpected_keys) | |
| if total_model_keys > 0 and mismatch_count > max(8, total_model_keys // 2): | |
| raise RuntimeError("checkpoint incompatible with requested model architecture") | |
| def generate_from_callable( | |
| generator: Callable[[str], str] | Callable[..., str], | |
| prompt: str, | |
| *, | |
| max_new_tokens: int, | |
| temperature: float, | |
| top_p: float, | |
| ) -> str: | |
| text = generator( | |
| prompt, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| ) | |
| return str(text).strip() | |
| def load_hydra_causal_lm(checkpoint_path: Path | None = None, device: str | None = None): | |
| ckpt_path = resolve_checkpoint_path(checkpoint_path) | |
| if not ckpt_path.exists(): | |
| hydrated = hydrate_checkpoint( | |
| cache_dir=ckpt_path.parent, | |
| output_repo=resolve_routing(token=os.environ.get("HF_TOKEN")).output_repo, | |
| token=os.environ.get("HF_TOKEN"), | |
| ) | |
| if hydrated is not None: | |
| ckpt_path = hydrated | |
| if not ckpt_path.exists(): | |
| raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}") | |
| from transformers import GenerationConfig, GenerationMixin, PretrainedConfig, PreTrainedModel | |
| from transformers.modeling_outputs import CausalLMOutputWithPast | |
| from hydra.config import PostSemClawConfig | |
| from hydra.model import PostSemClawModel | |
| from prepare import Tokenizer | |
| resolved_device = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
| class _HydraGenConfig(PretrainedConfig): | |
| model_type = "hydra" | |
| def __init__(self, vocab_size: int = 65536, **kw): | |
| super().__init__(**kw) | |
| self.vocab_size = vocab_size | |
| class HydraForCausalLM(PreTrainedModel, GenerationMixin): | |
| config_class = _HydraGenConfig | |
| def __init__(self, gen_config, inner_model): | |
| super().__init__(gen_config) | |
| self.inner = inner_model | |
| self.config.vocab_size = gen_config.vocab_size | |
| def forward(self, input_ids, attention_mask=None, **kw): | |
| logits = self.inner(input_ids) | |
| return CausalLMOutputWithPast(loss=None, logits=logits, past_key_values=None) | |
| def prepare_inputs_for_generation(self, input_ids, **kw): | |
| return {"input_ids": input_ids} | |
| def get_input_embeddings(self): | |
| return self.inner.wte | |
| def can_generate(self) -> bool: | |
| return True | |
| def _supports_cache_class(self): | |
| return False | |
| tokenizer = Tokenizer.from_directory() | |
| vocab_size = tokenizer.get_vocab_size() | |
| bos = tokenizer.get_bos_token_id() | |
| ckpt = torch.load(str(ckpt_path), map_location="cpu", weights_only=False) | |
| cfg = PostSemClawConfig(**ckpt["config"]) | |
| with torch.device("meta"): | |
| inner = PostSemClawModel(cfg) | |
| inner.to_empty(device=resolved_device) | |
| missing, unexpected = inner.load_state_dict(ckpt["model_state_dict"], strict=False) | |
| validate_checkpoint_compatibility( | |
| baseline_arch=os.environ.get("HYDRA_BASELINE_ARCH", "mamba3").strip().lower(), | |
| missing_keys=list(missing), | |
| unexpected_keys=list(unexpected), | |
| total_model_keys=len(inner.state_dict()), | |
| ) | |
| inner.eval() | |
| gen_cfg = _HydraGenConfig(vocab_size=vocab_size) | |
| gen_cfg.bos_token_id = bos | |
| gen_cfg.eos_token_id = bos | |
| gen_cfg.pad_token_id = bos | |
| model = HydraForCausalLM(gen_cfg, inner).to(resolved_device) | |
| model.eval() | |
| return tokenizer, model, bos, resolved_device, GenerationConfig | |
| def build_hydra_generator( | |
| *, | |
| checkpoint_path: Path | None = None, | |
| device: str | None = None, | |
| max_new_tokens: int, | |
| temperature: float, | |
| top_p: float, | |
| ): | |
| tokenizer, model, bos, resolved_device, GenerationConfig = load_hydra_causal_lm(checkpoint_path=checkpoint_path, device=device) | |
| def _generate(prompt: str) -> str: | |
| ids = torch.tensor([tokenizer.encode(prompt)], dtype=torch.long, device=resolved_device) | |
| gen_config = GenerationConfig( | |
| max_new_tokens=max_new_tokens, | |
| use_cache=False, | |
| do_sample=temperature > 0.0, | |
| temperature=temperature, | |
| top_p=top_p, | |
| bos_token_id=bos, | |
| eos_token_id=bos, | |
| pad_token_id=bos, | |
| ) | |
| if str(resolved_device).startswith("cuda"): | |
| with torch.no_grad(), torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): | |
| out = model.generate(ids, generation_config=gen_config) | |
| else: | |
| with torch.no_grad(): | |
| out = model.generate(ids, generation_config=gen_config) | |
| return tokenizer.decode(out[0].tolist()) | |
| return _generate | |