File size: 1,103 Bytes
4754707
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
"""Dump Python v18 integer-path first-token prediction for a fixed prompt.
Compare to C output on the same prompt with 1 new token."""
import torch
import numpy as np
from model_v18 import BitLMv18

ckpt = torch.load('/root/bitnet1/ckpt/v18_binint_last.pt', map_location='cpu', weights_only=False)
cfg = ckpt['args']
m = BitLMv18(vocab_size=cfg['vocab_size'], d_model=cfg['d_model'], n_layers=cfg['n_layers'],
             n_heads=cfg['n_heads'], d_ff=cfg['d_ff'], max_seq_len=cfg['seq_len'])
m.load_state_dict(ckpt['model'])
m.eval()

prompt = 'Once upon a time, there was a little girl named'
ids = torch.tensor([[b for b in prompt.encode()]], dtype=torch.long)
print(f"Prompt len: {ids.shape[1]}, prompt: {prompt!r}")

with torch.no_grad():
    pred, int_logits = m.forward_bin_eval_argmax_next(ids)
next_id = int(pred[0, -1].item())
print(f"Next token ID (integer path): {next_id} = {chr(next_id)!r}")

# Top-5 logits
scores, indices = torch.topk(int_logits[0, -1], 5)
for i, (s, idx) in enumerate(zip(scores.tolist(), indices.tolist())):
    print(f"  rank {i}: id={idx} ({chr(idx)!r}) logit={s}")