File size: 1,089 Bytes
002bd9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
import click
import json
import evaluate
import tqdm
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@click.command()
@click.argument("json_file", type=click.Path(exists=True))
@click.option("--split")
def compute_metrics(json_file, split):
compute_metrics_func = evaluate.load("meteor")
with open(json_file) as f:
data = json.load(f)
refs = []
preds = []
num_heads = None
logger.info(f"Loading {json_file}")
for sample in filter(lambda x: x["split"] == split, tqdm.tqdm(data)):
refs_ = sample["references"]
preds_ = sample["candidates"]
if num_heads is None:
num_heads = max(len(refs_), len(preds_))
refs.extend(refs_ + [refs_[-1]] * (num_heads - len(refs_)))
preds.extend(preds_ + [preds_[-1]] * (num_heads - len(preds_)))
logger.info(f"Computing metrics for {json_file}")
metrics = compute_metrics_func.compute(predictions=preds, references=refs)
print(json.dumps(metrics, indent=2))
if __name__ == "__main__":
compute_metrics()
|