|
|
"""This module implements the SPICE metric.""" |
|
|
|
|
|
import os |
|
|
import shutil |
|
|
import subprocess |
|
|
import json |
|
|
import tempfile |
|
|
from typing import List, Dict |
|
|
|
|
|
import evaluate |
|
|
import datasets |
|
|
from evaluate.utils.logging import get_logger |
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
CORENLP = "stanford-corenlp-full-2015-12-09" |
|
|
SPICELIB = "lib" |
|
|
SPICE_JAR = "spice-1.0.jar" |
|
|
|
|
|
_CITATION = """\ |
|
|
@inproceedings{spice2016, |
|
|
title = {SPICE: Semantic Propositional Image Caption Evaluation}, |
|
|
author = {Peter Anderson and Basura Fernando and Mark Johnson and Stephen Gould}, |
|
|
year = {2016}, |
|
|
booktitle = {ECCV} |
|
|
} |
|
|
""" |
|
|
|
|
|
_DESCRIPTION = """\ |
|
|
This module is designed to evaluate the quality of image captions using the SPICE metric. |
|
|
It compares generated captions with reference captions to assess their semantic similarity. |
|
|
""" |
|
|
|
|
|
_KWARGS_DESCRIPTION = """ |
|
|
Compute SPICE score. |
|
|
Args: |
|
|
predictions: list of predictions to score. Each predictions |
|
|
should be a string. |
|
|
references: list of reference for each prediction. Each |
|
|
reference should be a string. |
|
|
Returns: |
|
|
spice: SPICE score |
|
|
Examples: |
|
|
>>> metric = evaluate.load("sunhill/spice") |
|
|
>>> results = metric.compute( |
|
|
predictions=[['train traveling down a track in front of a road']], |
|
|
references=[ |
|
|
[ |
|
|
'a train traveling down tracks next to lights', |
|
|
'a blue and silver train next to train station and trees', |
|
|
'a blue train is next to a sidewalk on the rails', |
|
|
'a passenger train pulls into a train station', |
|
|
'a train coming down the tracks arriving at a station' |
|
|
] |
|
|
] |
|
|
) |
|
|
>>> print(results) |
|
|
[ |
|
|
{ |
|
|
"All": { |
|
|
"pr": 0.25, |
|
|
"re": 0.07142857142857142, |
|
|
"f": 0.11111111111111112, |
|
|
"fn": 13.0, |
|
|
"numImages": 1.0, |
|
|
"fp": 3.0, |
|
|
"tp": 1.0, |
|
|
}, |
|
|
"Relation": { |
|
|
"pr": 0.0, |
|
|
"re": 0.0, |
|
|
"f": 0.0, |
|
|
"fn": 5.0, |
|
|
"numImages": 1.0, |
|
|
"fp": 1.0, |
|
|
"tp": 0.0, |
|
|
}, |
|
|
"Cardinality": { |
|
|
"pr": nan, |
|
|
"re": nan, |
|
|
"f": nan, |
|
|
"fn": 0.0, |
|
|
"numImages": 1.0, |
|
|
"fp": 0.0, |
|
|
"tp": 0.0, |
|
|
}, |
|
|
"Attribute": { |
|
|
"pr": 0.0, |
|
|
"re": 0.0, |
|
|
"f": 0.0, |
|
|
"fn": 4.0, |
|
|
"numImages": 1.0, |
|
|
"fp": 0.0, |
|
|
"tp": 0.0, |
|
|
}, |
|
|
"Size": { |
|
|
"pr": nan, |
|
|
"re": nan, |
|
|
"f": nan, |
|
|
"fn": 0.0, |
|
|
"numImages": 1.0, |
|
|
"fp": 0.0, |
|
|
"tp": 0.0, |
|
|
}, |
|
|
"Color": { |
|
|
"pr": 0.0, |
|
|
"re": 0.0, |
|
|
"f": 0.0, |
|
|
"fn": 1.0, |
|
|
"numImages": 1.0, |
|
|
"fp": 0.0, |
|
|
"tp": 0.0, |
|
|
}, |
|
|
"Object": { |
|
|
"pr": 0.3333333333333333, |
|
|
"re": 0.2, |
|
|
"f": 0.25, |
|
|
"fn": 4.0, |
|
|
"numImages": 1.0, |
|
|
"fp": 2.0, |
|
|
"tp": 1.0, |
|
|
}, |
|
|
} |
|
|
] |
|
|
""" |
|
|
|
|
|
|
|
|
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) |
|
|
class SPICE(evaluate.Metric): |
|
|
"""This module implements the SPICE metric for evaluating image captioning models.""" |
|
|
|
|
|
def _info(self): |
|
|
return evaluate.MetricInfo( |
|
|
|
|
|
module_type="metric", |
|
|
description=_DESCRIPTION, |
|
|
citation=_CITATION, |
|
|
inputs_description=_KWARGS_DESCRIPTION, |
|
|
|
|
|
features=[ |
|
|
datasets.Features( |
|
|
{ |
|
|
"predictions": datasets.Value("string"), |
|
|
"references": datasets.Value("string"), |
|
|
} |
|
|
), |
|
|
datasets.Features( |
|
|
{ |
|
|
"predictions": datasets.Value("string"), |
|
|
"references": datasets.Sequence(datasets.Value("string")), |
|
|
} |
|
|
), |
|
|
], |
|
|
|
|
|
homepage="https://huggingface.co/spaces/sunhill/spice", |
|
|
|
|
|
codebase_urls=[ |
|
|
"https://github.com/peteanderson80/SPICE", |
|
|
"https://github.com/EricWWWW/image-caption-metrics", |
|
|
], |
|
|
reference_urls=["https://panderson.me/spice"], |
|
|
) |
|
|
|
|
|
def _download_and_prepare(self, dl_manager): |
|
|
"""Optional: download external resources useful to compute the scores""" |
|
|
if os.path.exists("lib/stanford-corenlp-3.6.0-models.jar") and os.path.exists( |
|
|
"lib/stanford-corenlp-3.6.0.jar" |
|
|
): |
|
|
logger.info("`stanford-corenlp` already exists. Skip downloading.") |
|
|
return |
|
|
logger.info("Downloading `stanford-corenlp`...") |
|
|
url = f"http://nlp.stanford.edu/software/{CORENLP}.zip" |
|
|
extracted_path = dl_manager.download_and_extract(url) |
|
|
tmp_path = os.path.join(extracted_path, CORENLP) |
|
|
shutil.copyfile( |
|
|
os.path.join(tmp_path, "stanford-corenlp-3.6.0-models.jar"), |
|
|
os.path.join(SPICELIB, "stanford-corenlp-3.6.0-models.jar"), |
|
|
) |
|
|
shutil.copyfile( |
|
|
os.path.join(tmp_path, "stanford-corenlp-3.6.0.jar"), |
|
|
os.path.join(SPICELIB, "stanford-corenlp-3.6.0.jar"), |
|
|
) |
|
|
logger.info(f"`stanford-corenlp` has been downloaded to {SPICELIB}") |
|
|
|
|
|
def float_convert(self, obj): |
|
|
try: |
|
|
return float(obj) |
|
|
except (ValueError, TypeError): |
|
|
return float("nan") |
|
|
|
|
|
def _compute_batch(self, scores: List[Dict]) -> Dict[str, float]: |
|
|
"""Compute average scores over all images in the batch.""" |
|
|
|
|
|
|
|
|
aggregate_scores = { |
|
|
"pr": 0.0, |
|
|
"re": 0.0, |
|
|
"f": 0.0, |
|
|
"fn": 0.0, |
|
|
"numImages": 0.0, |
|
|
"fp": 0.0, |
|
|
"tp": 0.0, |
|
|
} |
|
|
num_images = len(scores) |
|
|
if num_images == 0: |
|
|
return aggregate_scores |
|
|
|
|
|
|
|
|
for score in scores: |
|
|
for k, v in score.items(): |
|
|
if k in ["fn", "fp", "tp"]: |
|
|
aggregate_scores[k] += v |
|
|
aggregate_scores["numImages"] += 1 |
|
|
|
|
|
|
|
|
tp = aggregate_scores["tp"] |
|
|
fp = aggregate_scores["fp"] |
|
|
fn = aggregate_scores["fn"] |
|
|
precision = tp / (tp + fp) if (tp + fp) > 0 else float("nan") |
|
|
recall = tp / (tp + fn) if (tp + fn) > 0 else float("nan") |
|
|
f_score = ( |
|
|
2 * precision * recall / (precision + recall) |
|
|
if precision is not None and recall is not None and (precision + recall) > 0 |
|
|
else float("nan") |
|
|
) |
|
|
aggregate_scores["pr"] = precision |
|
|
aggregate_scores["re"] = recall |
|
|
aggregate_scores["f"] = f_score |
|
|
return aggregate_scores |
|
|
|
|
|
def _compute(self, predictions, references, spice_name="All"): |
|
|
"""Returns the scores""" |
|
|
assert len(predictions) == len(references), ( |
|
|
"The number of predictions and references should be the same. " |
|
|
f"Got {len(predictions)} predictions and {len(references)} references." |
|
|
) |
|
|
input_data = [] |
|
|
for i, (prediction, reference) in enumerate(zip(predictions, references)): |
|
|
assert isinstance(prediction, str), ( |
|
|
"Each prediction should be a string. " |
|
|
f"Got {type(prediction)} for image {i}." |
|
|
) |
|
|
if isinstance(reference, str): |
|
|
reference = [reference] |
|
|
assert isinstance(reference, list) and all( |
|
|
isinstance(ref, str) for ref in reference |
|
|
), ( |
|
|
"Each reference should be a list of strings. " |
|
|
f"Got {type(reference)} with elements of type {[type(ref) for ref in reference]} for index {i}." |
|
|
) |
|
|
input_data.append({"image_id": i, "test": prediction, "refs": reference}) |
|
|
|
|
|
in_file = tempfile.NamedTemporaryFile(delete=False) |
|
|
in_file.write(json.dumps(input_data, indent=2).encode("utf-8")) |
|
|
in_file.close() |
|
|
|
|
|
out_file = tempfile.NamedTemporaryFile(delete=False) |
|
|
out_file.close() |
|
|
with tempfile.TemporaryDirectory() as cache_dir: |
|
|
spice_cmd = [ |
|
|
"java", |
|
|
"-jar", |
|
|
"-Xmx8G", |
|
|
SPICE_JAR, |
|
|
in_file.name, |
|
|
"-cache", |
|
|
cache_dir, |
|
|
"-out", |
|
|
out_file.name, |
|
|
"-subset", |
|
|
"-silent", |
|
|
] |
|
|
try: |
|
|
subprocess.run( |
|
|
spice_cmd, |
|
|
check=True, |
|
|
stdout=subprocess.PIPE, |
|
|
stderr=subprocess.PIPE, |
|
|
) |
|
|
except subprocess.CalledProcessError as e: |
|
|
raise RuntimeError( |
|
|
f"SPICE command '{' '.join(spice_cmd)}' returned non-zero exit status {e.returncode}. " |
|
|
f"stderr: {e.stderr.decode('utf-8')}" |
|
|
) from e |
|
|
|
|
|
with open(out_file.name, "r") as f: |
|
|
results = json.load(f) |
|
|
os.remove(in_file.name) |
|
|
os.remove(out_file.name) |
|
|
|
|
|
img_id_to_scores = { |
|
|
item["image_id"]: item["scores"][spice_name] for item in results |
|
|
} |
|
|
scores = [ |
|
|
{k: self.float_convert(v) for k, v in img_id_to_scores[image_id].items()} |
|
|
for image_id in range(len(predictions)) |
|
|
] |
|
|
return {f"spice_{k}": v for k, v in self._compute_batch(scores).items()} |
|
|
|