Spaces:
Build error
Build error
| """ | |
| Copyright (c) 2022, salesforce.com, inc. | |
| All rights reserved. | |
| SPDX-License-Identifier: BSD-3-Clause | |
| For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
| """ | |
| import json | |
| import os | |
| import logging | |
| import numpy as np | |
| import torch | |
| from lavis.common.dist_utils import main_process | |
| from lavis.common.registry import registry | |
| from lavis.tasks.base_task import BaseTask | |
| class MultimodalClassificationTask(BaseTask): | |
| def __init__(self): | |
| super().__init__() | |
| def valid_step(self, model, samples): | |
| results = [] | |
| outputs = model.predict(samples) | |
| predictions = outputs["predictions"] | |
| targets = outputs["targets"] | |
| predictions = predictions.max(1)[1].cpu().numpy() | |
| targets = targets.cpu().numpy() | |
| indices = samples[self.inst_id_key] | |
| for pred, tgt, index in zip(predictions, targets, indices): | |
| if isinstance(index, torch.Tensor): | |
| index = index.item() | |
| results.append( | |
| { | |
| self.inst_id_key: index, | |
| "prediction": pred.item(), | |
| "target": tgt.item(), | |
| } | |
| ) | |
| return results | |
| def after_evaluation(self, val_result, split_name, epoch, **kwargs): | |
| eval_result_file = self.save_result( | |
| result=val_result, | |
| result_dir=registry.get_path("result_dir"), | |
| filename="{}_epoch{}".format(split_name, epoch), | |
| remove_duplicate=self.inst_id_key, | |
| ) | |
| metrics = self._report_metrics( | |
| eval_result_file=eval_result_file, split_name=split_name | |
| ) | |
| return metrics | |
| def _report_metrics(self, eval_result_file, split_name): | |
| results = json.load(open(eval_result_file)) | |
| predictions = np.array([res["prediction"] for res in results]) | |
| targets = np.array([res["target"] for res in results]) | |
| accuracy = (targets == predictions).sum() / targets.shape[0] | |
| metrics = {"agg_metrics": accuracy, "acc": accuracy} | |
| log_stats = {split_name: {k: v for k, v in metrics.items()}} | |
| with open( | |
| os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a" | |
| ) as f: | |
| f.write(json.dumps(log_stats) + "\n") | |
| logging.info(metrics) | |
| return metrics | |