apoorvrajdev's picture
feat(evaluation): add beam search, metrics pipeline, and stabilized training workflow
91a1214
"""Per-sample inspection utilities for diagnosing weak captions.
The aggregate corpus metric tells you *how bad* the model is; this module
tells you *why*. For each (image, prediction, reference-set) triple it
records per-sample BLEU-4, sentence-level ROUGE-L, the prediction length,
the longest repeated token run, and whether the prediction is empty after
stripping sentinels.
Three failure modes the evaluation pass is trying to surface:
* **Generic captions** — high BLEU-1, low BLEU-4 (n-gram trickle out).
* **Repetition** — large ``repeat_run`` value.
* **Early stopping** — ``length_tokens`` far below reference median.
Output JSONL is intentionally flat (one line per sample) so it can be loaded
with ``pandas.read_json(..., lines=True)`` or grep'd from the shell. The
runner that uses this module writes one such file per evaluation pass
alongside ``metrics.json`` for the same run.
"""
from __future__ import annotations
import json
from collections.abc import Iterable, Sequence
from dataclasses import asdict, dataclass
from itertools import pairwise
from pathlib import Path
from captioning.evaluation.tokenization import strip_sentinels
@dataclass(frozen=True)
class SampleDiagnostics:
"""Inspectable record for one (image, prediction, reference-set) triple."""
image: str
prediction: str
references: list[str]
length_tokens: int
longest_repeat_run: int
sentence_bleu4: float | None
sentence_rouge_l: float | None
flags: list[str]
def _longest_repeat_run(tokens: Sequence[str]) -> int:
"""Return the longest run of immediately-repeated tokens.
Example: ``["a", "a", "a", "dog"]`` -> ``3``. Used to flag the classic
transformer-decoder collapse where the same token is emitted on every step.
"""
if not tokens:
return 0
best = current = 1
for prev, cur in pairwise(tokens):
current = current + 1 if cur == prev else 1
best = max(best, current)
return best
def _sentence_bleu4(prediction: str, references: Sequence[str]) -> float | None:
"""Sentence-level BLEU-4 via sacrebleu's effective-order smoothing."""
try:
import sacrebleu
except ImportError:
return None
if not references or not prediction:
return None
scorer = sacrebleu.metrics.BLEU(effective_order=True, max_ngram_order=4)
return float(scorer.sentence_score(prediction, list(references)).score)
def _sentence_rouge_l(prediction: str, references: Sequence[str]) -> float | None:
"""Best-of-references sentence-level ROUGE-L F-measure (0-100 scale)."""
try:
from rouge_score import rouge_scorer
except ImportError:
return None
valid_refs = [r for r in references if r]
if not valid_refs or not prediction:
return None
scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True)
best = max(scorer.score(r, prediction)["rougeL"].fmeasure for r in valid_refs)
return float(100.0 * best)
def diagnose_sample(
image: str,
prediction: str,
references: Sequence[str],
) -> SampleDiagnostics:
"""Return :class:`SampleDiagnostics` for one prediction-vs-references row."""
pred_clean = strip_sentinels(prediction)
ref_clean = [strip_sentinels(r) for r in references if r]
tokens = pred_clean.split()
flags: list[str] = []
if not pred_clean:
flags.append("empty")
if len(tokens) <= 2:
flags.append("very_short")
repeat = _longest_repeat_run(tokens)
if repeat >= 3:
flags.append("repetitive")
if ref_clean and tokens and len(tokens) < min(len(r.split()) for r in ref_clean) // 2:
flags.append("under_length")
return SampleDiagnostics(
image=image,
prediction=pred_clean,
references=ref_clean,
length_tokens=len(tokens),
longest_repeat_run=repeat,
sentence_bleu4=_sentence_bleu4(pred_clean, ref_clean),
sentence_rouge_l=_sentence_rouge_l(pred_clean, ref_clean),
flags=flags,
)
def diagnose_many(
images: Sequence[str],
predictions: Sequence[str],
references: Sequence[Sequence[str]],
) -> list[SampleDiagnostics]:
"""Vectorised :func:`diagnose_sample` over parallel sequences."""
if not (len(images) == len(predictions) == len(references)):
raise ValueError(
"images, predictions, references must be the same length: "
f"got {len(images)} / {len(predictions)} / {len(references)}"
)
return [
diagnose_sample(img, pred, refs)
for img, pred, refs in zip(images, predictions, references, strict=True)
]
def write_diagnostics_jsonl(
diagnostics: Iterable[SampleDiagnostics],
path: str | Path,
) -> None:
"""Write one JSON object per line — pandas/jq friendly.
Args:
diagnostics: An iterable of :class:`SampleDiagnostics` (typically the
output of :func:`diagnose_many`).
path: Destination file. Parent directory is created if needed.
"""
out = Path(path)
out.parent.mkdir(parents=True, exist_ok=True)
with out.open("w", encoding="utf-8") as f:
for d in diagnostics:
f.write(json.dumps(asdict(d), ensure_ascii=False) + "\n")
def format_diagnostic_row(d: SampleDiagnostics) -> str:
"""Return a one-line human-readable summary — used by the CLI tail print."""
bleu = f"BLEU4={d.sentence_bleu4:5.1f}" if d.sentence_bleu4 is not None else "BLEU4= n/a"
rouge = f"R-L={d.sentence_rouge_l:5.1f}" if d.sentence_rouge_l is not None else "R-L= n/a"
flagstr = ",".join(d.flags) if d.flags else "-"
return (
f"{Path(d.image).name:35s} "
f"{bleu} {rouge} len={d.length_tokens:>2} repeat={d.longest_repeat_run:>2} "
f"flags={flagstr}\n pred: {d.prediction}\n ref : {d.references[0] if d.references else ''}"
)