#!/usr/bin/env python3 from __future__ import annotations import os from pathlib import Path from typing import Callable import torch from scripts.benchmark_checkpoint import hydrate_checkpoint from scripts.hf_routing import resolve_routing def default_checkpoint_path() -> Path: return Path(os.path.expanduser("~/.cache/autoresearch/latest.pt")) def checkpoint_candidates(*, cache_dir: Path | None = None) -> list[Path]: base = cache_dir or Path(os.path.expanduser("~/.cache/autoresearch")) return [ base / "best_bpb.pt", base / "pretrain_final.pt", base / "latest.pt", ] def resolve_checkpoint_path(explicit_path: Path | None, *, cache_dir: Path | None = None) -> Path: if explicit_path is not None: return explicit_path for candidate in checkpoint_candidates(cache_dir=cache_dir): if candidate.exists(): return candidate return default_checkpoint_path() def validate_checkpoint_compatibility( *, baseline_arch: str, missing_keys: list[str], unexpected_keys: list[str], total_model_keys: int, ) -> None: if baseline_arch == "transformer" and (missing_keys or unexpected_keys): raise RuntimeError( "checkpoint incompatible with transformer baseline architecture; " "use a transformer-trained checkpoint or keep HYDRA_BASELINE_ARCH=mamba3" ) mismatch_count = len(missing_keys) + len(unexpected_keys) if total_model_keys > 0 and mismatch_count > max(8, total_model_keys // 2): raise RuntimeError("checkpoint incompatible with requested model architecture") def generate_from_callable( generator: Callable[[str], str] | Callable[..., str], prompt: str, *, max_new_tokens: int, temperature: float, top_p: float, ) -> str: text = generator( prompt, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, ) return str(text).strip() def load_hydra_causal_lm(checkpoint_path: Path | None = None, device: str | None = None): ckpt_path = resolve_checkpoint_path(checkpoint_path) if not ckpt_path.exists(): hydrated = hydrate_checkpoint( cache_dir=ckpt_path.parent, output_repo=resolve_routing(token=os.environ.get("HF_TOKEN")).output_repo, token=os.environ.get("HF_TOKEN"), ) if hydrated is not None: ckpt_path = hydrated if not ckpt_path.exists(): raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}") from transformers import GenerationConfig, GenerationMixin, PretrainedConfig, PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithPast from hydra.config import PostSemClawConfig from hydra.model import PostSemClawModel from prepare import Tokenizer resolved_device = device or ("cuda" if torch.cuda.is_available() else "cpu") class _HydraGenConfig(PretrainedConfig): model_type = "hydra" def __init__(self, vocab_size: int = 65536, **kw): super().__init__(**kw) self.vocab_size = vocab_size class HydraForCausalLM(PreTrainedModel, GenerationMixin): config_class = _HydraGenConfig def __init__(self, gen_config, inner_model): super().__init__(gen_config) self.inner = inner_model self.config.vocab_size = gen_config.vocab_size def forward(self, input_ids, attention_mask=None, **kw): logits = self.inner(input_ids) return CausalLMOutputWithPast(loss=None, logits=logits, past_key_values=None) def prepare_inputs_for_generation(self, input_ids, **kw): return {"input_ids": input_ids} def get_input_embeddings(self): return self.inner.wte def can_generate(self) -> bool: return True @property def _supports_cache_class(self): return False tokenizer = Tokenizer.from_directory() vocab_size = tokenizer.get_vocab_size() bos = tokenizer.get_bos_token_id() ckpt = torch.load(str(ckpt_path), map_location="cpu", weights_only=False) cfg = PostSemClawConfig(**ckpt["config"]) with torch.device("meta"): inner = PostSemClawModel(cfg) inner.to_empty(device=resolved_device) missing, unexpected = inner.load_state_dict(ckpt["model_state_dict"], strict=False) validate_checkpoint_compatibility( baseline_arch=os.environ.get("HYDRA_BASELINE_ARCH", "mamba3").strip().lower(), missing_keys=list(missing), unexpected_keys=list(unexpected), total_model_keys=len(inner.state_dict()), ) inner.eval() gen_cfg = _HydraGenConfig(vocab_size=vocab_size) gen_cfg.bos_token_id = bos gen_cfg.eos_token_id = bos gen_cfg.pad_token_id = bos model = HydraForCausalLM(gen_cfg, inner).to(resolved_device) model.eval() return tokenizer, model, bos, resolved_device, GenerationConfig def build_hydra_generator( *, checkpoint_path: Path | None = None, device: str | None = None, max_new_tokens: int, temperature: float, top_p: float, ): tokenizer, model, bos, resolved_device, GenerationConfig = load_hydra_causal_lm(checkpoint_path=checkpoint_path, device=device) def _generate(prompt: str) -> str: ids = torch.tensor([tokenizer.encode(prompt)], dtype=torch.long, device=resolved_device) gen_config = GenerationConfig( max_new_tokens=max_new_tokens, use_cache=False, do_sample=temperature > 0.0, temperature=temperature, top_p=top_p, bos_token_id=bos, eos_token_id=bos, pad_token_id=bos, ) if str(resolved_device).startswith("cuda"): with torch.no_grad(), torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): out = model.generate(ids, generation_config=gen_config) else: with torch.no_grad(): out = model.generate(ids, generation_config=gen_config) return tokenizer.decode(out[0].tolist()) return _generate