|
|
"""
|
|
|
Evaluation module β greedy / beam-search decoding + chrF scoring.
|
|
|
=================================================================
|
|
|
Provides:
|
|
|
β’ ``greedy_decode`` β auto-regressive greedy decoding.
|
|
|
β’ ``beam_search_decode`` β beam search with length normalisation.
|
|
|
β’ ``translate`` β end-to-end: raw English string β Malay string.
|
|
|
β’ ``compute_chrf`` β corpus-level chrF score via *sacrebleu*.
|
|
|
β’ ``evaluate`` β decode the full validation set, compute chrF,
|
|
|
and print sample translations.
|
|
|
"""
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
import re
|
|
|
from typing import List, Optional
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from tokenizers import Tokenizer
|
|
|
|
|
|
import sacrebleu
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def postprocess_translation(text: str) -> str:
|
|
|
"""
|
|
|
Clean up raw tokenizer decode output:
|
|
|
1. Remove spaces before punctuation ( ", tuan ." β ", tuan.")
|
|
|
2. Remove spaces after opening brackets/quotes
|
|
|
3. Remove spaces before closing brackets/quotes
|
|
|
4. Capitalise the first letter
|
|
|
5. Collapse multiple spaces
|
|
|
"""
|
|
|
|
|
|
text = re.sub(r'\s+([.,?!;:)\]}"\'β¦])', r'\1', text)
|
|
|
|
|
|
text = re.sub(r'([(\[{"\'])\s+', r'\1', text)
|
|
|
|
|
|
text = re.sub(r'\s*-\s*', '-', text)
|
|
|
|
|
|
text = re.sub(r'\s{2,}', ' ', text)
|
|
|
|
|
|
text = text.strip()
|
|
|
if text:
|
|
|
text = text[0].upper() + text[1:]
|
|
|
return text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def greedy_decode(
|
|
|
model: nn.Module,
|
|
|
src: torch.Tensor,
|
|
|
bos_id: int,
|
|
|
eos_id: int,
|
|
|
pad_id: int = 0,
|
|
|
max_len: int = 128,
|
|
|
) -> torch.Tensor:
|
|
|
"""
|
|
|
Auto-regressive greedy decoding for a single source sequence.
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
model : TransformerTranslator
|
|
|
src : (1, src_len) source token IDs.
|
|
|
bos_id : beginning-of-sentence token ID.
|
|
|
eos_id : end-of-sentence token ID.
|
|
|
pad_id : padding token ID.
|
|
|
max_len : maximum decoding steps.
|
|
|
|
|
|
Returns
|
|
|
-------
|
|
|
(1, out_len) generated token IDs (including [BOS], up to [EOS]).
|
|
|
"""
|
|
|
device = src.device
|
|
|
model.eval()
|
|
|
|
|
|
|
|
|
src_pad_mask = (src == pad_id)
|
|
|
memory = model.encode(src, src_key_padding_mask=src_pad_mask)
|
|
|
|
|
|
|
|
|
ys = torch.tensor([[bos_id]], dtype=torch.long, device=device)
|
|
|
|
|
|
for _ in range(max_len - 1):
|
|
|
logits = model.decode(
|
|
|
ys, memory,
|
|
|
memory_key_padding_mask=src_pad_mask,
|
|
|
)
|
|
|
next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
|
|
|
ys = torch.cat([ys, next_token], dim=1)
|
|
|
|
|
|
if next_token.item() == eos_id:
|
|
|
break
|
|
|
|
|
|
return ys
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def beam_search_decode(
|
|
|
model: nn.Module,
|
|
|
src: torch.Tensor,
|
|
|
bos_id: int,
|
|
|
eos_id: int,
|
|
|
pad_id: int = 0,
|
|
|
max_len: int = 128,
|
|
|
beam_width: int = 5,
|
|
|
length_penalty: float = 0.6,
|
|
|
) -> torch.Tensor:
|
|
|
"""
|
|
|
Beam-search decoding for a single source sequence.
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
model : TransformerTranslator
|
|
|
src : (1, src_len) source token IDs.
|
|
|
bos_id, eos_id, pad_id : special token IDs.
|
|
|
max_len : maximum decoding steps.
|
|
|
beam_width : number of beams to keep at each step.
|
|
|
length_penalty : Ξ± for length normalisation: score / len^Ξ±.
|
|
|
|
|
|
Returns
|
|
|
-------
|
|
|
(1, out_len) best hypothesis token IDs (including [BOS], up to [EOS]).
|
|
|
"""
|
|
|
device = src.device
|
|
|
model.eval()
|
|
|
|
|
|
|
|
|
src_pad_mask = (src == pad_id)
|
|
|
memory = model.encode(src, src_key_padding_mask=src_pad_mask)
|
|
|
|
|
|
|
|
|
beams = [(0.0, [bos_id])]
|
|
|
completed = []
|
|
|
|
|
|
for _ in range(max_len - 1):
|
|
|
candidates = []
|
|
|
for score, tokens in beams:
|
|
|
if tokens[-1] == eos_id:
|
|
|
completed.append((score, tokens))
|
|
|
continue
|
|
|
|
|
|
ys = torch.tensor([tokens], dtype=torch.long, device=device)
|
|
|
logits = model.decode(
|
|
|
ys, memory,
|
|
|
memory_key_padding_mask=src_pad_mask,
|
|
|
)
|
|
|
log_probs = torch.log_softmax(logits[:, -1, :], dim=-1).squeeze(0)
|
|
|
|
|
|
topk_log_probs, topk_ids = log_probs.topk(beam_width)
|
|
|
for k in range(beam_width):
|
|
|
new_score = score + topk_log_probs[k].item()
|
|
|
new_tokens = tokens + [topk_ids[k].item()]
|
|
|
candidates.append((new_score, new_tokens))
|
|
|
|
|
|
if not candidates:
|
|
|
break
|
|
|
|
|
|
|
|
|
candidates.sort(
|
|
|
key=lambda x: x[0] / (len(x[1]) ** length_penalty),
|
|
|
reverse=True,
|
|
|
)
|
|
|
beams = candidates[:beam_width]
|
|
|
|
|
|
|
|
|
if all(b[1][-1] == eos_id for b in beams):
|
|
|
completed.extend(beams)
|
|
|
break
|
|
|
|
|
|
|
|
|
completed.extend(beams)
|
|
|
|
|
|
|
|
|
best = max(
|
|
|
completed,
|
|
|
key=lambda x: x[0] / (len(x[1]) ** length_penalty),
|
|
|
)
|
|
|
return torch.tensor([best[1]], dtype=torch.long, device=device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def translate(
|
|
|
model: nn.Module,
|
|
|
sentence: str,
|
|
|
src_tokenizer: Tokenizer,
|
|
|
tgt_tokenizer: Tokenizer,
|
|
|
bos_id: int,
|
|
|
eos_id: int,
|
|
|
pad_id: int = 0,
|
|
|
max_len: int = 128,
|
|
|
device: Optional[torch.device] = None,
|
|
|
beam_width: int = 1,
|
|
|
length_penalty: float = 0.6,
|
|
|
) -> str:
|
|
|
"""Translate a single English sentence to Malay.
|
|
|
Set beam_width=1 for greedy, >1 for beam search.
|
|
|
"""
|
|
|
if device is None:
|
|
|
device = next(model.parameters()).device
|
|
|
|
|
|
|
|
|
src_ids = src_tokenizer.encode(sentence).ids
|
|
|
src = torch.tensor([src_ids], dtype=torch.long, device=device)
|
|
|
|
|
|
|
|
|
if beam_width > 1:
|
|
|
out_ids = beam_search_decode(
|
|
|
model, src, bos_id, eos_id, pad_id, max_len,
|
|
|
beam_width=beam_width, length_penalty=length_penalty,
|
|
|
)
|
|
|
else:
|
|
|
out_ids = greedy_decode(model, src, bos_id, eos_id, pad_id, max_len)
|
|
|
|
|
|
|
|
|
raw = tgt_tokenizer.decode(out_ids.squeeze(0).tolist(), skip_special_tokens=True)
|
|
|
return postprocess_translation(raw)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_chrf(hypotheses: List[str], references: List[str]) -> sacrebleu.CHRFScore:
|
|
|
"""
|
|
|
Compute corpus-level chrF score.
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
hypotheses : list[str]
|
|
|
System outputs (decoded translations).
|
|
|
references : list[str]
|
|
|
Gold reference translations.
|
|
|
|
|
|
Returns
|
|
|
-------
|
|
|
sacrebleu.CHRFScore β has ``.score`` attribute (0β100 scale).
|
|
|
"""
|
|
|
return sacrebleu.corpus_chrf(hypotheses, [references])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate(
|
|
|
model: nn.Module,
|
|
|
hf_dataset,
|
|
|
src_tokenizer: Tokenizer,
|
|
|
tgt_tokenizer: Tokenizer,
|
|
|
src_lang: str = "en",
|
|
|
tgt_lang: str = "ms",
|
|
|
bos_id: int = 5,
|
|
|
eos_id: int = 6,
|
|
|
pad_id: int = 0,
|
|
|
max_len: int = 128,
|
|
|
device: Optional[torch.device] = None,
|
|
|
num_samples: int = 5,
|
|
|
beam_width: int = 1,
|
|
|
length_penalty: float = 0.6,
|
|
|
) -> float:
|
|
|
"""
|
|
|
Decode every example in *hf_dataset*, compute corpus chrF, and
|
|
|
print ``num_samples`` side-by-side translations.
|
|
|
|
|
|
Set beam_width=1 for greedy, >1 for beam search.
|
|
|
|
|
|
Returns
|
|
|
-------
|
|
|
chrf_score : float (0β100)
|
|
|
"""
|
|
|
if device is None:
|
|
|
device = next(model.parameters()).device
|
|
|
|
|
|
model.eval()
|
|
|
hypotheses: List[str] = []
|
|
|
references: List[str] = []
|
|
|
|
|
|
for i, example in enumerate(hf_dataset):
|
|
|
src_text = example["translation"][src_lang]
|
|
|
ref_text = example["translation"][tgt_lang]
|
|
|
|
|
|
hyp_text = translate(
|
|
|
model, src_text,
|
|
|
src_tokenizer, tgt_tokenizer,
|
|
|
bos_id, eos_id, pad_id, max_len, device,
|
|
|
beam_width=beam_width,
|
|
|
length_penalty=length_penalty,
|
|
|
)
|
|
|
|
|
|
hypotheses.append(hyp_text)
|
|
|
references.append(ref_text)
|
|
|
|
|
|
chrf = compute_chrf(hypotheses, references)
|
|
|
|
|
|
|
|
|
print(f"\n{'='*60}")
|
|
|
print(f"chrF Score: {chrf.score:.2f}")
|
|
|
print(f"{'='*60}")
|
|
|
for i in range(min(num_samples, len(hypotheses))):
|
|
|
src_text = hf_dataset[i]["translation"][src_lang]
|
|
|
print(f"\n[{i}] SRC: {src_text[:120]}")
|
|
|
print(f" REF: {references[i][:120]}")
|
|
|
print(f" HYP: {hypotheses[i][:120]}")
|
|
|
|
|
|
return chrf.score |