SM-Selective-ViT-Base-224-Distilled / modeling_selectivevit.py
XAFT's picture
Add support for FlashAttention
1c44bca verified
# modeling_my_model.py
from dataclasses import dataclass
from typing import Optional, Tuple, List
import inspect
import torch
from torch import nn
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_outputs import ModelOutput
from .configuration_selectivevit import SMSelectiveViTConfig
from .selective_vit import VisionTransformer
@dataclass
class ImageClassifierWithMasksOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: Optional[torch.FloatTensor] = None
distil_logits: Optional[torch.FloatTensor] = None
last_hidden_state: Optional[torch.FloatTensor] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
masks: Optional[List[torch.FloatTensor]] = None
class SMSelectiveViTModelForClassification(PreTrainedModel ):
config_class = SMSelectiveViTConfig
base_model_prefix = "backbone"
def __init__(self, config: SMSelectiveViTConfig):
super().__init__(config)
# Build kwargs for your original class from config fields
cfg_dict = config.to_dict()
# Filter to only args your MyModel.__init__ accepts
sig = inspect.signature(VisionTransformer.__init__)
allowed = set(sig.parameters.keys())
allowed.discard("self")
model_kwargs = {k: v for k, v in cfg_dict.items() if k in allowed}
self.backbone = VisionTransformer(**model_kwargs)
self.post_init() # important: ties into HF weight init hooks
def forward(
self,
pixel_values=None,
labels=None,
full=False,
output_hidden_states=None,
return_dict=None,
skip_masks=False,
**kwargs,
):
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict
if return_dict is not None
else self.config.use_return_dict
)
last_hidden, all_hidden, masks = self.backbone.forward_features(
pixel_values,
full=full,
output_hidden_states=output_hidden_states,
skip_masks=skip_masks
)
logits, distil_logits = self.backbone.forward_classifier(last_hidden)
loss = None
if labels is not None:
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(logits, labels)
if not return_dict:
return (loss, logits, distil_logits, last_hidden, all_hidden, masks)
return ImageClassifierWithMasksOutput(
loss=loss,
logits=logits,
distil_logits=distil_logits,
last_hidden_state=last_hidden,
hidden_states=all_hidden,
masks=masks,
)