| |
| """Generate text samples from the trained WrinkleBrane model.""" |
| import sys |
| sys.path.insert(0, 'src') |
| import torch |
| from wrinklebrane.standalone_model import WrinkleBraneModel |
| from wrinklebrane.data import encode_bytes, decode_tokens, BOS_ID |
|
|
| |
| print("Loading best checkpoint...") |
| ckpt = torch.load('checkpoints/best_model.pt', weights_only=False, map_location='cpu') |
| config = ckpt['config'] |
| model = WrinkleBraneModel(config) |
| model.load_state_dict(ckpt['model_state_dict']) |
| model.eval() |
|
|
| print(f"Model: {config.d_model}d, {config.n_layers}L, {config.n_heads}H") |
| print(f"Checkpoint step: {ckpt.get('step')}, val_loss: {ckpt.get('val_loss'):.4f}") |
| print() |
|
|
| def generate(prompt, max_tokens=200, temperature=0.8): |
| tokens = [BOS_ID] + encode_bytes(prompt) |
| input_ids = torch.tensor([tokens], dtype=torch.long) |
|
|
| with torch.no_grad(): |
| logits, states = model.forward_sequential(input_ids) |
|
|
| generated = list(tokens) |
| current = generated[-1] |
|
|
| for _ in range(max_tokens): |
| inp = torch.tensor([[current]], dtype=torch.long) |
| with torch.no_grad(): |
| logits, states = model.forward_sequential(inp, states) |
| probs = torch.softmax(logits[0, -1] / temperature, dim=-1) |
| current = torch.multinomial(probs, 1).item() |
| generated.append(current) |
| if current == 2: |
| break |
|
|
| return decode_tokens(generated) |
|
|
| |
| prompts = [ |
| "Once upon a time", |
| "The cat sat on", |
| "2 + 3 =", |
| "def hello", |
| "The little bear", |
| ] |
|
|
| for prompt in prompts: |
| print(f"=== Prompt: '{prompt}' ===") |
| sample = generate(prompt, max_tokens=150) |
| print(sample[:400]) |
| print() |
|
|