Spaces:
Runtime error
Runtime error
| from typing import Any | |
| import torch | |
| import torch.nn as nn | |
| from mmpl.registry import MODELS | |
| from ..builder import build_backbone, build_loss | |
| from .base_pler import BasePLer | |
| from mmpl.structures import ClsDataSample | |
| from .base import BaseClassifier | |
| import lightning.pytorch as pl | |
| import torch.nn.functional as F | |
| class GPTPLer(BasePLer): | |
| def __init__(self, | |
| backbone, | |
| loss=dict(type='CrossEntropyLoss', loss_weight=1.0), | |
| *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.backbone = build_backbone(backbone) | |
| self.loss = build_loss(loss) | |
| def training_step(self, batch, batch_idx): | |
| x, gt_label = batch['x'], batch['gt_label'] | |
| outputs = self(input_ids=x, labels=gt_label) | |
| loss, logits = outputs['loss'], outputs['logits'] | |
| return loss | |
| def forward(self, *args: Any, **kwargs: Any) -> Any: | |
| return self.backbone(*args, **kwargs) | |
| def validation_step(self, batch, batch_idx): | |
| pass | |