File size: 2,815 Bytes
ddcb7ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c44bca
ddcb7ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c44bca
ddcb7ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# modeling_my_model.py
from dataclasses import dataclass
from typing import Optional, Tuple, List
import inspect
import torch
from torch import nn
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_outputs import ModelOutput

from .configuration_selectivevit import SMSelectiveViTConfig

from .selective_vit import VisionTransformer


@dataclass
class ImageClassifierWithMasksOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    logits: Optional[torch.FloatTensor] = None
    distil_logits: Optional[torch.FloatTensor] = None
    last_hidden_state: Optional[torch.FloatTensor] = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    masks: Optional[List[torch.FloatTensor]] = None



class SMSelectiveViTModelForClassification(PreTrainedModel ):
    config_class = SMSelectiveViTConfig
    base_model_prefix = "backbone"

    def __init__(self, config: SMSelectiveViTConfig):
        super().__init__(config)

        # Build kwargs for your original class from config fields
        cfg_dict = config.to_dict()

        # Filter to only args your MyModel.__init__ accepts
        sig = inspect.signature(VisionTransformer.__init__)
        allowed = set(sig.parameters.keys())
        allowed.discard("self")
        model_kwargs = {k: v for k, v in cfg_dict.items() if k in allowed}

        self.backbone = VisionTransformer(**model_kwargs)

        self.post_init()  # important: ties into HF weight init hooks

    def forward(
        self,
        pixel_values=None,
        labels=None,
        full=False,
        output_hidden_states=None,
        return_dict=None,
        skip_masks=False,
        **kwargs,
    ):
        output_hidden_states = (
            output_hidden_states
            if output_hidden_states is not None
            else self.config.output_hidden_states
        )
        return_dict = (
            return_dict
            if return_dict is not None
            else self.config.use_return_dict
        )

        last_hidden, all_hidden, masks = self.backbone.forward_features(
            pixel_values,
            full=full,
            output_hidden_states=output_hidden_states,
            skip_masks=skip_masks
        )

        logits, distil_logits = self.backbone.forward_classifier(last_hidden)

        loss = None
        if labels is not None:
            loss_fn = nn.CrossEntropyLoss()
            loss = loss_fn(logits, labels)

        if not return_dict:
            return (loss, logits, distil_logits, last_hidden, all_hidden, masks)

        return ImageClassifierWithMasksOutput(
            loss=loss,
            logits=logits,
            distil_logits=distil_logits,
            last_hidden_state=last_hidden,
            hidden_states=all_hidden,
            masks=masks,
        )