feather-runtime / overlay /scripts /sample_english.py
Jackoatmon's picture
Update Feather h200 training runtime image
e317e25 verified
"""Sample English from latest checkpoint using HuggingFace transformers.generate().
Wraps PostSemClawModel in a minimal GenerationMixin shim so we get:
- Beam search (num_beams=4)
- Top-k / top-p / temperature sampling
- Repetition penalty
- All the battle-tested stopping criteria
Usage: python scripts/sample_english.py
"""
from __future__ import annotations
import os
import sys
sys.stdout.reconfigure(line_buffering=True)
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch
import torch.nn as nn
from transformers import (
GenerationConfig,
GenerationMixin,
PretrainedConfig,
PreTrainedModel,
)
from transformers.modeling_outputs import CausalLMOutputWithPast
from hydra.config import PostSemClawConfig
from hydra.model import PostSemClawModel
from prepare import Tokenizer
CKPT_PATH = os.path.expanduser("~/.cache/autoresearch/latest.pt")
class _HydraGenConfig(PretrainedConfig):
model_type = "hydra"
def __init__(self, vocab_size: int = 65536, **kw):
super().__init__(**kw)
self.vocab_size = vocab_size
self.num_hidden_layers = 4
self.hidden_size = 256
self.num_attention_heads = 4
class HydraForCausalLM(PreTrainedModel, GenerationMixin):
"""HF wrapper around PostSemClawModel so we can use .generate()."""
config_class = _HydraGenConfig
def __init__(self, gen_config, inner_model):
super().__init__(gen_config)
self.inner = inner_model
# HF looks for these attrs
self.config.vocab_size = gen_config.vocab_size
def forward(self, input_ids, attention_mask=None, **kw):
logits = self.inner(input_ids)
return CausalLMOutputWithPast(loss=None, logits=logits, past_key_values=None)
def prepare_inputs_for_generation(self, input_ids, **kw):
# Our model has no KV cache — always feed full context
return {"input_ids": input_ids}
def get_input_embeddings(self):
return self.inner.wte
def can_generate(self) -> bool:
return True
@property
def _supports_cache_class(self):
return False
def main() -> None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[sample] device: {device}")
tokenizer = Tokenizer.from_directory()
vocab_size = tokenizer.get_vocab_size()
bos = tokenizer.get_bos_token_id()
ckpt = torch.load(CKPT_PATH, map_location="cpu", weights_only=False)
cfg_dict = ckpt["config"]
step = ckpt.get("step", "?")
print(f"[sample] loaded step={step}")
cfg = PostSemClawConfig(**cfg_dict)
with torch.device("meta"):
inner = PostSemClawModel(cfg)
inner.to_empty(device=device)
inner.load_state_dict(ckpt["model_state_dict"], strict=False)
inner.eval()
gen_cfg = _HydraGenConfig(vocab_size=vocab_size)
# Set common pad/eos tokens so HF generate is happy (we use BOS as both)
gen_cfg.bos_token_id = bos
gen_cfg.eos_token_id = bos
gen_cfg.pad_token_id = bos
model = HydraForCausalLM(gen_cfg, inner).to(device)
model.eval()
print(f"[sample] model ready, vocab={vocab_size}")
PROMPTS = [
"The capital of France is",
"Paris is known for",
"Once upon a time",
"Water boils at",
"Shakespeare wrote",
"The theory of evolution was proposed by",
"Einstein discovered that",
"Photosynthesis is",
]
# --- Greedy ---
print("\n=== GREEDY (baseline) ===")
gen_config = GenerationConfig(
max_new_tokens=20, use_cache=False,
do_sample=False,
num_beams=1,
bos_token_id=bos, eos_token_id=bos, pad_token_id=bos,
)
for prompt in PROMPTS:
ids = torch.tensor([tokenizer.encode(prompt)], dtype=torch.long, device=device)
with torch.no_grad(), torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
out = model.generate(ids, generation_config=gen_config)
text = tokenizer.decode(out[0].tolist())
print(f' "{prompt}" -> "{text}"')
# --- Beam search (4 beams) ---
print("\n=== BEAM SEARCH (4 beams, length_penalty=1.0) ===")
gen_config = GenerationConfig(
max_new_tokens=20, use_cache=False,
num_beams=4,
do_sample=False,
length_penalty=1.0,
no_repeat_ngram_size=3,
early_stopping=True,
bos_token_id=bos, eos_token_id=bos, pad_token_id=bos,
)
for prompt in PROMPTS[:4]:
ids = torch.tensor([tokenizer.encode(prompt)], dtype=torch.long, device=device)
with torch.no_grad(), torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
out = model.generate(ids, generation_config=gen_config)
text = tokenizer.decode(out[0].tolist())
print(f' "{prompt}" -> "{text}"')
# --- Top-p sampling (nucleus, t=0.8, p=0.9) ---
print("\n=== TOP-P SAMPLING (temperature=0.8, top_p=0.9) ===")
gen_config = GenerationConfig(
max_new_tokens=30, use_cache=False,
do_sample=True,
temperature=0.8,
top_p=0.9,
repetition_penalty=1.2,
bos_token_id=bos, eos_token_id=bos, pad_token_id=bos,
)
torch.manual_seed(42)
for prompt in PROMPTS[:4]:
ids = torch.tensor([tokenizer.encode(prompt)], dtype=torch.long, device=device)
with torch.no_grad(), torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
out = model.generate(ids, generation_config=gen_config)
text = tokenizer.decode(out[0].tolist())
print(f' "{prompt}" -> "{text}"')
print("\n[sample] done.")
if __name__ == "__main__":
main()