import click import json import evaluate import tqdm import logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @click.command() @click.argument("json_file", type=click.Path(exists=True)) @click.option("--split") def compute_metrics(json_file, split): compute_metrics_func = evaluate.load("meteor") with open(json_file) as f: data = json.load(f) refs = [] preds = [] num_heads = None logger.info(f"Loading {json_file}") for sample in filter(lambda x: x["split"] == split, tqdm.tqdm(data)): refs_ = sample["references"] preds_ = sample["candidates"] if num_heads is None: num_heads = max(len(refs_), len(preds_)) refs.extend(refs_ + [refs_[-1]] * (num_heads - len(refs_))) preds.extend(preds_ + [preds_[-1]] * (num_heads - len(preds_))) logger.info(f"Computing metrics for {json_file}") metrics = compute_metrics_func.compute(predictions=preds, references=refs) print(json.dumps(metrics, indent=2)) if __name__ == "__main__": compute_metrics()