Spaces:
Configuration error
Configuration error
File size: 5,880 Bytes
91a1214 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 | """Per-sample inspection utilities for diagnosing weak captions.
The aggregate corpus metric tells you *how bad* the model is; this module
tells you *why*. For each (image, prediction, reference-set) triple it
records per-sample BLEU-4, sentence-level ROUGE-L, the prediction length,
the longest repeated token run, and whether the prediction is empty after
stripping sentinels.
Three failure modes the evaluation pass is trying to surface:
* **Generic captions** — high BLEU-1, low BLEU-4 (n-gram trickle out).
* **Repetition** — large ``repeat_run`` value.
* **Early stopping** — ``length_tokens`` far below reference median.
Output JSONL is intentionally flat (one line per sample) so it can be loaded
with ``pandas.read_json(..., lines=True)`` or grep'd from the shell. The
runner that uses this module writes one such file per evaluation pass
alongside ``metrics.json`` for the same run.
"""
from __future__ import annotations
import json
from collections.abc import Iterable, Sequence
from dataclasses import asdict, dataclass
from itertools import pairwise
from pathlib import Path
from captioning.evaluation.tokenization import strip_sentinels
@dataclass(frozen=True)
class SampleDiagnostics:
"""Inspectable record for one (image, prediction, reference-set) triple."""
image: str
prediction: str
references: list[str]
length_tokens: int
longest_repeat_run: int
sentence_bleu4: float | None
sentence_rouge_l: float | None
flags: list[str]
def _longest_repeat_run(tokens: Sequence[str]) -> int:
"""Return the longest run of immediately-repeated tokens.
Example: ``["a", "a", "a", "dog"]`` -> ``3``. Used to flag the classic
transformer-decoder collapse where the same token is emitted on every step.
"""
if not tokens:
return 0
best = current = 1
for prev, cur in pairwise(tokens):
current = current + 1 if cur == prev else 1
best = max(best, current)
return best
def _sentence_bleu4(prediction: str, references: Sequence[str]) -> float | None:
"""Sentence-level BLEU-4 via sacrebleu's effective-order smoothing."""
try:
import sacrebleu
except ImportError:
return None
if not references or not prediction:
return None
scorer = sacrebleu.metrics.BLEU(effective_order=True, max_ngram_order=4)
return float(scorer.sentence_score(prediction, list(references)).score)
def _sentence_rouge_l(prediction: str, references: Sequence[str]) -> float | None:
"""Best-of-references sentence-level ROUGE-L F-measure (0-100 scale)."""
try:
from rouge_score import rouge_scorer
except ImportError:
return None
valid_refs = [r for r in references if r]
if not valid_refs or not prediction:
return None
scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True)
best = max(scorer.score(r, prediction)["rougeL"].fmeasure for r in valid_refs)
return float(100.0 * best)
def diagnose_sample(
image: str,
prediction: str,
references: Sequence[str],
) -> SampleDiagnostics:
"""Return :class:`SampleDiagnostics` for one prediction-vs-references row."""
pred_clean = strip_sentinels(prediction)
ref_clean = [strip_sentinels(r) for r in references if r]
tokens = pred_clean.split()
flags: list[str] = []
if not pred_clean:
flags.append("empty")
if len(tokens) <= 2:
flags.append("very_short")
repeat = _longest_repeat_run(tokens)
if repeat >= 3:
flags.append("repetitive")
if ref_clean and tokens and len(tokens) < min(len(r.split()) for r in ref_clean) // 2:
flags.append("under_length")
return SampleDiagnostics(
image=image,
prediction=pred_clean,
references=ref_clean,
length_tokens=len(tokens),
longest_repeat_run=repeat,
sentence_bleu4=_sentence_bleu4(pred_clean, ref_clean),
sentence_rouge_l=_sentence_rouge_l(pred_clean, ref_clean),
flags=flags,
)
def diagnose_many(
images: Sequence[str],
predictions: Sequence[str],
references: Sequence[Sequence[str]],
) -> list[SampleDiagnostics]:
"""Vectorised :func:`diagnose_sample` over parallel sequences."""
if not (len(images) == len(predictions) == len(references)):
raise ValueError(
"images, predictions, references must be the same length: "
f"got {len(images)} / {len(predictions)} / {len(references)}"
)
return [
diagnose_sample(img, pred, refs)
for img, pred, refs in zip(images, predictions, references, strict=True)
]
def write_diagnostics_jsonl(
diagnostics: Iterable[SampleDiagnostics],
path: str | Path,
) -> None:
"""Write one JSON object per line — pandas/jq friendly.
Args:
diagnostics: An iterable of :class:`SampleDiagnostics` (typically the
output of :func:`diagnose_many`).
path: Destination file. Parent directory is created if needed.
"""
out = Path(path)
out.parent.mkdir(parents=True, exist_ok=True)
with out.open("w", encoding="utf-8") as f:
for d in diagnostics:
f.write(json.dumps(asdict(d), ensure_ascii=False) + "\n")
def format_diagnostic_row(d: SampleDiagnostics) -> str:
"""Return a one-line human-readable summary — used by the CLI tail print."""
bleu = f"BLEU4={d.sentence_bleu4:5.1f}" if d.sentence_bleu4 is not None else "BLEU4= n/a"
rouge = f"R-L={d.sentence_rouge_l:5.1f}" if d.sentence_rouge_l is not None else "R-L= n/a"
flagstr = ",".join(d.flags) if d.flags else "-"
return (
f"{Path(d.image).name:35s} "
f"{bleu} {rouge} len={d.length_tokens:>2} repeat={d.longest_repeat_run:>2} "
f"flags={flagstr}\n pred: {d.prediction}\n ref : {d.references[0] if d.references else ''}"
)
|