""" inference.py Inference (translation) for English→Bengali with full calculation logging. Supports greedy decoding and beam search, showing every step. """ import torch import torch.nn.functional as F import numpy as np import math from typing import Dict, List, Tuple, Optional from transformer import Transformer, CalcLog from vocab import get_vocabs, PAD_IDX, BOS_IDX, EOS_IDX # ───────────────────────────────────────────── # Greedy decoding with full logging # ───────────────────────────────────────────── def greedy_decode( model: Transformer, src: torch.Tensor, max_len: int = 20, device: str = "cpu", log: Optional[CalcLog] = None, ) -> Tuple[List[int], List[Dict]]: model.eval() src_v, tgt_v = get_vocabs() with torch.no_grad(): src_mask = model.make_src_mask(src) # ── Encode once ────────────────────── src_emb = model.src_embed(src) * math.sqrt(model.d_model) enc_x = model.src_pe(src_emb, log=log) enc_attn_weights = [] for i, layer in enumerate(model.encoder_layers): enc_x, ew = layer(enc_x, src_mask=src_mask, log=log if i == 0 else None, layer_idx=i) enc_attn_weights.append(ew.cpu().numpy()) if log: log.log("INFERENCE_ENCODER_done", enc_x[0, :, :8], note="Encoder finished. Output K,V will be reused for every decoder step.") # ── Auto-regressive decode ──────────── generated = [BOS_IDX] step_logs = [] for step in range(max_len): tgt_so_far = torch.tensor([generated], dtype=torch.long, device=device) tgt_mask = model.make_tgt_mask(tgt_so_far) tgt_emb = model.tgt_embed(tgt_so_far) * math.sqrt(model.d_model) dec_x = model.tgt_pe(tgt_emb) step_dec_cross = [] for i, layer in enumerate(model.decoder_layers): do_log = (log is not None) and (step < 3) and (i == 0) if do_log: log.log(f"INFERENCE_step{step}_dec_input", dec_x[0, :, :8], note=f"Decoder input at step {step}: tokens so far = " f"{tgt_v.tokens(generated)}") dec_x, mw, cw = layer( dec_x, enc_x, tgt_mask=tgt_mask, src_mask=src_mask, log=log if do_log else None, layer_idx=i, ) step_dec_cross.append(cw.cpu().numpy()) # Only look at last position last_logits = model.output_linear(dec_x[:, -1, :]) # (1, V) probs = F.softmax(last_logits, dim=-1)[0] # Top-5 predictions top5_probs, top5_ids = probs.topk(5) top5 = [ {"token": tgt_v.idx2token.get(idx.item(), "?"), "id": idx.item(), "prob": round(prob.item(), 4)} for prob, idx in zip(top5_probs, top5_ids) ] # Greedy: pick highest next_token = top5_ids[0].item() step_info = { "step": step, "tokens_so_far": tgt_v.tokens(generated), "top5": top5, "chosen_token": tgt_v.idx2token.get(next_token, "?"), "chosen_id": next_token, "chosen_prob": round(top5_probs[0].item(), 4), "cross_attn": step_dec_cross[0][0].tolist() if step_dec_cross else None, } step_logs.append(step_info) if log and step < 3: log.log(f"INFERENCE_step{step}_top5", top5, formula="P(next_token) = softmax(W_out · dec_out[-1])", note=f"Step {step}: top-5 candidates. Chosen: {step_info['chosen_token']} ({step_info['chosen_prob']:.4f})") generated.append(next_token) if next_token == EOS_IDX: break return generated, step_logs # ───────────────────────────────────────────── # Beam search # ───────────────────────────────────────────── def beam_search( model: Transformer, src: torch.Tensor, beam_size: int = 3, max_len: int = 20, device: str = "cpu", log: Optional[CalcLog] = None, ) -> Tuple[List[int], List[Dict]]: model.eval() src_v, tgt_v = get_vocabs() with torch.no_grad(): src_mask = model.make_src_mask(src) # Encode (with logging, same as greedy) src_emb = model.src_embed(src) * math.sqrt(model.d_model) enc_x = model.src_pe(src_emb, log=log) for i, layer in enumerate(model.encoder_layers): enc_x, _ = layer(enc_x, src_mask=src_mask, log=log if i == 0 else None, layer_idx=i) if log: log.log("INFERENCE_ENCODER_done", enc_x[0, :, :8], note="Encoder done. K,V reused for every beam decode step.") # Beams: list of (score, token_ids) beams = [(0.0, [BOS_IDX])] completed = [] step_logs = [] # greedy-compatible format for decode_steps_html for step in range(max_len): if not beams: break candidates = [] best_cross_attn = None # capture from top beam only for beam_idx, (score, tokens) in enumerate(beams): tgt_t = torch.tensor([tokens], dtype=torch.long, device=device) tgt_mask = model.make_tgt_mask(tgt_t) tgt_emb = model.tgt_embed(tgt_t) * math.sqrt(model.d_model) dec_x = model.tgt_pe(tgt_emb) step_dec_cross = [] for i, layer in enumerate(model.decoder_layers): do_log = (log is not None) and (step < 3) and (i == 0) and (beam_idx == 0) dec_x, _, cw = layer(dec_x, enc_x, tgt_mask=tgt_mask, src_mask=src_mask, log=log if do_log else None, layer_idx=i) step_dec_cross.append(cw.cpu().numpy()) if beam_idx == 0: best_cross_attn = step_dec_cross last_logits = model.output_linear(dec_x[:, -1, :]) log_probs = F.log_softmax(last_logits, dim=-1)[0] top_lp, top_id = log_probs.topk(beam_size) for lp, tid in zip(top_lp, top_id): candidates.append((score + lp.item(), tokens + [tid.item()])) # Sort all candidates candidates.sort(key=lambda x: x[0], reverse=True) # Build greedy-compatible step_info from top candidates tokens_so_far = tgt_v.tokens(beams[0][1]) top5 = [ { "token": tgt_v.idx2token.get(toks[-1], "?"), "id": toks[-1], "prob": round(math.exp(max(sc / max(len(toks) - 1, 1), -20)), 4), } for sc, toks in candidates[:5] ] best_sc, best_toks = candidates[0] if candidates else (0.0, [BOS_IDX, EOS_IDX]) chosen_id = best_toks[-1] # cross-attn: head 0, last position → [T_src] cross_attn = None if best_cross_attn: attn = best_cross_attn[0][0] # (4, step+1, T_src) after [0]=batch cross_attn = attn.tolist() step_logs.append({ "step": step, "tokens_so_far": tokens_so_far, "top5": top5, "chosen_token": tgt_v.idx2token.get(chosen_id, "?"), "chosen_id": chosen_id, "chosen_prob": top5[0]["prob"] if top5 else 0.0, "cross_attn": cross_attn, }) if log and step < 3: log.log(f"BEAM_step{step}_top_candidates", top5, formula="score = Σ log P(token_i | prev, src)", note=f"Step {step}: top beam candidates. Best: '{top5[0]['token'] if top5 else '?'}'") # Prune into next beams beams = [] for sc, toks in candidates[:beam_size * 2]: if toks[-1] == EOS_IDX: completed.append((sc / len(toks), toks)) elif len(beams) < beam_size: beams.append((sc, toks)) if len(completed) >= beam_size: break if completed: best = max(completed, key=lambda x: x[0]) return best[1], step_logs elif beams: return beams[0][1] + [EOS_IDX], step_logs else: return [BOS_IDX, EOS_IDX], step_logs # ───────────────────────────────────────────── # Full inference pipeline with visualization # ───────────────────────────────────────────── def visualize_inference( model: Transformer, en_sentence: str, device: str = "cpu", decode_method: str = "greedy", ) -> Dict: src_v, tgt_v = get_vocabs() log = CalcLog() src_ids = src_v.encode(en_sentence) log.log("INFERENCE_TOKENIZATION", { "sentence": en_sentence, "tokens": en_sentence.lower().split(), "ids": src_ids, }, formula="word → vocab_id lookup", note="No ground-truth Bengali needed — model generates from scratch") src = torch.tensor([src_ids], dtype=torch.long, device=device) if decode_method == "beam": output_ids, step_logs = beam_search(model, src, beam_size=3, device=device, log=log) log.log("BEAM_SEARCH_complete", { "method": "beam search (beam=3)", "note": "Explores multiple hypotheses simultaneously — generally better quality" }) else: output_ids, step_logs = greedy_decode(model, src, device=device, log=log) log.log("GREEDY_complete", { "method": "greedy decoding", "note": "Always picks highest probability token — fast but can miss optimal sequences" }) translation = tgt_v.decode(output_ids) output_tokens = tgt_v.tokens(output_ids) log.log("FINAL_TRANSLATION", { "input": en_sentence, "output_ids": output_ids, "output_tokens": output_tokens, "translation": translation, }, note="Complete English→Bengali translation") return { "en_sentence": en_sentence, "translation": translation, "output_tokens": output_tokens, "output_ids": output_ids, "src_tokens": src_v.tokens(src_ids), "step_logs": step_logs, "calc_log": log.to_dict(), "decode_method": decode_method, }