File size: 3,867 Bytes
91a1214
 
 
 
 
 
 
 
 
 
3a2e5f0
 
 
 
 
91a1214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a2e5f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91a1214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a2e5f0
 
 
 
 
 
 
 
 
 
 
 
 
 
91a1214
 
 
3a2e5f0
91a1214
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""Corpus BLEU score.

The IEEE paper reports BLEU-4 ~24 on COCO val. ``sacrebleu`` is the de-facto
BLEU implementation; NLTK's BLEU has idiosyncratic smoothing and would not
reproduce the published number across machines.

``corpus_bleu_score`` returns BLEU-4 (the default n=4 score) so existing
callers keep working. ``corpus_bleu_breakdown`` additionally exposes BLEU-1,
BLEU-2, BLEU-3, BLEU-4 in one pass — useful for the inspection utility and
for the JSON report consumed by Phase 3 cross-model comparison.
"""

from __future__ import annotations

from collections.abc import Sequence
from dataclasses import dataclass

from captioning.evaluation.tokenization import (
    strip_sentinels_many,
    strip_sentinels_references,
)


@dataclass(frozen=True)
class BleuBreakdown:
    """Per-n BLEU precisions plus the overall BLEU-4 score (0-100 scale)."""

    bleu1: float
    bleu2: float
    bleu3: float
    bleu4: float


def _refs_by_slot(references: Sequence[Sequence[str]]) -> list[list[str]]:
    """Convert ragged per-example references to sacrebleu's per-slot layout."""
    max_refs = max(len(r) for r in references) if references else 0
    return [[refs[i] if i < len(refs) else "" for refs in references] for i in range(max_refs)]


def corpus_bleu_score(
    predictions: Sequence[str],
    references: Sequence[Sequence[str]],
) -> float:
    """Compute corpus BLEU-4 via ``sacrebleu``.

    Args:
        predictions: One generated caption per evaluation example.
        references: One *list* of reference captions per evaluation example.
            COCO has up to 5 references per image; pad shorter lists with the
            empty string ``""`` if needed (sacrebleu handles ragged lists).

    Returns:
        BLEU-4 in the 0-100 range (sacrebleu's convention; multiply by 1
        to compare with NLTK's 0-1 range — they're not interchangeable).

    Raises:
        ImportError: If sacrebleu is not installed. Install via the eval
            extras: ``pip install -e ".[eval]"`` or the requirements file.
    """
    return corpus_bleu_breakdown(predictions, references).bleu4


def corpus_bleu_breakdown(
    predictions: Sequence[str],
    references: Sequence[Sequence[str]],
) -> BleuBreakdown:
    """Compute BLEU-1, BLEU-2, BLEU-3, BLEU-4 in a single pass.

    Args:
        predictions: One generated caption per example.
        references: One *list* of reference captions per example.

    Returns:
        :class:`BleuBreakdown` with all four cumulative BLEU-n scores on the
        0-100 scale (sacrebleu's convention).

    Raises:
        ImportError: If sacrebleu is not installed.
        ValueError: On mismatched lengths.
    """
    try:
        import sacrebleu
    except ImportError as e:
        raise ImportError(
            "sacrebleu is required for BLEU evaluation. "
            "Install it via `pip install -r requirements-eval.txt`."
        ) from e

    if len(predictions) != len(references):
        raise ValueError(
            f"predictions ({len(predictions)}) and references "
            f"({len(references)}) must have the same length"
        )

    preds = strip_sentinels_many(predictions)
    refs = strip_sentinels_references(references)
    refs_by_slot = _refs_by_slot(refs)

    # ``corpus_bleu`` only returns BLEU-4. To get cumulative BLEU-1..3 we
    # instantiate ``BLEU`` directly with ``max_ngram_order=n``, which weights
    # the geometric mean over precisions[:n] (same convention as NLTK's
    # cumulative BLEU and the COCO eval scripts).
    bleu_cls = sacrebleu.metrics.BLEU
    scores: list[float] = []
    for n in (1, 2, 3, 4):
        scorer = bleu_cls(max_ngram_order=n, effective_order=True)
        scores.append(float(scorer.corpus_score(preds, refs_by_slot).score))
    return BleuBreakdown(bleu1=scores[0], bleu2=scores[1], bleu3=scores[2], bleu4=scores[3])