Transformer-Visualizer / inference.py
priyadip
Fix beam search: greedy-compatible step_logs, encoder logging, cross-attn capture
7c3f7b0
"""
inference.py
Inference (translation) for English→Bengali with full calculation logging.
Supports greedy decoding and beam search, showing every step.
"""
import torch
import torch.nn.functional as F
import numpy as np
import math
from typing import Dict, List, Tuple, Optional
from transformer import Transformer, CalcLog
from vocab import get_vocabs, PAD_IDX, BOS_IDX, EOS_IDX
# ─────────────────────────────────────────────
# Greedy decoding with full logging
# ─────────────────────────────────────────────
def greedy_decode(
model: Transformer,
src: torch.Tensor,
max_len: int = 20,
device: str = "cpu",
log: Optional[CalcLog] = None,
) -> Tuple[List[int], List[Dict]]:
model.eval()
src_v, tgt_v = get_vocabs()
with torch.no_grad():
src_mask = model.make_src_mask(src)
# ── Encode once ──────────────────────
src_emb = model.src_embed(src) * math.sqrt(model.d_model)
enc_x = model.src_pe(src_emb, log=log)
enc_attn_weights = []
for i, layer in enumerate(model.encoder_layers):
enc_x, ew = layer(enc_x, src_mask=src_mask,
log=log if i == 0 else None, layer_idx=i)
enc_attn_weights.append(ew.cpu().numpy())
if log:
log.log("INFERENCE_ENCODER_done", enc_x[0, :, :8],
note="Encoder finished. Output K,V will be reused for every decoder step.")
# ── Auto-regressive decode ────────────
generated = [BOS_IDX]
step_logs = []
for step in range(max_len):
tgt_so_far = torch.tensor([generated], dtype=torch.long, device=device)
tgt_mask = model.make_tgt_mask(tgt_so_far)
tgt_emb = model.tgt_embed(tgt_so_far) * math.sqrt(model.d_model)
dec_x = model.tgt_pe(tgt_emb)
step_dec_cross = []
for i, layer in enumerate(model.decoder_layers):
do_log = (log is not None) and (step < 3) and (i == 0)
if do_log:
log.log(f"INFERENCE_step{step}_dec_input", dec_x[0, :, :8],
note=f"Decoder input at step {step}: tokens so far = "
f"{tgt_v.tokens(generated)}")
dec_x, mw, cw = layer(
dec_x, enc_x,
tgt_mask=tgt_mask, src_mask=src_mask,
log=log if do_log else None,
layer_idx=i,
)
step_dec_cross.append(cw.cpu().numpy())
# Only look at last position
last_logits = model.output_linear(dec_x[:, -1, :]) # (1, V)
probs = F.softmax(last_logits, dim=-1)[0]
# Top-5 predictions
top5_probs, top5_ids = probs.topk(5)
top5 = [
{"token": tgt_v.idx2token.get(idx.item(), "?"),
"id": idx.item(),
"prob": round(prob.item(), 4)}
for prob, idx in zip(top5_probs, top5_ids)
]
# Greedy: pick highest
next_token = top5_ids[0].item()
step_info = {
"step": step,
"tokens_so_far": tgt_v.tokens(generated),
"top5": top5,
"chosen_token": tgt_v.idx2token.get(next_token, "?"),
"chosen_id": next_token,
"chosen_prob": round(top5_probs[0].item(), 4),
"cross_attn": step_dec_cross[0][0].tolist()
if step_dec_cross else None,
}
step_logs.append(step_info)
if log and step < 3:
log.log(f"INFERENCE_step{step}_top5", top5,
formula="P(next_token) = softmax(W_out · dec_out[-1])",
note=f"Step {step}: top-5 candidates. Chosen: {step_info['chosen_token']} ({step_info['chosen_prob']:.4f})")
generated.append(next_token)
if next_token == EOS_IDX:
break
return generated, step_logs
# ─────────────────────────────────────────────
# Beam search
# ─────────────────────────────────────────────
def beam_search(
model: Transformer,
src: torch.Tensor,
beam_size: int = 3,
max_len: int = 20,
device: str = "cpu",
log: Optional[CalcLog] = None,
) -> Tuple[List[int], List[Dict]]:
model.eval()
src_v, tgt_v = get_vocabs()
with torch.no_grad():
src_mask = model.make_src_mask(src)
# Encode (with logging, same as greedy)
src_emb = model.src_embed(src) * math.sqrt(model.d_model)
enc_x = model.src_pe(src_emb, log=log)
for i, layer in enumerate(model.encoder_layers):
enc_x, _ = layer(enc_x, src_mask=src_mask,
log=log if i == 0 else None, layer_idx=i)
if log:
log.log("INFERENCE_ENCODER_done", enc_x[0, :, :8],
note="Encoder done. K,V reused for every beam decode step.")
# Beams: list of (score, token_ids)
beams = [(0.0, [BOS_IDX])]
completed = []
step_logs = [] # greedy-compatible format for decode_steps_html
for step in range(max_len):
if not beams:
break
candidates = []
best_cross_attn = None # capture from top beam only
for beam_idx, (score, tokens) in enumerate(beams):
tgt_t = torch.tensor([tokens], dtype=torch.long, device=device)
tgt_mask = model.make_tgt_mask(tgt_t)
tgt_emb = model.tgt_embed(tgt_t) * math.sqrt(model.d_model)
dec_x = model.tgt_pe(tgt_emb)
step_dec_cross = []
for i, layer in enumerate(model.decoder_layers):
do_log = (log is not None) and (step < 3) and (i == 0) and (beam_idx == 0)
dec_x, _, cw = layer(dec_x, enc_x,
tgt_mask=tgt_mask, src_mask=src_mask,
log=log if do_log else None, layer_idx=i)
step_dec_cross.append(cw.cpu().numpy())
if beam_idx == 0:
best_cross_attn = step_dec_cross
last_logits = model.output_linear(dec_x[:, -1, :])
log_probs = F.log_softmax(last_logits, dim=-1)[0]
top_lp, top_id = log_probs.topk(beam_size)
for lp, tid in zip(top_lp, top_id):
candidates.append((score + lp.item(), tokens + [tid.item()]))
# Sort all candidates
candidates.sort(key=lambda x: x[0], reverse=True)
# Build greedy-compatible step_info from top candidates
tokens_so_far = tgt_v.tokens(beams[0][1])
top5 = [
{
"token": tgt_v.idx2token.get(toks[-1], "?"),
"id": toks[-1],
"prob": round(math.exp(max(sc / max(len(toks) - 1, 1), -20)), 4),
}
for sc, toks in candidates[:5]
]
best_sc, best_toks = candidates[0] if candidates else (0.0, [BOS_IDX, EOS_IDX])
chosen_id = best_toks[-1]
# cross-attn: head 0, last position → [T_src]
cross_attn = None
if best_cross_attn:
attn = best_cross_attn[0][0] # (4, step+1, T_src) after [0]=batch
cross_attn = attn.tolist()
step_logs.append({
"step": step,
"tokens_so_far": tokens_so_far,
"top5": top5,
"chosen_token": tgt_v.idx2token.get(chosen_id, "?"),
"chosen_id": chosen_id,
"chosen_prob": top5[0]["prob"] if top5 else 0.0,
"cross_attn": cross_attn,
})
if log and step < 3:
log.log(f"BEAM_step{step}_top_candidates", top5,
formula="score = Σ log P(token_i | prev, src)",
note=f"Step {step}: top beam candidates. Best: '{top5[0]['token'] if top5 else '?'}'")
# Prune into next beams
beams = []
for sc, toks in candidates[:beam_size * 2]:
if toks[-1] == EOS_IDX:
completed.append((sc / len(toks), toks))
elif len(beams) < beam_size:
beams.append((sc, toks))
if len(completed) >= beam_size:
break
if completed:
best = max(completed, key=lambda x: x[0])
return best[1], step_logs
elif beams:
return beams[0][1] + [EOS_IDX], step_logs
else:
return [BOS_IDX, EOS_IDX], step_logs
# ─────────────────────────────────────────────
# Full inference pipeline with visualization
# ─────────────────────────────────────────────
def visualize_inference(
model: Transformer,
en_sentence: str,
device: str = "cpu",
decode_method: str = "greedy",
) -> Dict:
src_v, tgt_v = get_vocabs()
log = CalcLog()
src_ids = src_v.encode(en_sentence)
log.log("INFERENCE_TOKENIZATION", {
"sentence": en_sentence,
"tokens": en_sentence.lower().split(),
"ids": src_ids,
}, formula="word → vocab_id lookup",
note="No ground-truth Bengali needed — model generates from scratch")
src = torch.tensor([src_ids], dtype=torch.long, device=device)
if decode_method == "beam":
output_ids, step_logs = beam_search(model, src, beam_size=3,
device=device, log=log)
log.log("BEAM_SEARCH_complete", {
"method": "beam search (beam=3)",
"note": "Explores multiple hypotheses simultaneously — generally better quality"
})
else:
output_ids, step_logs = greedy_decode(model, src, device=device, log=log)
log.log("GREEDY_complete", {
"method": "greedy decoding",
"note": "Always picks highest probability token — fast but can miss optimal sequences"
})
translation = tgt_v.decode(output_ids)
output_tokens = tgt_v.tokens(output_ids)
log.log("FINAL_TRANSLATION", {
"input": en_sentence,
"output_ids": output_ids,
"output_tokens": output_tokens,
"translation": translation,
}, note="Complete English→Bengali translation")
return {
"en_sentence": en_sentence,
"translation": translation,
"output_tokens": output_tokens,
"output_ids": output_ids,
"src_tokens": src_v.tokens(src_ids),
"step_logs": step_logs,
"calc_log": log.to_dict(),
"decode_method": decode_method,
}