#!/usr/bin/env python3 """Generate text samples from the trained WrinkleBrane model.""" import sys sys.path.insert(0, 'src') import torch from wrinklebrane.standalone_model import WrinkleBraneModel from wrinklebrane.data import encode_bytes, decode_tokens, BOS_ID # Load best checkpoint print("Loading best checkpoint...") ckpt = torch.load('checkpoints/best_model.pt', weights_only=False, map_location='cpu') config = ckpt['config'] model = WrinkleBraneModel(config) model.load_state_dict(ckpt['model_state_dict']) model.eval() print(f"Model: {config.d_model}d, {config.n_layers}L, {config.n_heads}H") print(f"Checkpoint step: {ckpt.get('step')}, val_loss: {ckpt.get('val_loss'):.4f}") print() def generate(prompt, max_tokens=200, temperature=0.8): tokens = [BOS_ID] + encode_bytes(prompt) input_ids = torch.tensor([tokens], dtype=torch.long) with torch.no_grad(): logits, states = model.forward_sequential(input_ids) generated = list(tokens) current = generated[-1] for _ in range(max_tokens): inp = torch.tensor([[current]], dtype=torch.long) with torch.no_grad(): logits, states = model.forward_sequential(inp, states) probs = torch.softmax(logits[0, -1] / temperature, dim=-1) current = torch.multinomial(probs, 1).item() generated.append(current) if current == 2: # EOS break return decode_tokens(generated) # Generate samples prompts = [ "Once upon a time", "The cat sat on", "2 + 3 =", "def hello", "The little bear", ] for prompt in prompts: print(f"=== Prompt: '{prompt}' ===") sample = generate(prompt, max_tokens=150) print(sample[:400]) print()