deepspeed / scripts /tools /compute_metrics.py
xingzhikb's picture
init
002bd9b
import click
import json
import evaluate
import tqdm
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@click.command()
@click.argument("json_file", type=click.Path(exists=True))
@click.option("--split")
def compute_metrics(json_file, split):
compute_metrics_func = evaluate.load("meteor")
with open(json_file) as f:
data = json.load(f)
refs = []
preds = []
num_heads = None
logger.info(f"Loading {json_file}")
for sample in filter(lambda x: x["split"] == split, tqdm.tqdm(data)):
refs_ = sample["references"]
preds_ = sample["candidates"]
if num_heads is None:
num_heads = max(len(refs_), len(preds_))
refs.extend(refs_ + [refs_[-1]] * (num_heads - len(refs_)))
preds.extend(preds_ + [preds_[-1]] * (num_heads - len(preds_)))
logger.info(f"Computing metrics for {json_file}")
metrics = compute_metrics_func.compute(predictions=preds, references=refs)
print(json.dumps(metrics, indent=2))
if __name__ == "__main__":
compute_metrics()