Spaces:
Runtime error
Runtime error
| """Sample English from latest checkpoint using HuggingFace transformers.generate(). | |
| Wraps PostSemClawModel in a minimal GenerationMixin shim so we get: | |
| - Beam search (num_beams=4) | |
| - Top-k / top-p / temperature sampling | |
| - Repetition penalty | |
| - All the battle-tested stopping criteria | |
| Usage: python scripts/sample_english.py | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import sys | |
| sys.stdout.reconfigure(line_buffering=True) | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| import torch | |
| import torch.nn as nn | |
| from transformers import ( | |
| GenerationConfig, | |
| GenerationMixin, | |
| PretrainedConfig, | |
| PreTrainedModel, | |
| ) | |
| from transformers.modeling_outputs import CausalLMOutputWithPast | |
| from hydra.config import PostSemClawConfig | |
| from hydra.model import PostSemClawModel | |
| from prepare import Tokenizer | |
| CKPT_PATH = os.path.expanduser("~/.cache/autoresearch/latest.pt") | |
| class _HydraGenConfig(PretrainedConfig): | |
| model_type = "hydra" | |
| def __init__(self, vocab_size: int = 65536, **kw): | |
| super().__init__(**kw) | |
| self.vocab_size = vocab_size | |
| self.num_hidden_layers = 4 | |
| self.hidden_size = 256 | |
| self.num_attention_heads = 4 | |
| class HydraForCausalLM(PreTrainedModel, GenerationMixin): | |
| """HF wrapper around PostSemClawModel so we can use .generate().""" | |
| config_class = _HydraGenConfig | |
| def __init__(self, gen_config, inner_model): | |
| super().__init__(gen_config) | |
| self.inner = inner_model | |
| # HF looks for these attrs | |
| self.config.vocab_size = gen_config.vocab_size | |
| def forward(self, input_ids, attention_mask=None, **kw): | |
| logits = self.inner(input_ids) | |
| return CausalLMOutputWithPast(loss=None, logits=logits, past_key_values=None) | |
| def prepare_inputs_for_generation(self, input_ids, **kw): | |
| # Our model has no KV cache — always feed full context | |
| return {"input_ids": input_ids} | |
| def get_input_embeddings(self): | |
| return self.inner.wte | |
| def can_generate(self) -> bool: | |
| return True | |
| def _supports_cache_class(self): | |
| return False | |
| def main() -> None: | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"[sample] device: {device}") | |
| tokenizer = Tokenizer.from_directory() | |
| vocab_size = tokenizer.get_vocab_size() | |
| bos = tokenizer.get_bos_token_id() | |
| ckpt = torch.load(CKPT_PATH, map_location="cpu", weights_only=False) | |
| cfg_dict = ckpt["config"] | |
| step = ckpt.get("step", "?") | |
| print(f"[sample] loaded step={step}") | |
| cfg = PostSemClawConfig(**cfg_dict) | |
| with torch.device("meta"): | |
| inner = PostSemClawModel(cfg) | |
| inner.to_empty(device=device) | |
| inner.load_state_dict(ckpt["model_state_dict"], strict=False) | |
| inner.eval() | |
| gen_cfg = _HydraGenConfig(vocab_size=vocab_size) | |
| # Set common pad/eos tokens so HF generate is happy (we use BOS as both) | |
| gen_cfg.bos_token_id = bos | |
| gen_cfg.eos_token_id = bos | |
| gen_cfg.pad_token_id = bos | |
| model = HydraForCausalLM(gen_cfg, inner).to(device) | |
| model.eval() | |
| print(f"[sample] model ready, vocab={vocab_size}") | |
| PROMPTS = [ | |
| "The capital of France is", | |
| "Paris is known for", | |
| "Once upon a time", | |
| "Water boils at", | |
| "Shakespeare wrote", | |
| "The theory of evolution was proposed by", | |
| "Einstein discovered that", | |
| "Photosynthesis is", | |
| ] | |
| # --- Greedy --- | |
| print("\n=== GREEDY (baseline) ===") | |
| gen_config = GenerationConfig( | |
| max_new_tokens=20, use_cache=False, | |
| do_sample=False, | |
| num_beams=1, | |
| bos_token_id=bos, eos_token_id=bos, pad_token_id=bos, | |
| ) | |
| for prompt in PROMPTS: | |
| ids = torch.tensor([tokenizer.encode(prompt)], dtype=torch.long, device=device) | |
| with torch.no_grad(), torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): | |
| out = model.generate(ids, generation_config=gen_config) | |
| text = tokenizer.decode(out[0].tolist()) | |
| print(f' "{prompt}" -> "{text}"') | |
| # --- Beam search (4 beams) --- | |
| print("\n=== BEAM SEARCH (4 beams, length_penalty=1.0) ===") | |
| gen_config = GenerationConfig( | |
| max_new_tokens=20, use_cache=False, | |
| num_beams=4, | |
| do_sample=False, | |
| length_penalty=1.0, | |
| no_repeat_ngram_size=3, | |
| early_stopping=True, | |
| bos_token_id=bos, eos_token_id=bos, pad_token_id=bos, | |
| ) | |
| for prompt in PROMPTS[:4]: | |
| ids = torch.tensor([tokenizer.encode(prompt)], dtype=torch.long, device=device) | |
| with torch.no_grad(), torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): | |
| out = model.generate(ids, generation_config=gen_config) | |
| text = tokenizer.decode(out[0].tolist()) | |
| print(f' "{prompt}" -> "{text}"') | |
| # --- Top-p sampling (nucleus, t=0.8, p=0.9) --- | |
| print("\n=== TOP-P SAMPLING (temperature=0.8, top_p=0.9) ===") | |
| gen_config = GenerationConfig( | |
| max_new_tokens=30, use_cache=False, | |
| do_sample=True, | |
| temperature=0.8, | |
| top_p=0.9, | |
| repetition_penalty=1.2, | |
| bos_token_id=bos, eos_token_id=bos, pad_token_id=bos, | |
| ) | |
| torch.manual_seed(42) | |
| for prompt in PROMPTS[:4]: | |
| ids = torch.tensor([tokenizer.encode(prompt)], dtype=torch.long, device=device) | |
| with torch.no_grad(), torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): | |
| out = model.generate(ids, generation_config=gen_config) | |
| text = tokenizer.decode(out[0].tolist()) | |
| print(f' "{prompt}" -> "{text}"') | |
| print("\n[sample] done.") | |
| if __name__ == "__main__": | |
| main() | |