File size: 6,352 Bytes
951f760
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
#!/usr/bin/env python3
from __future__ import annotations

import os
from pathlib import Path
from typing import Callable

import torch

from scripts.benchmark_checkpoint import hydrate_checkpoint
from scripts.hf_routing import resolve_routing


def default_checkpoint_path() -> Path:
    return Path(os.path.expanduser("~/.cache/autoresearch/latest.pt"))


def checkpoint_candidates(*, cache_dir: Path | None = None) -> list[Path]:
    base = cache_dir or Path(os.path.expanduser("~/.cache/autoresearch"))
    return [
        base / "best_bpb.pt",
        base / "pretrain_final.pt",
        base / "latest.pt",
    ]


def resolve_checkpoint_path(explicit_path: Path | None, *, cache_dir: Path | None = None) -> Path:
    if explicit_path is not None:
        return explicit_path
    for candidate in checkpoint_candidates(cache_dir=cache_dir):
        if candidate.exists():
            return candidate
    return default_checkpoint_path()


def validate_checkpoint_compatibility(

    *,

    baseline_arch: str,

    missing_keys: list[str],

    unexpected_keys: list[str],

    total_model_keys: int,

) -> None:
    if baseline_arch == "transformer" and (missing_keys or unexpected_keys):
        raise RuntimeError(
            "checkpoint incompatible with transformer baseline architecture; "
            "use a transformer-trained checkpoint or keep HYDRA_BASELINE_ARCH=mamba3"
        )
    mismatch_count = len(missing_keys) + len(unexpected_keys)
    if total_model_keys > 0 and mismatch_count > max(8, total_model_keys // 2):
        raise RuntimeError("checkpoint incompatible with requested model architecture")


def generate_from_callable(

    generator: Callable[[str], str] | Callable[..., str],

    prompt: str,

    *,

    max_new_tokens: int,

    temperature: float,

    top_p: float,

) -> str:
    text = generator(
        prompt,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_p=top_p,
    )
    return str(text).strip()


def load_hydra_causal_lm(checkpoint_path: Path | None = None, device: str | None = None):
    ckpt_path = resolve_checkpoint_path(checkpoint_path)
    if not ckpt_path.exists():
        hydrated = hydrate_checkpoint(
            cache_dir=ckpt_path.parent,
            output_repo=resolve_routing(token=os.environ.get("HF_TOKEN")).output_repo,
            token=os.environ.get("HF_TOKEN"),
        )
        if hydrated is not None:
            ckpt_path = hydrated
        if not ckpt_path.exists():
            raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")

    from transformers import GenerationConfig, GenerationMixin, PretrainedConfig, PreTrainedModel
    from transformers.modeling_outputs import CausalLMOutputWithPast

    from hydra.config import PostSemClawConfig
    from hydra.model import PostSemClawModel
    from prepare import Tokenizer

    resolved_device = device or ("cuda" if torch.cuda.is_available() else "cpu")

    class _HydraGenConfig(PretrainedConfig):
        model_type = "hydra"

        def __init__(self, vocab_size: int = 65536, **kw):
            super().__init__(**kw)
            self.vocab_size = vocab_size

    class HydraForCausalLM(PreTrainedModel, GenerationMixin):
        config_class = _HydraGenConfig

        def __init__(self, gen_config, inner_model):
            super().__init__(gen_config)
            self.inner = inner_model
            self.config.vocab_size = gen_config.vocab_size

        def forward(self, input_ids, attention_mask=None, **kw):
            logits = self.inner(input_ids)
            return CausalLMOutputWithPast(loss=None, logits=logits, past_key_values=None)

        def prepare_inputs_for_generation(self, input_ids, **kw):
            return {"input_ids": input_ids}

        def get_input_embeddings(self):
            return self.inner.wte

        def can_generate(self) -> bool:
            return True

        @property
        def _supports_cache_class(self):
            return False

    tokenizer = Tokenizer.from_directory()
    vocab_size = tokenizer.get_vocab_size()
    bos = tokenizer.get_bos_token_id()
    ckpt = torch.load(str(ckpt_path), map_location="cpu", weights_only=False)
    cfg = PostSemClawConfig(**ckpt["config"])
    with torch.device("meta"):
        inner = PostSemClawModel(cfg)
    inner.to_empty(device=resolved_device)
    missing, unexpected = inner.load_state_dict(ckpt["model_state_dict"], strict=False)
    validate_checkpoint_compatibility(
        baseline_arch=os.environ.get("HYDRA_BASELINE_ARCH", "mamba3").strip().lower(),
        missing_keys=list(missing),
        unexpected_keys=list(unexpected),
        total_model_keys=len(inner.state_dict()),
    )
    inner.eval()

    gen_cfg = _HydraGenConfig(vocab_size=vocab_size)
    gen_cfg.bos_token_id = bos
    gen_cfg.eos_token_id = bos
    gen_cfg.pad_token_id = bos
    model = HydraForCausalLM(gen_cfg, inner).to(resolved_device)
    model.eval()
    return tokenizer, model, bos, resolved_device, GenerationConfig


def build_hydra_generator(

    *,

    checkpoint_path: Path | None = None,

    device: str | None = None,

    max_new_tokens: int,

    temperature: float,

    top_p: float,

):
    tokenizer, model, bos, resolved_device, GenerationConfig = load_hydra_causal_lm(checkpoint_path=checkpoint_path, device=device)

    def _generate(prompt: str) -> str:
        ids = torch.tensor([tokenizer.encode(prompt)], dtype=torch.long, device=resolved_device)
        gen_config = GenerationConfig(
            max_new_tokens=max_new_tokens,
            use_cache=False,
            do_sample=temperature > 0.0,
            temperature=temperature,
            top_p=top_p,
            bos_token_id=bos,
            eos_token_id=bos,
            pad_token_id=bos,
        )
        if str(resolved_device).startswith("cuda"):
            with torch.no_grad(), torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
                out = model.generate(ids, generation_config=gen_config)
        else:
            with torch.no_grad():
                out = model.generate(ids, generation_config=gen_config)
        return tokenizer.decode(out[0].tolist())

    return _generate