Spaces:
Runtime error
Runtime error
| import os | |
| from typing import Any | |
| import mmengine | |
| import torch | |
| import torch.nn as nn | |
| from einops import rearrange | |
| from mmdet.models.utils import samplelist_boxtype2tensor | |
| from mmdet.structures import SampleList | |
| from mmdet.utils import InstanceList | |
| from mmpl.registry import MODELS | |
| from ..builder import build_backbone, build_loss, build_neck, build_head | |
| from .base_pler import BasePLer | |
| from mmpl.structures import ClsDataSample | |
| from .base import BaseClassifier | |
| import lightning.pytorch as pl | |
| import torch.nn.functional as F | |
| class MMClsPLer(BasePLer): | |
| def __init__(self, | |
| whole_model=None, | |
| *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.save_hyperparameters() | |
| self.whole_model = MODELS.build(whole_model) | |
| def setup(self, stage: str) -> None: | |
| super().setup(stage) | |
| def validation_step(self, batch, batch_idx): | |
| data = self.whole_model.data_preprocessor(batch, False) | |
| batch_data_samples = self.whole_model._run_forward(data, mode='predict') # type: ignore | |
| pred_label = torch.cat([data_sample.pred_label for data_sample in batch_data_samples]) | |
| gt_label = torch.cat([data_sample.gt_label for data_sample in batch_data_samples]) | |
| self.val_evaluator.update(pred_label, gt_label) | |
| # self.val_evaluator.update(batch, batch_data_samples) | |
| def training_step(self, batch, batch_idx): | |
| data = self.whole_model.data_preprocessor(batch, True) | |
| losses = self.whole_model._run_forward(data, mode='loss') # type: ignore | |
| parsed_losses, log_vars = self.parse_losses(losses) | |
| log_vars = {f'train_{k}': v for k, v in log_vars.items()} | |
| log_vars['loss'] = parsed_losses | |
| self.log_dict(log_vars, prog_bar=True) | |
| return log_vars | |
| def test_step(self, batch, batch_idx, *args: Any, **kwargs: Any): | |
| data = self.whole_model.data_preprocessor(batch, False) | |
| batch_data_samples = self.whole_model._run_forward(data, mode='predict') # type: ignore | |
| self.test_evaluator.update(batch, batch_data_samples) | |