""" Evaluation module – greedy / beam-search decoding + chrF scoring. ================================================================= Provides: • ``greedy_decode`` – auto-regressive greedy decoding. • ``beam_search_decode`` – beam search with length normalisation. • ``translate`` – end-to-end: raw English string → Malay string. • ``compute_chrf`` – corpus-level chrF score via *sacrebleu*. • ``evaluate`` – decode the full validation set, compute chrF, and print sample translations. """ from __future__ import annotations import re from typing import List, Optional import torch import torch.nn as nn from tokenizers import Tokenizer import sacrebleu # ────────────────────────────────────────────────────────────────────── # 0. Post-processing: fix tokenizer spacing artefacts # ────────────────────────────────────────────────────────────────────── def postprocess_translation(text: str) -> str: """ Clean up raw tokenizer decode output: 1. Remove spaces before punctuation ( ", tuan ." → ", tuan.") 2. Remove spaces after opening brackets/quotes 3. Remove spaces before closing brackets/quotes 4. Capitalise the first letter 5. Collapse multiple spaces """ # Remove space before punctuation: . , ? ! ; : ) ] } ' " ... text = re.sub(r'\s+([.,?!;:)\]}"\'…])', r'\1', text) # Remove space after opening brackets/quotes text = re.sub(r'([(\[{"\'])\s+', r'\1', text) # Fix spaced hyphens in compound words (e.g. "brother - in - arms" → "brother-in-arms") text = re.sub(r'\s*-\s*', '-', text) # Collapse multiple spaces text = re.sub(r'\s{2,}', ' ', text) # Strip and capitalise text = text.strip() if text: text = text[0].upper() + text[1:] return text # ────────────────────────────────────────────────────────────────────── # 1. Greedy decoding # ────────────────────────────────────────────────────────────────────── @torch.no_grad() def greedy_decode( model: nn.Module, src: torch.Tensor, bos_id: int, eos_id: int, pad_id: int = 0, max_len: int = 128, ) -> torch.Tensor: """ Auto-regressive greedy decoding for a single source sequence. Parameters ---------- model : TransformerTranslator src : (1, src_len) source token IDs. bos_id : beginning-of-sentence token ID. eos_id : end-of-sentence token ID. pad_id : padding token ID. max_len : maximum decoding steps. Returns ------- (1, out_len) generated token IDs (including [BOS], up to [EOS]). """ device = src.device model.eval() # Encode source once src_pad_mask = (src == pad_id) memory = model.encode(src, src_key_padding_mask=src_pad_mask) # Start with [BOS] ys = torch.tensor([[bos_id]], dtype=torch.long, device=device) for _ in range(max_len - 1): logits = model.decode( ys, memory, memory_key_padding_mask=src_pad_mask, ) # (1, cur_len, vocab) next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True) # (1, 1) ys = torch.cat([ys, next_token], dim=1) if next_token.item() == eos_id: break return ys # ────────────────────────────────────────────────────────────────────── # 1b. Beam-search decoding # ────────────────────────────────────────────────────────────────────── @torch.no_grad() def beam_search_decode( model: nn.Module, src: torch.Tensor, bos_id: int, eos_id: int, pad_id: int = 0, max_len: int = 128, beam_width: int = 5, length_penalty: float = 0.6, ) -> torch.Tensor: """ Beam-search decoding for a single source sequence. Parameters ---------- model : TransformerTranslator src : (1, src_len) source token IDs. bos_id, eos_id, pad_id : special token IDs. max_len : maximum decoding steps. beam_width : number of beams to keep at each step. length_penalty : α for length normalisation: score / len^α. Returns ------- (1, out_len) best hypothesis token IDs (including [BOS], up to [EOS]). """ device = src.device model.eval() # Encode source once src_pad_mask = (src == pad_id) memory = model.encode(src, src_key_padding_mask=src_pad_mask) # Each beam: (log_prob, token_ids_list) beams = [(0.0, [bos_id])] completed = [] for _ in range(max_len - 1): candidates = [] for score, tokens in beams: if tokens[-1] == eos_id: completed.append((score, tokens)) continue ys = torch.tensor([tokens], dtype=torch.long, device=device) logits = model.decode( ys, memory, memory_key_padding_mask=src_pad_mask, ) # (1, cur_len, vocab) log_probs = torch.log_softmax(logits[:, -1, :], dim=-1).squeeze(0) topk_log_probs, topk_ids = log_probs.topk(beam_width) for k in range(beam_width): new_score = score + topk_log_probs[k].item() new_tokens = tokens + [topk_ids[k].item()] candidates.append((new_score, new_tokens)) if not candidates: break # Keep top beam_width by length-normalised score candidates.sort( key=lambda x: x[0] / (len(x[1]) ** length_penalty), reverse=True, ) beams = candidates[:beam_width] # Early exit if all beams have finished if all(b[1][-1] == eos_id for b in beams): completed.extend(beams) break # Add any remaining beams completed.extend(beams) # Pick best by length-normalised score best = max( completed, key=lambda x: x[0] / (len(x[1]) ** length_penalty), ) return torch.tensor([best[1]], dtype=torch.long, device=device) # ────────────────────────────────────────────────────────────────────── # 2. Translate a raw string # ────────────────────────────────────────────────────────────────────── def translate( model: nn.Module, sentence: str, src_tokenizer: Tokenizer, tgt_tokenizer: Tokenizer, bos_id: int, eos_id: int, pad_id: int = 0, max_len: int = 128, device: Optional[torch.device] = None, beam_width: int = 1, length_penalty: float = 0.6, ) -> str: """Translate a single English sentence to Malay. Set beam_width=1 for greedy, >1 for beam search. """ if device is None: device = next(model.parameters()).device # Tokenise source src_ids = src_tokenizer.encode(sentence).ids src = torch.tensor([src_ids], dtype=torch.long, device=device) # Decode if beam_width > 1: out_ids = beam_search_decode( model, src, bos_id, eos_id, pad_id, max_len, beam_width=beam_width, length_penalty=length_penalty, ) else: out_ids = greedy_decode(model, src, bos_id, eos_id, pad_id, max_len) # Convert IDs → string (skip special tokens) + clean up spacing raw = tgt_tokenizer.decode(out_ids.squeeze(0).tolist(), skip_special_tokens=True) return postprocess_translation(raw) # ────────────────────────────────────────────────────────────────────── # 3. Corpus-level chrF # ────────────────────────────────────────────────────────────────────── def compute_chrf(hypotheses: List[str], references: List[str]) -> sacrebleu.CHRFScore: """ Compute corpus-level chrF score. Parameters ---------- hypotheses : list[str] System outputs (decoded translations). references : list[str] Gold reference translations. Returns ------- sacrebleu.CHRFScore – has ``.score`` attribute (0–100 scale). """ return sacrebleu.corpus_chrf(hypotheses, [references]) # ────────────────────────────────────────────────────────────────────── # 4. Full evaluation driver # ────────────────────────────────────────────────────────────────────── def evaluate( model: nn.Module, hf_dataset, src_tokenizer: Tokenizer, tgt_tokenizer: Tokenizer, src_lang: str = "en", tgt_lang: str = "ms", bos_id: int = 5, eos_id: int = 6, pad_id: int = 0, max_len: int = 128, device: Optional[torch.device] = None, num_samples: int = 5, beam_width: int = 1, length_penalty: float = 0.6, ) -> float: """ Decode every example in *hf_dataset*, compute corpus chrF, and print ``num_samples`` side-by-side translations. Set beam_width=1 for greedy, >1 for beam search. Returns ------- chrf_score : float (0–100) """ if device is None: device = next(model.parameters()).device model.eval() hypotheses: List[str] = [] references: List[str] = [] for i, example in enumerate(hf_dataset): src_text = example["translation"][src_lang] ref_text = example["translation"][tgt_lang] hyp_text = translate( model, src_text, src_tokenizer, tgt_tokenizer, bos_id, eos_id, pad_id, max_len, device, beam_width=beam_width, length_penalty=length_penalty, ) hypotheses.append(hyp_text) references.append(ref_text) chrf = compute_chrf(hypotheses, references) # Print samples print(f"\n{'='*60}") print(f"chrF Score: {chrf.score:.2f}") print(f"{'='*60}") for i in range(min(num_samples, len(hypotheses))): src_text = hf_dataset[i]["translation"][src_lang] print(f"\n[{i}] SRC: {src_text[:120]}") print(f" REF: {references[i][:120]}") print(f" HYP: {hypotheses[i][:120]}") return chrf.score