Spaces:
Runtime error
Runtime error
| """ | |
| Copyright (c) 2022, salesforce.com, inc. | |
| All rights reserved. | |
| SPDX-License-Identifier: BSD-3-Clause | |
| For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
| """ | |
| import json | |
| import os | |
| import pandas as pd | |
| from tqdm import tqdm | |
| from lavis.common.dist_utils import main_process, get_rank | |
| from lavis.common.registry import registry | |
| from lavis.tasks.base_task import BaseTask | |
| from lavis.common.utils import is_convertible_to_int, is_url, cache_url | |
| class CaptionTask(BaseTask): | |
| def __init__(self, num_beams, max_len, min_len, repetition_penalty, length_penalty, top_p, temperature, evaluate, report_metric=True, annotation_file=None, sample_id_key="image_id", caption_key="caption", split=["val"], load_gt_from_file=False, img_ids = []): | |
| super().__init__() | |
| self.num_beams = num_beams | |
| self.max_len = max_len | |
| self.min_len = min_len | |
| self.repetition_penalty = repetition_penalty | |
| self.length_penalty = length_penalty | |
| self.top_p = top_p | |
| self.temperature = temperature | |
| self.evaluate = evaluate | |
| self.report_metric = report_metric | |
| self.annotation_file = annotation_file | |
| self.sample_id_key = sample_id_key | |
| self.caption_key = caption_key | |
| assert len(split) == 1, "Only support one split for evaluation." | |
| self.split = split[0] | |
| self.load_gt_from_file = load_gt_from_file | |
| self.img_ids = img_ids | |
| def setup_task(cls, cfg): | |
| run_cfg = cfg.run_cfg | |
| num_beams = run_cfg.get("num_beams", 5) | |
| max_len = run_cfg.get("max_len", 30) | |
| min_len = run_cfg.get("min_len", 1) | |
| repetition_penalty = run_cfg.get("repetition_penalty", 1.15) | |
| length_penalty = run_cfg.get("length_penalty", 0.) | |
| top_p = run_cfg.get("top_p", 0.9) | |
| temperature = run_cfg.get("temperature", 1.) | |
| evaluate = run_cfg.evaluate | |
| report_metric = run_cfg.get("report_metric", True) | |
| annotation_file = run_cfg.get("annotation_file", None) | |
| sample_id_key = run_cfg.get("sample_id_key", "image_id") | |
| caption_key = run_cfg.get("caption_key", "caption") | |
| load_gt_from_file = run_cfg.get("load_gt_from_file", False) | |
| split = run_cfg.get("valid_splits", ["val"]) | |
| img_ids = run_cfg.get("img_ids", []) # evaluate only subset of imgs | |
| return cls( | |
| num_beams=num_beams, | |
| max_len=max_len, | |
| min_len=min_len, | |
| repetition_penalty=repetition_penalty, | |
| length_penalty=length_penalty, | |
| top_p=top_p, | |
| temperature=temperature, | |
| evaluate=evaluate, | |
| report_metric=report_metric, | |
| annotation_file=annotation_file, | |
| sample_id_key=sample_id_key, | |
| caption_key=caption_key, | |
| split=split, | |
| load_gt_from_file=load_gt_from_file, | |
| img_ids=img_ids | |
| ) | |
| def build_datasets(self, cfg): | |
| datasets = super().build_datasets(cfg) | |
| # get validation dataset name | |
| val_ds_name = [] | |
| for name,d in datasets.items(): | |
| if self.split in d: | |
| val_ds_name.append(name) | |
| if not val_ds_name: | |
| return datasets # no validation sets | |
| assert len(val_ds_name) == 1, "Only support one dataset for validation" | |
| val_ds_name = val_ds_name[0] | |
| # get question file, annotation file and anwser list in COCO format | |
| if self.annotation_file == None: | |
| if 'coco' not in val_ds_name: # coco is already precomputed in dataset | |
| self.annotation_file = os.path.join(registry.get_path("cache_root"),f'{val_ds_name}_gt', f'{val_ds_name}_{self.split}_annotations.json') | |
| if get_rank() == 0: | |
| os.makedirs(os.path.join(registry.get_path("cache_root"),f'{val_ds_name}_gt'), exist_ok=True) | |
| convert_to_coco_gt(datasets[val_ds_name], self.annotation_file, self.caption_key, self.sample_id_key, self.split, load_gt_from_file=self.load_gt_from_file, img_ids=self.img_ids) | |
| return datasets | |
| def valid_step(self, model, samples): | |
| results = [] | |
| # run_cfg = slf.cfg.run_cfg | |
| captions = model.generate( | |
| samples, | |
| use_nucleus_sampling=False, | |
| num_beams=self.num_beams, | |
| max_length=self.max_len, | |
| min_length=self.min_len, | |
| repetition_penalty=self.repetition_penalty, | |
| length_penalty=self.length_penalty, | |
| top_p=self.top_p, | |
| temperature=self.temperature, | |
| ) | |
| img_ids = samples[self.sample_id_key] | |
| for caption, img_id in zip(captions, img_ids): | |
| # not all img_ids are ints | |
| img_id = int(img_id) if is_convertible_to_int(img_id) else img_id | |
| if self.img_ids and img_id not in self.img_ids: # only include specified img_ids if specified | |
| continue | |
| results.append({"caption": caption, "image_id": img_id}) | |
| return results | |
| def after_evaluation(self, val_result, split_name, epoch, **kwargs): | |
| eval_result_file = self.save_result( | |
| result=val_result, | |
| result_dir=registry.get_path("result_dir"), | |
| filename="{}_epoch{}".format(split_name, epoch), | |
| remove_duplicate="image_id", | |
| ) | |
| if self.report_metric: | |
| metrics = self._report_metrics( | |
| eval_result_file=eval_result_file, split_name=split_name | |
| ) | |
| else: | |
| metrics = {"agg_metrics": 0.0} | |
| return metrics | |
| def _report_metrics(self, eval_result_file, split_name): | |
| if self.annotation_file == None: | |
| # TODO better way to define this | |
| coco_gt_root = os.path.join(registry.get_path("cache_root"), "coco_gt") | |
| coco_val = coco_caption_eval(coco_gt_root, eval_result_file, split_name, img_ids=self.img_ids) | |
| else: | |
| coco_val = coco_caption_eval(None, eval_result_file, split_name, annotation_file=self.annotation_file, img_ids=self.img_ids) | |
| agg_metrics = coco_val.eval["CIDEr"] + coco_val.eval["Bleu_4"] | |
| log_stats = {split_name: {k: v for k, v in coco_val.eval.items()}} | |
| with open( | |
| os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a" | |
| ) as f: | |
| f.write(json.dumps(log_stats) + "\n") | |
| coco_res = {k: v for k, v in coco_val.eval.items()} | |
| coco_res["agg_metrics"] = agg_metrics | |
| return coco_res | |
| def load_gt_file(file_path): | |
| if is_url(file_path): | |
| file_path = cache_url(file_path, registry.get_path("cache_root")) | |
| data = [] | |
| if any(ext in file_path for ext in ['csv', 'tsv']): | |
| df = pd.read_csv(file_path) | |
| data.extend(df.to_dict(orient="records")) | |
| elif 'jsonl' in file_path: | |
| with open(file_path, "r") as f: | |
| data.extend([json.loads(line) for line in f]) | |
| else: | |
| with open(file_path, "r") as f: | |
| loaded = json.load(f) | |
| if isinstance(loaded, list): | |
| data.extend(loaded) | |
| elif isinstance(loaded, dict): | |
| # assume that loaded data in file is the corresponding caption to the key | |
| data.extend([{"sample_id": k, **v} if isinstance(v, dict) else {"sample_id": k, "caption": v} for k, v in loaded.items()]) | |
| return data | |
| def convert_to_coco_gt(data, outpath, caption_key, sample_id_key, split, load_gt_from_file=False, img_ids=[]): | |
| gt_data = {"annotations":[], "images":[]} | |
| if load_gt_from_file: | |
| print(f"Generating ground truth file for evaluation from {load_gt_from_file}....") | |
| data = load_gt_file(load_gt_from_file) | |
| for ann in data: | |
| captions = ann[caption_key] | |
| img_id = int(ann[sample_id_key]) if is_convertible_to_int(ann[sample_id_key]) else ann[sample_id_key] | |
| if img_ids and img_id not in img_ids: # only include specified img_ids if specified | |
| continue | |
| gt_data["images"].append({"id":img_id}) | |
| if isinstance(captions, str): | |
| gt_data["annotations"].append({"image_id":img_id, "caption":captions, "id":img_id}) | |
| else: | |
| gt_data["annotations"].extend([{"image_id":img_id, "caption":c, "id":img_id} for c in captions]) | |
| else: | |
| print(f"Generating ground truth file for evaluation....") | |
| for i,ann in tqdm(enumerate(data[split])): | |
| captions = data[split].annotation[i][caption_key] | |
| img_id = int(ann[sample_id_key]) if is_convertible_to_int(ann[sample_id_key]) else ann[sample_id_key] | |
| if img_ids and img_id not in img_ids: # only include specified img_ids if specified | |
| continue | |
| gt_data["images"].append({"id":img_id}) | |
| if isinstance(captions, str): | |
| gt_data["annotations"].append({"image_id":img_id, "caption":captions, "id":img_id}) | |
| else: | |
| gt_data["annotations"].extend([{"image_id":img_id, "caption":c, "id":img_id} for c in captions]) | |
| json.dump(gt_data, open(outpath, 'w')) | |
| print(f"Saved annotations at {outpath}") | |
| # TODO better structure for this. | |
| from pycocoevalcap.eval import COCOEvalCap | |
| from pycocotools.coco import COCO | |
| from torchvision.datasets.utils import download_url | |
| def coco_caption_eval(coco_gt_root, results_file, split, annotation_file=None, img_ids=[]): | |
| if annotation_file == None: | |
| urls = { | |
| "val": "https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json", | |
| "test": "https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json", | |
| } | |
| filenames = { | |
| "val": "coco_karpathy_val_gt.json", | |
| "test": "coco_karpathy_test_gt.json", | |
| } | |
| download_url(urls[split], coco_gt_root) | |
| annotation_file = os.path.join(coco_gt_root, filenames[split]) | |
| if is_url(annotation_file): | |
| annotation_file = cache_url(annotation_file, registry.get_path("cache_root")) | |
| # create coco object and coco_result object | |
| coco = COCO(annotation_file) | |
| coco_result = coco.loadRes(results_file) | |
| # create coco_eval object by taking coco and coco_result | |
| coco_eval = COCOEvalCap(coco, coco_result) | |
| # evaluate on a subset of images by setting | |
| if img_ids: | |
| coco_eval.params['image_id'] = coco_result.getImgIds() | |
| # please remove this line when evaluating the full validation set | |
| # coco_eval.params['image_id'] = coco_result.getImgIds() | |
| # evaluate results | |
| # SPICE will take a few minutes the first time, but speeds up due to caching | |
| coco_eval.evaluate() | |
| # print output evaluation scores | |
| for metric, score in coco_eval.eval.items(): | |
| print(f"{metric}: {score:.3f}") | |
| return coco_eval | |