"""Dump Python v18 integer-path first-token prediction for a fixed prompt. Compare to C output on the same prompt with 1 new token.""" import torch import numpy as np from model_v18 import BitLMv18 ckpt = torch.load('/root/bitnet1/ckpt/v18_binint_last.pt', map_location='cpu', weights_only=False) cfg = ckpt['args'] m = BitLMv18(vocab_size=cfg['vocab_size'], d_model=cfg['d_model'], n_layers=cfg['n_layers'], n_heads=cfg['n_heads'], d_ff=cfg['d_ff'], max_seq_len=cfg['seq_len']) m.load_state_dict(ckpt['model']) m.eval() prompt = 'Once upon a time, there was a little girl named' ids = torch.tensor([[b for b in prompt.encode()]], dtype=torch.long) print(f"Prompt len: {ids.shape[1]}, prompt: {prompt!r}") with torch.no_grad(): pred, int_logits = m.forward_bin_eval_argmax_next(ids) next_id = int(pred[0, -1].item()) print(f"Next token ID (integer path): {next_id} = {chr(next_id)!r}") # Top-5 logits scores, indices = torch.topk(int_logits[0, -1], 5) for i, (s, idx) in enumerate(zip(scores.tolist(), indices.tolist())): print(f" rank {i}: id={idx} ({chr(idx)!r}) logit={s}")