WrinkleBrane / Wrinkle /09_standalone_model /generate_samples.py
WCNegentropy's picture
Upload 510 files
3d7f6c5 verified
#!/usr/bin/env python3
"""Generate text samples from the trained WrinkleBrane model."""
import sys
sys.path.insert(0, 'src')
import torch
from wrinklebrane.standalone_model import WrinkleBraneModel
from wrinklebrane.data import encode_bytes, decode_tokens, BOS_ID
# Load best checkpoint
print("Loading best checkpoint...")
ckpt = torch.load('checkpoints/best_model.pt', weights_only=False, map_location='cpu')
config = ckpt['config']
model = WrinkleBraneModel(config)
model.load_state_dict(ckpt['model_state_dict'])
model.eval()
print(f"Model: {config.d_model}d, {config.n_layers}L, {config.n_heads}H")
print(f"Checkpoint step: {ckpt.get('step')}, val_loss: {ckpt.get('val_loss'):.4f}")
print()
def generate(prompt, max_tokens=200, temperature=0.8):
tokens = [BOS_ID] + encode_bytes(prompt)
input_ids = torch.tensor([tokens], dtype=torch.long)
with torch.no_grad():
logits, states = model.forward_sequential(input_ids)
generated = list(tokens)
current = generated[-1]
for _ in range(max_tokens):
inp = torch.tensor([[current]], dtype=torch.long)
with torch.no_grad():
logits, states = model.forward_sequential(inp, states)
probs = torch.softmax(logits[0, -1] / temperature, dim=-1)
current = torch.multinomial(probs, 1).item()
generated.append(current)
if current == 2: # EOS
break
return decode_tokens(generated)
# Generate samples
prompts = [
"Once upon a time",
"The cat sat on",
"2 + 3 =",
"def hello",
"The little bear",
]
for prompt in prompts:
print(f"=== Prompt: '{prompt}' ===")
sample = generate(prompt, max_tokens=150)
print(sample[:400])
print()