AstralPotato's picture
Upload en-ms Transformer (6+2 Tied, 16K BPE, chrF 45.62)
e7f17a4 verified
"""
Evaluation module – greedy / beam-search decoding + chrF scoring.
=================================================================
Provides:
β€’ ``greedy_decode`` – auto-regressive greedy decoding.
β€’ ``beam_search_decode`` – beam search with length normalisation.
β€’ ``translate`` – end-to-end: raw English string β†’ Malay string.
β€’ ``compute_chrf`` – corpus-level chrF score via *sacrebleu*.
β€’ ``evaluate`` – decode the full validation set, compute chrF,
and print sample translations.
"""
from __future__ import annotations
import re
from typing import List, Optional
import torch
import torch.nn as nn
from tokenizers import Tokenizer
import sacrebleu
# ──────────────────────────────────────────────────────────────────────
# 0. Post-processing: fix tokenizer spacing artefacts
# ──────────────────────────────────────────────────────────────────────
def postprocess_translation(text: str) -> str:
"""
Clean up raw tokenizer decode output:
1. Remove spaces before punctuation ( ", tuan ." β†’ ", tuan.")
2. Remove spaces after opening brackets/quotes
3. Remove spaces before closing brackets/quotes
4. Capitalise the first letter
5. Collapse multiple spaces
"""
# Remove space before punctuation: . , ? ! ; : ) ] } ' " ...
text = re.sub(r'\s+([.,?!;:)\]}"\'…])', r'\1', text)
# Remove space after opening brackets/quotes
text = re.sub(r'([(\[{"\'])\s+', r'\1', text)
# Fix spaced hyphens in compound words (e.g. "brother - in - arms" β†’ "brother-in-arms")
text = re.sub(r'\s*-\s*', '-', text)
# Collapse multiple spaces
text = re.sub(r'\s{2,}', ' ', text)
# Strip and capitalise
text = text.strip()
if text:
text = text[0].upper() + text[1:]
return text
# ──────────────────────────────────────────────────────────────────────
# 1. Greedy decoding
# ──────────────────────────────────────────────────────────────────────
@torch.no_grad()
def greedy_decode(
model: nn.Module,
src: torch.Tensor,
bos_id: int,
eos_id: int,
pad_id: int = 0,
max_len: int = 128,
) -> torch.Tensor:
"""
Auto-regressive greedy decoding for a single source sequence.
Parameters
----------
model : TransformerTranslator
src : (1, src_len) source token IDs.
bos_id : beginning-of-sentence token ID.
eos_id : end-of-sentence token ID.
pad_id : padding token ID.
max_len : maximum decoding steps.
Returns
-------
(1, out_len) generated token IDs (including [BOS], up to [EOS]).
"""
device = src.device
model.eval()
# Encode source once
src_pad_mask = (src == pad_id)
memory = model.encode(src, src_key_padding_mask=src_pad_mask)
# Start with [BOS]
ys = torch.tensor([[bos_id]], dtype=torch.long, device=device)
for _ in range(max_len - 1):
logits = model.decode(
ys, memory,
memory_key_padding_mask=src_pad_mask,
) # (1, cur_len, vocab)
next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True) # (1, 1)
ys = torch.cat([ys, next_token], dim=1)
if next_token.item() == eos_id:
break
return ys
# ──────────────────────────────────────────────────────────────────────
# 1b. Beam-search decoding
# ──────────────────────────────────────────────────────────────────────
@torch.no_grad()
def beam_search_decode(
model: nn.Module,
src: torch.Tensor,
bos_id: int,
eos_id: int,
pad_id: int = 0,
max_len: int = 128,
beam_width: int = 5,
length_penalty: float = 0.6,
) -> torch.Tensor:
"""
Beam-search decoding for a single source sequence.
Parameters
----------
model : TransformerTranslator
src : (1, src_len) source token IDs.
bos_id, eos_id, pad_id : special token IDs.
max_len : maximum decoding steps.
beam_width : number of beams to keep at each step.
length_penalty : Ξ± for length normalisation: score / len^Ξ±.
Returns
-------
(1, out_len) best hypothesis token IDs (including [BOS], up to [EOS]).
"""
device = src.device
model.eval()
# Encode source once
src_pad_mask = (src == pad_id)
memory = model.encode(src, src_key_padding_mask=src_pad_mask)
# Each beam: (log_prob, token_ids_list)
beams = [(0.0, [bos_id])]
completed = []
for _ in range(max_len - 1):
candidates = []
for score, tokens in beams:
if tokens[-1] == eos_id:
completed.append((score, tokens))
continue
ys = torch.tensor([tokens], dtype=torch.long, device=device)
logits = model.decode(
ys, memory,
memory_key_padding_mask=src_pad_mask,
) # (1, cur_len, vocab)
log_probs = torch.log_softmax(logits[:, -1, :], dim=-1).squeeze(0)
topk_log_probs, topk_ids = log_probs.topk(beam_width)
for k in range(beam_width):
new_score = score + topk_log_probs[k].item()
new_tokens = tokens + [topk_ids[k].item()]
candidates.append((new_score, new_tokens))
if not candidates:
break
# Keep top beam_width by length-normalised score
candidates.sort(
key=lambda x: x[0] / (len(x[1]) ** length_penalty),
reverse=True,
)
beams = candidates[:beam_width]
# Early exit if all beams have finished
if all(b[1][-1] == eos_id for b in beams):
completed.extend(beams)
break
# Add any remaining beams
completed.extend(beams)
# Pick best by length-normalised score
best = max(
completed,
key=lambda x: x[0] / (len(x[1]) ** length_penalty),
)
return torch.tensor([best[1]], dtype=torch.long, device=device)
# ──────────────────────────────────────────────────────────────────────
# 2. Translate a raw string
# ──────────────────────────────────────────────────────────────────────
def translate(
model: nn.Module,
sentence: str,
src_tokenizer: Tokenizer,
tgt_tokenizer: Tokenizer,
bos_id: int,
eos_id: int,
pad_id: int = 0,
max_len: int = 128,
device: Optional[torch.device] = None,
beam_width: int = 1,
length_penalty: float = 0.6,
) -> str:
"""Translate a single English sentence to Malay.
Set beam_width=1 for greedy, >1 for beam search.
"""
if device is None:
device = next(model.parameters()).device
# Tokenise source
src_ids = src_tokenizer.encode(sentence).ids
src = torch.tensor([src_ids], dtype=torch.long, device=device)
# Decode
if beam_width > 1:
out_ids = beam_search_decode(
model, src, bos_id, eos_id, pad_id, max_len,
beam_width=beam_width, length_penalty=length_penalty,
)
else:
out_ids = greedy_decode(model, src, bos_id, eos_id, pad_id, max_len)
# Convert IDs β†’ string (skip special tokens) + clean up spacing
raw = tgt_tokenizer.decode(out_ids.squeeze(0).tolist(), skip_special_tokens=True)
return postprocess_translation(raw)
# ──────────────────────────────────────────────────────────────────────
# 3. Corpus-level chrF
# ──────────────────────────────────────────────────────────────────────
def compute_chrf(hypotheses: List[str], references: List[str]) -> sacrebleu.CHRFScore:
"""
Compute corpus-level chrF score.
Parameters
----------
hypotheses : list[str]
System outputs (decoded translations).
references : list[str]
Gold reference translations.
Returns
-------
sacrebleu.CHRFScore – has ``.score`` attribute (0–100 scale).
"""
return sacrebleu.corpus_chrf(hypotheses, [references])
# ──────────────────────────────────────────────────────────────────────
# 4. Full evaluation driver
# ──────────────────────────────────────────────────────────────────────
def evaluate(
model: nn.Module,
hf_dataset,
src_tokenizer: Tokenizer,
tgt_tokenizer: Tokenizer,
src_lang: str = "en",
tgt_lang: str = "ms",
bos_id: int = 5,
eos_id: int = 6,
pad_id: int = 0,
max_len: int = 128,
device: Optional[torch.device] = None,
num_samples: int = 5,
beam_width: int = 1,
length_penalty: float = 0.6,
) -> float:
"""
Decode every example in *hf_dataset*, compute corpus chrF, and
print ``num_samples`` side-by-side translations.
Set beam_width=1 for greedy, >1 for beam search.
Returns
-------
chrf_score : float (0–100)
"""
if device is None:
device = next(model.parameters()).device
model.eval()
hypotheses: List[str] = []
references: List[str] = []
for i, example in enumerate(hf_dataset):
src_text = example["translation"][src_lang]
ref_text = example["translation"][tgt_lang]
hyp_text = translate(
model, src_text,
src_tokenizer, tgt_tokenizer,
bos_id, eos_id, pad_id, max_len, device,
beam_width=beam_width,
length_penalty=length_penalty,
)
hypotheses.append(hyp_text)
references.append(ref_text)
chrf = compute_chrf(hypotheses, references)
# Print samples
print(f"\n{'='*60}")
print(f"chrF Score: {chrf.score:.2f}")
print(f"{'='*60}")
for i in range(min(num_samples, len(hypotheses))):
src_text = hf_dataset[i]["translation"][src_lang]
print(f"\n[{i}] SRC: {src_text[:120]}")
print(f" REF: {references[i][:120]}")
print(f" HYP: {hypotheses[i][:120]}")
return chrf.score