Spaces:
Build error
Build error
| """ | |
| Copyright (c) 2022, salesforce.com, inc. | |
| All rights reserved. | |
| SPDX-License-Identifier: BSD-3-Clause | |
| For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
| """ | |
| import json | |
| import logging | |
| import os | |
| import numpy as np | |
| import torch | |
| from lavis.common.dist_utils import is_main_process | |
| from lavis.common.registry import registry | |
| from lavis.tasks.base_task import BaseTask | |
| class RetrievalTask(BaseTask): | |
| def __init__(self, cfg): | |
| super().__init__() | |
| self.cfg = cfg | |
| def setup_task(cls, cfg): | |
| run_cfg = cfg.run_cfg | |
| return cls(cfg=run_cfg) | |
| def evaluation(self, model, data_loader, **kwargs): | |
| # score_i2t, score_t2i = model.compute_sim_matrix(model, data_loader) | |
| score_i2t, score_t2i = model.compute_sim_matrix(data_loader, task_cfg=self.cfg) | |
| if is_main_process(): | |
| eval_result = self._report_metrics( | |
| score_i2t, | |
| score_t2i, | |
| data_loader.dataset.txt2img, | |
| data_loader.dataset.img2txt, | |
| ) | |
| logging.info(eval_result) | |
| else: | |
| eval_result = None | |
| return eval_result | |
| def after_evaluation(self, val_result, **kwargs): | |
| return val_result | |
| def _report_metrics(scores_i2t, scores_t2i, txt2img, img2txt): | |
| # Images->Text | |
| ranks = np.zeros(scores_i2t.shape[0]) | |
| for index, score in enumerate(scores_i2t): | |
| inds = np.argsort(score)[::-1] | |
| # Score | |
| rank = 1e20 | |
| for i in img2txt[index]: | |
| tmp = np.where(inds == i)[0][0] | |
| if tmp < rank: | |
| rank = tmp | |
| ranks[index] = rank | |
| # Compute metrics | |
| tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) | |
| tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) | |
| tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) | |
| # Text->Images | |
| ranks = np.zeros(scores_t2i.shape[0]) | |
| for index, score in enumerate(scores_t2i): | |
| inds = np.argsort(score)[::-1] | |
| ranks[index] = np.where(inds == txt2img[index])[0][0] | |
| # Compute metrics | |
| ir1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) | |
| ir5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) | |
| ir10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) | |
| tr_mean = (tr1 + tr5 + tr10) / 3 | |
| ir_mean = (ir1 + ir5 + ir10) / 3 | |
| r_mean = (tr_mean + ir_mean) / 2 | |
| agg_metrics = (tr1 + tr5 + tr10) / 3 | |
| eval_result = { | |
| "txt_r1": tr1, | |
| "txt_r5": tr5, | |
| "txt_r10": tr10, | |
| "txt_r_mean": tr_mean, | |
| "img_r1": ir1, | |
| "img_r5": ir5, | |
| "img_r10": ir10, | |
| "img_r_mean": ir_mean, | |
| "r_mean": r_mean, | |
| "agg_metrics": agg_metrics, | |
| } | |
| with open( | |
| os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a" | |
| ) as f: | |
| f.write(json.dumps(eval_result) + "\n") | |
| return eval_result | |