File size: 1,712 Bytes
3d7f6c5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 | #!/usr/bin/env python3
"""Generate text samples from the trained WrinkleBrane model."""
import sys
sys.path.insert(0, 'src')
import torch
from wrinklebrane.standalone_model import WrinkleBraneModel
from wrinklebrane.data import encode_bytes, decode_tokens, BOS_ID
# Load best checkpoint
print("Loading best checkpoint...")
ckpt = torch.load('checkpoints/best_model.pt', weights_only=False, map_location='cpu')
config = ckpt['config']
model = WrinkleBraneModel(config)
model.load_state_dict(ckpt['model_state_dict'])
model.eval()
print(f"Model: {config.d_model}d, {config.n_layers}L, {config.n_heads}H")
print(f"Checkpoint step: {ckpt.get('step')}, val_loss: {ckpt.get('val_loss'):.4f}")
print()
def generate(prompt, max_tokens=200, temperature=0.8):
tokens = [BOS_ID] + encode_bytes(prompt)
input_ids = torch.tensor([tokens], dtype=torch.long)
with torch.no_grad():
logits, states = model.forward_sequential(input_ids)
generated = list(tokens)
current = generated[-1]
for _ in range(max_tokens):
inp = torch.tensor([[current]], dtype=torch.long)
with torch.no_grad():
logits, states = model.forward_sequential(inp, states)
probs = torch.softmax(logits[0, -1] / temperature, dim=-1)
current = torch.multinomial(probs, 1).item()
generated.append(current)
if current == 2: # EOS
break
return decode_tokens(generated)
# Generate samples
prompts = [
"Once upon a time",
"The cat sat on",
"2 + 3 =",
"def hello",
"The little bear",
]
for prompt in prompts:
print(f"=== Prompt: '{prompt}' ===")
sample = generate(prompt, max_tokens=150)
print(sample[:400])
print()
|