File size: 5,880 Bytes
91a1214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
"""Per-sample inspection utilities for diagnosing weak captions.

The aggregate corpus metric tells you *how bad* the model is; this module
tells you *why*. For each (image, prediction, reference-set) triple it
records per-sample BLEU-4, sentence-level ROUGE-L, the prediction length,
the longest repeated token run, and whether the prediction is empty after
stripping sentinels.

Three failure modes the evaluation pass is trying to surface:
    * **Generic captions** — high BLEU-1, low BLEU-4 (n-gram trickle out).
    * **Repetition** — large ``repeat_run`` value.
    * **Early stopping** — ``length_tokens`` far below reference median.

Output JSONL is intentionally flat (one line per sample) so it can be loaded
with ``pandas.read_json(..., lines=True)`` or grep'd from the shell. The
runner that uses this module writes one such file per evaluation pass
alongside ``metrics.json`` for the same run.
"""

from __future__ import annotations

import json
from collections.abc import Iterable, Sequence
from dataclasses import asdict, dataclass
from itertools import pairwise
from pathlib import Path

from captioning.evaluation.tokenization import strip_sentinels


@dataclass(frozen=True)
class SampleDiagnostics:
    """Inspectable record for one (image, prediction, reference-set) triple."""

    image: str
    prediction: str
    references: list[str]
    length_tokens: int
    longest_repeat_run: int
    sentence_bleu4: float | None
    sentence_rouge_l: float | None
    flags: list[str]


def _longest_repeat_run(tokens: Sequence[str]) -> int:
    """Return the longest run of immediately-repeated tokens.

    Example: ``["a", "a", "a", "dog"]`` -> ``3``. Used to flag the classic
    transformer-decoder collapse where the same token is emitted on every step.
    """
    if not tokens:
        return 0
    best = current = 1
    for prev, cur in pairwise(tokens):
        current = current + 1 if cur == prev else 1
        best = max(best, current)
    return best


def _sentence_bleu4(prediction: str, references: Sequence[str]) -> float | None:
    """Sentence-level BLEU-4 via sacrebleu's effective-order smoothing."""
    try:
        import sacrebleu
    except ImportError:
        return None
    if not references or not prediction:
        return None
    scorer = sacrebleu.metrics.BLEU(effective_order=True, max_ngram_order=4)
    return float(scorer.sentence_score(prediction, list(references)).score)


def _sentence_rouge_l(prediction: str, references: Sequence[str]) -> float | None:
    """Best-of-references sentence-level ROUGE-L F-measure (0-100 scale)."""
    try:
        from rouge_score import rouge_scorer
    except ImportError:
        return None
    valid_refs = [r for r in references if r]
    if not valid_refs or not prediction:
        return None
    scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True)
    best = max(scorer.score(r, prediction)["rougeL"].fmeasure for r in valid_refs)
    return float(100.0 * best)


def diagnose_sample(
    image: str,
    prediction: str,
    references: Sequence[str],
) -> SampleDiagnostics:
    """Return :class:`SampleDiagnostics` for one prediction-vs-references row."""
    pred_clean = strip_sentinels(prediction)
    ref_clean = [strip_sentinels(r) for r in references if r]
    tokens = pred_clean.split()

    flags: list[str] = []
    if not pred_clean:
        flags.append("empty")
    if len(tokens) <= 2:
        flags.append("very_short")
    repeat = _longest_repeat_run(tokens)
    if repeat >= 3:
        flags.append("repetitive")
    if ref_clean and tokens and len(tokens) < min(len(r.split()) for r in ref_clean) // 2:
        flags.append("under_length")

    return SampleDiagnostics(
        image=image,
        prediction=pred_clean,
        references=ref_clean,
        length_tokens=len(tokens),
        longest_repeat_run=repeat,
        sentence_bleu4=_sentence_bleu4(pred_clean, ref_clean),
        sentence_rouge_l=_sentence_rouge_l(pred_clean, ref_clean),
        flags=flags,
    )


def diagnose_many(
    images: Sequence[str],
    predictions: Sequence[str],
    references: Sequence[Sequence[str]],
) -> list[SampleDiagnostics]:
    """Vectorised :func:`diagnose_sample` over parallel sequences."""
    if not (len(images) == len(predictions) == len(references)):
        raise ValueError(
            "images, predictions, references must be the same length: "
            f"got {len(images)} / {len(predictions)} / {len(references)}"
        )
    return [
        diagnose_sample(img, pred, refs)
        for img, pred, refs in zip(images, predictions, references, strict=True)
    ]


def write_diagnostics_jsonl(
    diagnostics: Iterable[SampleDiagnostics],
    path: str | Path,
) -> None:
    """Write one JSON object per line — pandas/jq friendly.

    Args:
        diagnostics: An iterable of :class:`SampleDiagnostics` (typically the
            output of :func:`diagnose_many`).
        path: Destination file. Parent directory is created if needed.
    """
    out = Path(path)
    out.parent.mkdir(parents=True, exist_ok=True)
    with out.open("w", encoding="utf-8") as f:
        for d in diagnostics:
            f.write(json.dumps(asdict(d), ensure_ascii=False) + "\n")


def format_diagnostic_row(d: SampleDiagnostics) -> str:
    """Return a one-line human-readable summary — used by the CLI tail print."""
    bleu = f"BLEU4={d.sentence_bleu4:5.1f}" if d.sentence_bleu4 is not None else "BLEU4=    n/a"
    rouge = f"R-L={d.sentence_rouge_l:5.1f}" if d.sentence_rouge_l is not None else "R-L=    n/a"
    flagstr = ",".join(d.flags) if d.flags else "-"
    return (
        f"{Path(d.image).name:35s}  "
        f"{bleu}  {rouge}  len={d.length_tokens:>2}  repeat={d.longest_repeat_run:>2}  "
        f"flags={flagstr}\n  pred: {d.prediction}\n  ref : {d.references[0] if d.references else ''}"
    )