spice / spice.py
sunhill's picture
regular input
5689a44
"""This module implements the SPICE metric."""
import os
import shutil
import subprocess
import json
import tempfile
from typing import List, Dict
import evaluate
import datasets
from evaluate.utils.logging import get_logger
logger = get_logger(__name__)
CORENLP = "stanford-corenlp-full-2015-12-09"
SPICELIB = "lib"
SPICE_JAR = "spice-1.0.jar"
_CITATION = """\
@inproceedings{spice2016,
title = {SPICE: Semantic Propositional Image Caption Evaluation},
author = {Peter Anderson and Basura Fernando and Mark Johnson and Stephen Gould},
year = {2016},
booktitle = {ECCV}
}
"""
_DESCRIPTION = """\
This module is designed to evaluate the quality of image captions using the SPICE metric.
It compares generated captions with reference captions to assess their semantic similarity.
"""
_KWARGS_DESCRIPTION = """
Compute SPICE score.
Args:
predictions: list of predictions to score. Each predictions
should be a string.
references: list of reference for each prediction. Each
reference should be a string.
Returns:
spice: SPICE score
Examples:
>>> metric = evaluate.load("sunhill/spice")
>>> results = metric.compute(
predictions=[['train traveling down a track in front of a road']],
references=[
[
'a train traveling down tracks next to lights',
'a blue and silver train next to train station and trees',
'a blue train is next to a sidewalk on the rails',
'a passenger train pulls into a train station',
'a train coming down the tracks arriving at a station'
]
]
)
>>> print(results)
[
{
"All": {
"pr": 0.25,
"re": 0.07142857142857142,
"f": 0.11111111111111112,
"fn": 13.0,
"numImages": 1.0,
"fp": 3.0,
"tp": 1.0,
},
"Relation": {
"pr": 0.0,
"re": 0.0,
"f": 0.0,
"fn": 5.0,
"numImages": 1.0,
"fp": 1.0,
"tp": 0.0,
},
"Cardinality": {
"pr": nan,
"re": nan,
"f": nan,
"fn": 0.0,
"numImages": 1.0,
"fp": 0.0,
"tp": 0.0,
},
"Attribute": {
"pr": 0.0,
"re": 0.0,
"f": 0.0,
"fn": 4.0,
"numImages": 1.0,
"fp": 0.0,
"tp": 0.0,
},
"Size": {
"pr": nan,
"re": nan,
"f": nan,
"fn": 0.0,
"numImages": 1.0,
"fp": 0.0,
"tp": 0.0,
},
"Color": {
"pr": 0.0,
"re": 0.0,
"f": 0.0,
"fn": 1.0,
"numImages": 1.0,
"fp": 0.0,
"tp": 0.0,
},
"Object": {
"pr": 0.3333333333333333,
"re": 0.2,
"f": 0.25,
"fn": 4.0,
"numImages": 1.0,
"fp": 2.0,
"tp": 1.0,
},
}
]
"""
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class SPICE(evaluate.Metric):
"""This module implements the SPICE metric for evaluating image captioning models."""
def _info(self):
return evaluate.MetricInfo(
# This is the description that will appear on the modules page.
module_type="metric",
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
# This defines the format of each prediction and reference
features=[
datasets.Features(
{
"predictions": datasets.Value("string"),
"references": datasets.Value("string"),
}
),
datasets.Features(
{
"predictions": datasets.Value("string"),
"references": datasets.Sequence(datasets.Value("string")),
}
),
],
# Homepage of the module for documentation
homepage="https://huggingface.co/spaces/sunhill/spice",
# Additional links to the codebase or references
codebase_urls=[
"https://github.com/peteanderson80/SPICE",
"https://github.com/EricWWWW/image-caption-metrics",
],
reference_urls=["https://panderson.me/spice"],
)
def _download_and_prepare(self, dl_manager):
"""Optional: download external resources useful to compute the scores"""
if os.path.exists("lib/stanford-corenlp-3.6.0-models.jar") and os.path.exists(
"lib/stanford-corenlp-3.6.0.jar"
):
logger.info("`stanford-corenlp` already exists. Skip downloading.")
return
logger.info("Downloading `stanford-corenlp`...")
url = f"http://nlp.stanford.edu/software/{CORENLP}.zip"
extracted_path = dl_manager.download_and_extract(url)
tmp_path = os.path.join(extracted_path, CORENLP)
shutil.copyfile(
os.path.join(tmp_path, "stanford-corenlp-3.6.0-models.jar"),
os.path.join(SPICELIB, "stanford-corenlp-3.6.0-models.jar"),
)
shutil.copyfile(
os.path.join(tmp_path, "stanford-corenlp-3.6.0.jar"),
os.path.join(SPICELIB, "stanford-corenlp-3.6.0.jar"),
)
logger.info(f"`stanford-corenlp` has been downloaded to {SPICELIB}")
def float_convert(self, obj):
try:
return float(obj)
except (ValueError, TypeError):
return float("nan")
def _compute_batch(self, scores: List[Dict]) -> Dict[str, float]:
"""Compute average scores over all images in the batch."""
# Initialize aggregate_scores with zero values
aggregate_scores = {
"pr": 0.0,
"re": 0.0,
"f": 0.0,
"fn": 0.0,
"numImages": 0.0,
"fp": 0.0,
"tp": 0.0,
}
num_images = len(scores)
if num_images == 0:
return aggregate_scores
# Sum up scores for each category
for score in scores:
for k, v in score.items():
if k in ["fn", "fp", "tp"]:
aggregate_scores[k] += v
aggregate_scores["numImages"] += 1
# Compute average scores
tp = aggregate_scores["tp"]
fp = aggregate_scores["fp"]
fn = aggregate_scores["fn"]
precision = tp / (tp + fp) if (tp + fp) > 0 else float("nan")
recall = tp / (tp + fn) if (tp + fn) > 0 else float("nan")
f_score = (
2 * precision * recall / (precision + recall)
if precision is not None and recall is not None and (precision + recall) > 0
else float("nan")
)
aggregate_scores["pr"] = precision
aggregate_scores["re"] = recall
aggregate_scores["f"] = f_score
return aggregate_scores
def _compute(self, predictions, references, spice_name="All"):
"""Returns the scores"""
assert len(predictions) == len(references), (
"The number of predictions and references should be the same. "
f"Got {len(predictions)} predictions and {len(references)} references."
)
input_data = []
for i, (prediction, reference) in enumerate(zip(predictions, references)):
assert isinstance(prediction, str), (
"Each prediction should be a string. "
f"Got {type(prediction)} for image {i}."
)
if isinstance(reference, str):
reference = [reference]
assert isinstance(reference, list) and all(
isinstance(ref, str) for ref in reference
), (
"Each reference should be a list of strings. "
f"Got {type(reference)} with elements of type {[type(ref) for ref in reference]} for index {i}."
)
input_data.append({"image_id": i, "test": prediction, "refs": reference})
in_file = tempfile.NamedTemporaryFile(delete=False)
in_file.write(json.dumps(input_data, indent=2).encode("utf-8"))
in_file.close()
out_file = tempfile.NamedTemporaryFile(delete=False)
out_file.close()
with tempfile.TemporaryDirectory() as cache_dir:
spice_cmd = [
"java",
"-jar",
"-Xmx8G",
SPICE_JAR,
in_file.name,
"-cache",
cache_dir,
"-out",
out_file.name,
"-subset",
"-silent",
]
try:
subprocess.run(
spice_cmd,
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
except subprocess.CalledProcessError as e:
raise RuntimeError(
f"SPICE command '{' '.join(spice_cmd)}' returned non-zero exit status {e.returncode}. "
f"stderr: {e.stderr.decode('utf-8')}"
) from e
with open(out_file.name, "r") as f:
results = json.load(f)
os.remove(in_file.name)
os.remove(out_file.name)
img_id_to_scores = {
item["image_id"]: item["scores"][spice_name] for item in results
}
scores = [
{k: self.float_convert(v) for k, v in img_id_to_scores[image_id].items()}
for image_id in range(len(predictions))
]
return {f"spice_{k}": v for k, v in self._compute_batch(scores).items()}