"""Sample English from latest checkpoint using HuggingFace transformers.generate(). Wraps PostSemClawModel in a minimal GenerationMixin shim so we get: - Beam search (num_beams=4) - Top-k / top-p / temperature sampling - Repetition penalty - All the battle-tested stopping criteria Usage: python scripts/sample_english.py """ from __future__ import annotations import os import sys sys.stdout.reconfigure(line_buffering=True) sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import torch import torch.nn as nn from transformers import ( GenerationConfig, GenerationMixin, PretrainedConfig, PreTrainedModel, ) from transformers.modeling_outputs import CausalLMOutputWithPast from hydra.config import PostSemClawConfig from hydra.model import PostSemClawModel from prepare import Tokenizer CKPT_PATH = os.path.expanduser("~/.cache/autoresearch/latest.pt") class _HydraGenConfig(PretrainedConfig): model_type = "hydra" def __init__(self, vocab_size: int = 65536, **kw): super().__init__(**kw) self.vocab_size = vocab_size self.num_hidden_layers = 4 self.hidden_size = 256 self.num_attention_heads = 4 class HydraForCausalLM(PreTrainedModel, GenerationMixin): """HF wrapper around PostSemClawModel so we can use .generate().""" config_class = _HydraGenConfig def __init__(self, gen_config, inner_model): super().__init__(gen_config) self.inner = inner_model # HF looks for these attrs self.config.vocab_size = gen_config.vocab_size def forward(self, input_ids, attention_mask=None, **kw): logits = self.inner(input_ids) return CausalLMOutputWithPast(loss=None, logits=logits, past_key_values=None) def prepare_inputs_for_generation(self, input_ids, **kw): # Our model has no KV cache — always feed full context return {"input_ids": input_ids} def get_input_embeddings(self): return self.inner.wte def can_generate(self) -> bool: return True @property def _supports_cache_class(self): return False def main() -> None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"[sample] device: {device}") tokenizer = Tokenizer.from_directory() vocab_size = tokenizer.get_vocab_size() bos = tokenizer.get_bos_token_id() ckpt = torch.load(CKPT_PATH, map_location="cpu", weights_only=False) cfg_dict = ckpt["config"] step = ckpt.get("step", "?") print(f"[sample] loaded step={step}") cfg = PostSemClawConfig(**cfg_dict) with torch.device("meta"): inner = PostSemClawModel(cfg) inner.to_empty(device=device) inner.load_state_dict(ckpt["model_state_dict"], strict=False) inner.eval() gen_cfg = _HydraGenConfig(vocab_size=vocab_size) # Set common pad/eos tokens so HF generate is happy (we use BOS as both) gen_cfg.bos_token_id = bos gen_cfg.eos_token_id = bos gen_cfg.pad_token_id = bos model = HydraForCausalLM(gen_cfg, inner).to(device) model.eval() print(f"[sample] model ready, vocab={vocab_size}") PROMPTS = [ "The capital of France is", "Paris is known for", "Once upon a time", "Water boils at", "Shakespeare wrote", "The theory of evolution was proposed by", "Einstein discovered that", "Photosynthesis is", ] # --- Greedy --- print("\n=== GREEDY (baseline) ===") gen_config = GenerationConfig( max_new_tokens=20, use_cache=False, do_sample=False, num_beams=1, bos_token_id=bos, eos_token_id=bos, pad_token_id=bos, ) for prompt in PROMPTS: ids = torch.tensor([tokenizer.encode(prompt)], dtype=torch.long, device=device) with torch.no_grad(), torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): out = model.generate(ids, generation_config=gen_config) text = tokenizer.decode(out[0].tolist()) print(f' "{prompt}" -> "{text}"') # --- Beam search (4 beams) --- print("\n=== BEAM SEARCH (4 beams, length_penalty=1.0) ===") gen_config = GenerationConfig( max_new_tokens=20, use_cache=False, num_beams=4, do_sample=False, length_penalty=1.0, no_repeat_ngram_size=3, early_stopping=True, bos_token_id=bos, eos_token_id=bos, pad_token_id=bos, ) for prompt in PROMPTS[:4]: ids = torch.tensor([tokenizer.encode(prompt)], dtype=torch.long, device=device) with torch.no_grad(), torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): out = model.generate(ids, generation_config=gen_config) text = tokenizer.decode(out[0].tolist()) print(f' "{prompt}" -> "{text}"') # --- Top-p sampling (nucleus, t=0.8, p=0.9) --- print("\n=== TOP-P SAMPLING (temperature=0.8, top_p=0.9) ===") gen_config = GenerationConfig( max_new_tokens=30, use_cache=False, do_sample=True, temperature=0.8, top_p=0.9, repetition_penalty=1.2, bos_token_id=bos, eos_token_id=bos, pad_token_id=bos, ) torch.manual_seed(42) for prompt in PROMPTS[:4]: ids = torch.tensor([tokenizer.encode(prompt)], dtype=torch.long, device=device) with torch.no_grad(), torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): out = model.generate(ids, generation_config=gen_config) text = tokenizer.decode(out[0].tolist()) print(f' "{prompt}" -> "{text}"') print("\n[sample] done.") if __name__ == "__main__": main()