Spaces:
Runtime error
Runtime error
File size: 6,352 Bytes
951f760 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 | #!/usr/bin/env python3
from __future__ import annotations
import os
from pathlib import Path
from typing import Callable
import torch
from scripts.benchmark_checkpoint import hydrate_checkpoint
from scripts.hf_routing import resolve_routing
def default_checkpoint_path() -> Path:
return Path(os.path.expanduser("~/.cache/autoresearch/latest.pt"))
def checkpoint_candidates(*, cache_dir: Path | None = None) -> list[Path]:
base = cache_dir or Path(os.path.expanduser("~/.cache/autoresearch"))
return [
base / "best_bpb.pt",
base / "pretrain_final.pt",
base / "latest.pt",
]
def resolve_checkpoint_path(explicit_path: Path | None, *, cache_dir: Path | None = None) -> Path:
if explicit_path is not None:
return explicit_path
for candidate in checkpoint_candidates(cache_dir=cache_dir):
if candidate.exists():
return candidate
return default_checkpoint_path()
def validate_checkpoint_compatibility(
*,
baseline_arch: str,
missing_keys: list[str],
unexpected_keys: list[str],
total_model_keys: int,
) -> None:
if baseline_arch == "transformer" and (missing_keys or unexpected_keys):
raise RuntimeError(
"checkpoint incompatible with transformer baseline architecture; "
"use a transformer-trained checkpoint or keep HYDRA_BASELINE_ARCH=mamba3"
)
mismatch_count = len(missing_keys) + len(unexpected_keys)
if total_model_keys > 0 and mismatch_count > max(8, total_model_keys // 2):
raise RuntimeError("checkpoint incompatible with requested model architecture")
def generate_from_callable(
generator: Callable[[str], str] | Callable[..., str],
prompt: str,
*,
max_new_tokens: int,
temperature: float,
top_p: float,
) -> str:
text = generator(
prompt,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
)
return str(text).strip()
def load_hydra_causal_lm(checkpoint_path: Path | None = None, device: str | None = None):
ckpt_path = resolve_checkpoint_path(checkpoint_path)
if not ckpt_path.exists():
hydrated = hydrate_checkpoint(
cache_dir=ckpt_path.parent,
output_repo=resolve_routing(token=os.environ.get("HF_TOKEN")).output_repo,
token=os.environ.get("HF_TOKEN"),
)
if hydrated is not None:
ckpt_path = hydrated
if not ckpt_path.exists():
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
from transformers import GenerationConfig, GenerationMixin, PretrainedConfig, PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from hydra.config import PostSemClawConfig
from hydra.model import PostSemClawModel
from prepare import Tokenizer
resolved_device = device or ("cuda" if torch.cuda.is_available() else "cpu")
class _HydraGenConfig(PretrainedConfig):
model_type = "hydra"
def __init__(self, vocab_size: int = 65536, **kw):
super().__init__(**kw)
self.vocab_size = vocab_size
class HydraForCausalLM(PreTrainedModel, GenerationMixin):
config_class = _HydraGenConfig
def __init__(self, gen_config, inner_model):
super().__init__(gen_config)
self.inner = inner_model
self.config.vocab_size = gen_config.vocab_size
def forward(self, input_ids, attention_mask=None, **kw):
logits = self.inner(input_ids)
return CausalLMOutputWithPast(loss=None, logits=logits, past_key_values=None)
def prepare_inputs_for_generation(self, input_ids, **kw):
return {"input_ids": input_ids}
def get_input_embeddings(self):
return self.inner.wte
def can_generate(self) -> bool:
return True
@property
def _supports_cache_class(self):
return False
tokenizer = Tokenizer.from_directory()
vocab_size = tokenizer.get_vocab_size()
bos = tokenizer.get_bos_token_id()
ckpt = torch.load(str(ckpt_path), map_location="cpu", weights_only=False)
cfg = PostSemClawConfig(**ckpt["config"])
with torch.device("meta"):
inner = PostSemClawModel(cfg)
inner.to_empty(device=resolved_device)
missing, unexpected = inner.load_state_dict(ckpt["model_state_dict"], strict=False)
validate_checkpoint_compatibility(
baseline_arch=os.environ.get("HYDRA_BASELINE_ARCH", "mamba3").strip().lower(),
missing_keys=list(missing),
unexpected_keys=list(unexpected),
total_model_keys=len(inner.state_dict()),
)
inner.eval()
gen_cfg = _HydraGenConfig(vocab_size=vocab_size)
gen_cfg.bos_token_id = bos
gen_cfg.eos_token_id = bos
gen_cfg.pad_token_id = bos
model = HydraForCausalLM(gen_cfg, inner).to(resolved_device)
model.eval()
return tokenizer, model, bos, resolved_device, GenerationConfig
def build_hydra_generator(
*,
checkpoint_path: Path | None = None,
device: str | None = None,
max_new_tokens: int,
temperature: float,
top_p: float,
):
tokenizer, model, bos, resolved_device, GenerationConfig = load_hydra_causal_lm(checkpoint_path=checkpoint_path, device=device)
def _generate(prompt: str) -> str:
ids = torch.tensor([tokenizer.encode(prompt)], dtype=torch.long, device=resolved_device)
gen_config = GenerationConfig(
max_new_tokens=max_new_tokens,
use_cache=False,
do_sample=temperature > 0.0,
temperature=temperature,
top_p=top_p,
bos_token_id=bos,
eos_token_id=bos,
pad_token_id=bos,
)
if str(resolved_device).startswith("cuda"):
with torch.no_grad(), torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
out = model.generate(ids, generation_config=gen_config)
else:
with torch.no_grad():
out = model.generate(ids, generation_config=gen_config)
return tokenizer.decode(out[0].tolist())
return _generate
|