| """Dump Python v18 integer-path first-token prediction for a fixed prompt. |
| Compare to C output on the same prompt with 1 new token.""" |
| import torch |
| import numpy as np |
| from model_v18 import BitLMv18 |
|
|
| ckpt = torch.load('/root/bitnet1/ckpt/v18_binint_last.pt', map_location='cpu', weights_only=False) |
| cfg = ckpt['args'] |
| m = BitLMv18(vocab_size=cfg['vocab_size'], d_model=cfg['d_model'], n_layers=cfg['n_layers'], |
| n_heads=cfg['n_heads'], d_ff=cfg['d_ff'], max_seq_len=cfg['seq_len']) |
| m.load_state_dict(ckpt['model']) |
| m.eval() |
|
|
| prompt = 'Once upon a time, there was a little girl named' |
| ids = torch.tensor([[b for b in prompt.encode()]], dtype=torch.long) |
| print(f"Prompt len: {ids.shape[1]}, prompt: {prompt!r}") |
|
|
| with torch.no_grad(): |
| pred, int_logits = m.forward_bin_eval_argmax_next(ids) |
| next_id = int(pred[0, -1].item()) |
| print(f"Next token ID (integer path): {next_id} = {chr(next_id)!r}") |
|
|
| |
| scores, indices = torch.topk(int_logits[0, -1], 5) |
| for i, (s, idx) in enumerate(zip(scores.tolist(), indices.tolist())): |
| print(f" rank {i}: id={idx} ({chr(idx)!r}) logit={s}") |
|
|