File size: 1,103 Bytes
4754707 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 | """Dump Python v18 integer-path first-token prediction for a fixed prompt.
Compare to C output on the same prompt with 1 new token."""
import torch
import numpy as np
from model_v18 import BitLMv18
ckpt = torch.load('/root/bitnet1/ckpt/v18_binint_last.pt', map_location='cpu', weights_only=False)
cfg = ckpt['args']
m = BitLMv18(vocab_size=cfg['vocab_size'], d_model=cfg['d_model'], n_layers=cfg['n_layers'],
n_heads=cfg['n_heads'], d_ff=cfg['d_ff'], max_seq_len=cfg['seq_len'])
m.load_state_dict(ckpt['model'])
m.eval()
prompt = 'Once upon a time, there was a little girl named'
ids = torch.tensor([[b for b in prompt.encode()]], dtype=torch.long)
print(f"Prompt len: {ids.shape[1]}, prompt: {prompt!r}")
with torch.no_grad():
pred, int_logits = m.forward_bin_eval_argmax_next(ids)
next_id = int(pred[0, -1].item())
print(f"Next token ID (integer path): {next_id} = {chr(next_id)!r}")
# Top-5 logits
scores, indices = torch.topk(int_logits[0, -1], 5)
for i, (s, idx) in enumerate(zip(scores.tolist(), indices.tolist())):
print(f" rank {i}: id={idx} ({chr(idx)!r}) logit={s}")
|