Spaces:
Configuration error
Configuration error
File size: 4,681 Bytes
91a1214 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 | """Single entry point that returns every implemented caption-quality metric.
``compute_all_metrics`` is the shared aggregator used by the CLI
(:mod:`scripts.evaluate`) and the per-sample inspection utility. It produces
a single :class:`MetricsReport` so downstream code never has to know which
metrics exist in the package — only how to read fields off the dataclass.
Adding a new metric is the four-step pattern this package already follows
elsewhere:
1. Implement ``corpus_<metric>_score`` in a sibling module.
2. Add an entry to :class:`MetricsReport`.
3. Call it from :func:`compute_all_metrics` (wrapped in a try/except so a
single broken metric never poisons the whole report).
4. Add a unit test on a toy fixture.
The exception swallowing is deliberate — METEOR needs Java, CIDEr needs
multiple samples, sacrebleu is always available. We do NOT want one
unavailable metric to kill the entire evaluation pass; instead we record
``None`` for that metric and surface a per-metric ``errors`` field so callers
(and the CLI) can flag the issue without losing the metrics that did work.
"""
from __future__ import annotations
from collections.abc import Sequence
from dataclasses import asdict, dataclass, field
from captioning.evaluation.bleu import corpus_bleu_breakdown
from captioning.evaluation.cider import MIN_SAMPLES_FOR_CIDER, corpus_cider_score
from captioning.evaluation.meteor import corpus_meteor_score
from captioning.evaluation.rouge import corpus_rouge_l_score
@dataclass(frozen=True)
class MetricsReport:
"""Aggregate metric snapshot for one evaluation pass.
Every metric is ``float | None`` — ``None`` means the metric was skipped
(uninstalled, environment missing Java, too few samples for CIDEr, ...).
The reason for skipping is in :attr:`errors` keyed by metric name.
"""
n_examples: int
bleu1: float | None = None
bleu2: float | None = None
bleu3: float | None = None
bleu4: float | None = None
rouge_l: float | None = None
meteor: float | None = None
cider: float | None = None
errors: dict[str, str] = field(default_factory=dict)
def to_dict(self) -> dict[str, object]:
"""Return a JSON-serialisable dict (``errors`` becomes a sub-object)."""
return asdict(self)
def compute_all_metrics(
predictions: Sequence[str],
references: Sequence[Sequence[str]],
*,
include_meteor: bool = True,
include_cider: bool = True,
) -> MetricsReport:
"""Compute every available metric on a single ``(preds, refs)`` corpus.
Args:
predictions: One generated caption per example.
references: One *list* of reference captions per example.
include_meteor: Set False to skip METEOR (avoids the JVM spawn —
helpful in CI where Java isn't installed).
include_cider: Set False to skip CIDEr (avoids the warning when
running on tiny corpora; the runner also auto-skips below
``MIN_SAMPLES_FOR_CIDER``).
Returns:
A :class:`MetricsReport` with every field populated by a corpus
metric or recorded as failed in ``errors``.
"""
if len(predictions) != len(references):
raise ValueError(
f"predictions ({len(predictions)}) and references "
f"({len(references)}) must have the same length"
)
errors: dict[str, str] = {}
bleu1 = bleu2 = bleu3 = bleu4 = None
rouge_l = meteor = cider = None
try:
bleu = corpus_bleu_breakdown(predictions, references)
bleu1, bleu2, bleu3, bleu4 = bleu.bleu1, bleu.bleu2, bleu.bleu3, bleu.bleu4
except Exception as e: # — surface, don't crash the run
errors["bleu"] = repr(e)
try:
rouge_l = corpus_rouge_l_score(predictions, references)
except Exception as e:
errors["rouge_l"] = repr(e)
if include_meteor:
try:
meteor = corpus_meteor_score(predictions, references)
except Exception as e:
errors["meteor"] = repr(e)
if include_cider:
if len(predictions) < MIN_SAMPLES_FOR_CIDER:
errors["cider"] = (
f"skipped: needs >= {MIN_SAMPLES_FOR_CIDER} examples, " f"got {len(predictions)}"
)
else:
try:
cider = corpus_cider_score(predictions, references)
except Exception as e:
errors["cider"] = repr(e)
return MetricsReport(
n_examples=len(predictions),
bleu1=bleu1,
bleu2=bleu2,
bleu3=bleu3,
bleu4=bleu4,
rouge_l=rouge_l,
meteor=meteor,
cider=cider,
errors=errors,
)
|