| |
| from dataclasses import dataclass |
| from typing import Optional, Tuple, List |
| import inspect |
| import torch |
| from torch import nn |
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.modeling_outputs import ModelOutput |
|
|
| from .configuration_selectivevit import SMSelectiveViTConfig |
|
|
| from .selective_vit import VisionTransformer |
|
|
|
|
| @dataclass |
| class ImageClassifierWithMasksOutput(ModelOutput): |
| loss: Optional[torch.FloatTensor] = None |
| logits: Optional[torch.FloatTensor] = None |
| distil_logits: Optional[torch.FloatTensor] = None |
| last_hidden_state: Optional[torch.FloatTensor] = None |
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
| masks: Optional[List[torch.FloatTensor]] = None |
|
|
|
|
|
|
| class SMSelectiveViTModelForClassification(PreTrainedModel ): |
| config_class = SMSelectiveViTConfig |
| base_model_prefix = "backbone" |
|
|
| def __init__(self, config: SMSelectiveViTConfig): |
| super().__init__(config) |
|
|
| |
| cfg_dict = config.to_dict() |
|
|
| |
| sig = inspect.signature(VisionTransformer.__init__) |
| allowed = set(sig.parameters.keys()) |
| allowed.discard("self") |
| model_kwargs = {k: v for k, v in cfg_dict.items() if k in allowed} |
|
|
| self.backbone = VisionTransformer(**model_kwargs) |
|
|
| self.post_init() |
|
|
| def forward( |
| self, |
| pixel_values=None, |
| labels=None, |
| full=False, |
| output_hidden_states=None, |
| return_dict=None, |
| skip_masks=False, |
| **kwargs, |
| ): |
| output_hidden_states = ( |
| output_hidden_states |
| if output_hidden_states is not None |
| else self.config.output_hidden_states |
| ) |
| return_dict = ( |
| return_dict |
| if return_dict is not None |
| else self.config.use_return_dict |
| ) |
|
|
| last_hidden, all_hidden, masks = self.backbone.forward_features( |
| pixel_values, |
| full=full, |
| output_hidden_states=output_hidden_states, |
| skip_masks=skip_masks |
| ) |
|
|
| logits, distil_logits = self.backbone.forward_classifier(last_hidden) |
|
|
| loss = None |
| if labels is not None: |
| loss_fn = nn.CrossEntropyLoss() |
| loss = loss_fn(logits, labels) |
|
|
| if not return_dict: |
| return (loss, logits, distil_logits, last_hidden, all_hidden, masks) |
|
|
| return ImageClassifierWithMasksOutput( |
| loss=loss, |
| logits=logits, |
| distil_logits=distil_logits, |
| last_hidden_state=last_hidden, |
| hidden_states=all_hidden, |
| masks=masks, |
| ) |
|
|
|
|
|
|