# modeling_my_model.py from dataclasses import dataclass from typing import Optional, Tuple, List import inspect import torch from torch import nn from transformers.modeling_utils import PreTrainedModel from transformers.modeling_outputs import ModelOutput from .configuration_selectivevit import SMSelectiveViTConfig from .selective_vit import VisionTransformer @dataclass class ImageClassifierWithMasksOutput(ModelOutput): loss: Optional[torch.FloatTensor] = None logits: Optional[torch.FloatTensor] = None distil_logits: Optional[torch.FloatTensor] = None last_hidden_state: Optional[torch.FloatTensor] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None masks: Optional[List[torch.FloatTensor]] = None class SMSelectiveViTModelForClassification(PreTrainedModel ): config_class = SMSelectiveViTConfig base_model_prefix = "backbone" def __init__(self, config: SMSelectiveViTConfig): super().__init__(config) # Build kwargs for your original class from config fields cfg_dict = config.to_dict() # Filter to only args your MyModel.__init__ accepts sig = inspect.signature(VisionTransformer.__init__) allowed = set(sig.parameters.keys()) allowed.discard("self") model_kwargs = {k: v for k, v in cfg_dict.items() if k in allowed} self.backbone = VisionTransformer(**model_kwargs) self.post_init() # important: ties into HF weight init hooks def forward( self, pixel_values=None, labels=None, full=False, output_hidden_states=None, return_dict=None, skip_masks=False, **kwargs, ): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) last_hidden, all_hidden, masks = self.backbone.forward_features( pixel_values, full=full, output_hidden_states=output_hidden_states, skip_masks=skip_masks ) logits, distil_logits = self.backbone.forward_classifier(last_hidden) loss = None if labels is not None: loss_fn = nn.CrossEntropyLoss() loss = loss_fn(logits, labels) if not return_dict: return (loss, logits, distil_logits, last_hidden, all_hidden, masks) return ImageClassifierWithMasksOutput( loss=loss, logits=logits, distil_logits=distil_logits, last_hidden_state=last_hidden, hidden_states=all_hidden, masks=masks, )