| """
|
| Copyright (c) 2022, salesforce.com, inc.
|
| All rights reserved.
|
| SPDX-License-Identifier: BSD-3-Clause
|
| For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
| """
|
|
|
| import json
|
| import os
|
|
|
| from lavis.common.dist_utils import main_process
|
| from lavis.common.logger import MetricLogger
|
| from lavis.common.registry import registry
|
| from lavis.tasks.base_task import BaseTask
|
| from lavis.datasets.data_utils import prepare_sample
|
|
|
| import numpy as np
|
|
|
|
|
| @registry.register_task("dialogue")
|
| class DialogueTask(BaseTask):
|
| def __init__(self, num_beams, max_len, min_len, evaluate, report_metric=True):
|
| super().__init__()
|
|
|
| self.num_beams = num_beams
|
| self.max_len = max_len
|
| self.min_len = min_len
|
| self.evaluate = evaluate
|
|
|
| self.report_metric = report_metric
|
|
|
| @classmethod
|
| def setup_task(cls, cfg):
|
| run_cfg = cfg.run_cfg
|
|
|
| num_beams = run_cfg.num_beams
|
| max_len = run_cfg.max_len
|
| min_len = run_cfg.min_len
|
| evaluate = run_cfg.evaluate
|
|
|
| report_metric = run_cfg.get("report_metric", True)
|
|
|
| return cls(
|
| num_beams=num_beams,
|
| max_len=max_len,
|
| min_len=min_len,
|
| evaluate=evaluate,
|
| report_metric=report_metric,
|
| )
|
|
|
| def valid_step(self, model, samples):
|
| results = []
|
| loss = model(samples)["loss"].item()
|
|
|
| return [loss]
|
|
|
| def after_evaluation(self, val_result, split_name, epoch, **kwargs):
|
|
|
| if self.report_metric:
|
| avg_loss = np.mean(val_result)
|
| metrics = {"agg_metrics": avg_loss}
|
| else:
|
| metrics = {"agg_metrics": 0.0}
|
|
|
| return metrics
|
|
|
| @main_process
|
| def _report_metrics(self, eval_result_file, split_name):
|
|
|
| coco_gt_root = os.path.join(registry.get_path("cache_root"), "coco_gt")
|
| coco_val = coco_dialogue_eval(coco_gt_root, eval_result_file, split_name)
|
|
|
| agg_metrics = coco_val.eval["CIDEr"] + coco_val.eval["Bleu_4"]
|
| log_stats = {split_name: {k: v for k, v in coco_val.eval.items()}}
|
|
|
| with open(
|
| os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a"
|
| ) as f:
|
| f.write(json.dumps(log_stats) + "\n")
|
|
|
| coco_res = {k: v for k, v in coco_val.eval.items()}
|
| coco_res["agg_metrics"] = agg_metrics
|
|
|
| return coco_res
|
|
|
|
|
|
|
| from pycocoevalcap.eval import COCOEvalCap
|
| from pycocotools.coco import COCO
|
| from torchvision.datasets.utils import download_url
|
|
|
|
|
| def coco_dialogue_eval(coco_gt_root, results_file, split):
|
|
|
| urls = {
|
| "val": "https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json",
|
| "test": "https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json",
|
| }
|
| filenames = {
|
| "val": "coco_karpathy_val_gt.json",
|
| "test": "coco_karpathy_test_gt.json",
|
| }
|
|
|
| download_url(urls[split], coco_gt_root)
|
| annotation_file = os.path.join(coco_gt_root, filenames[split])
|
|
|
|
|
| coco = COCO(annotation_file)
|
| coco_result = coco.loadRes(results_file)
|
|
|
|
|
| coco_eval = COCOEvalCap(coco, coco_result)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| coco_eval.evaluate()
|
|
|
|
|
| for metric, score in coco_eval.eval.items():
|
| print(f"{metric}: {score:.3f}")
|
|
|
| return coco_eval
|
|
|