XAFT commited on
Commit
ddcb7ea
·
verified ·
1 Parent(s): a33ddda

Upload model files

Browse files
.gitattributes CHANGED
@@ -1,35 +1,5 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
  *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
  *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
 
2
  *.bin filter=lfs diff=lfs merge=lfs -text
3
+ *.pt filter=lfs diff=lfs merge=lfs -text
4
  *.ckpt filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
5
  *.onnx filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ library_name: transformers
4
+ tags:
5
+ - vision-transformer
6
+ - image-classification
7
+ - efficient-transformer
8
+ - selective-attention
9
+ - knowledge-distillation
10
+ - computer-vision
11
+ pipeline_tag: image-classification
12
+ ---
13
+
14
+ # Soft-Masked Selective Vision Transformer
15
+
16
+ ## Model Description
17
+
18
+ Soft-Masked Selective Vision Transformer is an efficient **Vision Transformer (ViT)** model designed to reduce the computational overhead of self-attention while maintaining competitive accuracy.
19
+ The model introduces a **patch-selective attention mechanism** that enables the transformer to focus on the most salient image regions and dynamically disregard less informative patches. This selective strategy significantly reduces the quadratic complexity typically associated with full self-attention, making the model particularly suitable for **high-resolution vision tasks** and **resource-constrained environments**.
20
+
21
+ To further improve performance, the model leverages **knowledge distillation**, transferring representational knowledge from a stronger teacher network to enhance the accuracy of lightweight transformer variants.
22
+
23
+ ---
24
+
25
+ ## Intended Use
26
+
27
+ This model is intended for:
28
+
29
+ - Image classification tasks
30
+ - Deployment in **compute- or memory-constrained environments**
31
+ - High-resolution image processing where standard ViTs are prohibitively expensive
32
+ - Research on efficient attention mechanisms and transformer compression
33
+
34
+ ### Example Use Cases
35
+
36
+ - Edge or embedded vision systems
37
+ - Large-scale image analysis with reduced inference cost
38
+ - Efficient backbones for downstream vision tasks
39
+
40
+ ---
41
+
42
+ ## Training Details
43
+
44
+ - **Training Objective:** Cross-entropy loss with optional distillation loss
45
+ - **Distillation:** Teacher–student framework
46
+ - **Optimization:** AdamW
47
+ - **Training Dataset:** ILSVRC 2012
48
+ - **Evaluation Metrics:** Top-1 accuracy, FLOPs, parameter count
49
+
50
+ ---
51
+
52
+ ## Usage
53
+
54
+ ### Image Classification Example
55
+
56
+ ```python
57
+ from transformers import AutoModelForImageClassification, AutoImageProcessor
58
+ from PIL import Image
59
+ import requests
60
+
61
+ # Load image
62
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
63
+ image = Image.open(requests.get(url, stream=True).raw)
64
+
65
+ # Load processor and model
66
+ processor = AutoImageProcessor.from_pretrained(
67
+ "XAFT/SM-Selective-ViT-Base-224-Distilled",
68
+ trust_remote_code=True,
69
+ )
70
+
71
+ model = AutoModelForImageClassification.from_pretrained(
72
+ "XAFT/SM-Selective-ViT-Base-224-Distilled",
73
+ trust_remote_code=True,
74
+ )
75
+
76
+ # Preprocess
77
+ inputs = processor(
78
+ images=image,
79
+ return_tensors="pt",
80
+ )
81
+
82
+ # Forward pass
83
+ outputs = model(**inputs)
84
+ logits = outputs.logits
85
+ predicted_class = logits.argmax(-1).item()
86
+
87
+ print("Predicted class index:", predicted_class)
88
+ ```
89
+
90
+ ---
91
+
92
+ ## Evaluation Results
93
+
94
+ | Model | Top-1 Acc. | Top-5 Acc. | # Params | Avg. GFLOPs |
95
+ |------------------------|------------|------------|----------|-------------|
96
+ | Base | 80.350% | 94.980% | 86.60M | 9.61 |
97
+ | Base (distilled) | 80.990% | 95.386% | 87.37M | 9.21 |
98
+ | Small | 78.662% | 94.454% | 22.06M | 3.12 |
99
+ | Small (distilled) | 79.000% | 94.494% | 22.45M | 3.05 |
100
+ | Tiny tall | 74.802% | 92.794% | 11.07M | 1.64 |
101
+ | Tiny tall (distilled) | 75.676% | 92.988% | 11.26M | 1.64 |
102
+ | Tiny | 71.056% | 90.192% | 5.72M | 0.95 |
103
+ | Tiny (distilled) | 72.618% | 91.338% | 5.92M | 0.93 |
104
+
105
+
106
+ ---
107
+
108
+ ## Acknowledgments
109
+
110
+ We thank the TPU Research Cloud program for providing cloud TPUs that were used to build and train the models for our extensive experiments.
__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .configuration_selectivevit import SMSelectiveViTConfig
2
+ from .modeling_selectivevit import SMSelectiveViTModelForClassification
3
+
4
+ __all__ = [
5
+ "SMSelectiveViTConfig",
6
+ "SMSelectiveViTModelForClassification",
7
+ ]
config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "SMSelectiveViTModelForClassification"
4
+ ],
5
+ "atten_dim": 768,
6
+ "attention_scale": 8.0,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_selectivevit.SMSelectiveViTConfig",
9
+ "AutoModelForImageClassification": "modeling_selectivevit.SMSelectiveViTModelForClassification"
10
+ },
11
+ "channels": 3,
12
+ "depth": 12,
13
+ "drop_path": 0.2,
14
+ "dropout": 0.0,
15
+ "dtype": "float32",
16
+ "embed_dim": 768,
17
+ "ffn_groups": null,
18
+ "image_size": 224,
19
+ "mask_threshold": 0.05,
20
+ "mlp_dim": 3072,
21
+ "model_type": "softmasked_selective_vit",
22
+ "num_classes": 1000,
23
+ "num_groups": 12,
24
+ "num_heads": 12,
25
+ "patch_drop": 0.05,
26
+ "patch_size": 16,
27
+ "transformers_version": "4.57.3",
28
+ "use_distil_token": true
29
+ }
configuration_selectivevit.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # configuration_my_model.py
2
+ from transformers import PretrainedConfig
3
+
4
+ class SMSelectiveViTConfig(PretrainedConfig):
5
+ model_type = "softmasked_selective_vit"
6
+
7
+ def __init__(
8
+ self,
9
+ image_size=224,
10
+ patch_size=16,
11
+ num_classes=1000,
12
+ embed_dim=768,
13
+ atten_dim=768,
14
+ depth=12,
15
+ num_heads=12,
16
+ mlp_dim=3072,
17
+ channels=3,
18
+ dropout=0.0,
19
+ drop_path=0.0,
20
+ attention_scale=0.0,
21
+ mask_threshold=0.0,
22
+ patch_drop=0.0,
23
+ use_distil_token=False,
24
+ **kwargs,
25
+ ):
26
+ super().__init__(**kwargs)
27
+ # store everything as attributes (HF will save them in config.json)
28
+ self.image_size = image_size
29
+ self.patch_size = patch_size
30
+ self.num_classes = num_classes
31
+ self.embed_dim = embed_dim
32
+ self.atten_dim = atten_dim
33
+ self.depth = depth
34
+ self.num_heads = num_heads
35
+ self.mlp_dim = mlp_dim
36
+ self.channels = channels
37
+ self.dropout = dropout
38
+ self.drop_path = drop_path
39
+ self.attention_scale = attention_scale
40
+ self.mask_threshold = mask_threshold
41
+ self.patch_drop = patch_drop
42
+ self.use_distil_token = use_distil_token
43
+
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4cd1a78510e92f164589ff458c23d69309413d14687029d5bcfccbc102ac06c6
3
+ size 349491616
modeling_selectivevit.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modeling_my_model.py
2
+ from dataclasses import dataclass
3
+ from typing import Optional, Tuple, List
4
+ import inspect
5
+ import torch
6
+ from torch import nn
7
+ from transformers.modeling_utils import PreTrainedModel
8
+ from transformers.modeling_outputs import ModelOutput
9
+
10
+ from .configuration_selectivevit import SMSelectiveViTConfig
11
+
12
+ from .selective_vit import VisionTransformer
13
+
14
+
15
+ @dataclass
16
+ class ImageClassifierWithMasksOutput(ModelOutput):
17
+ loss: Optional[torch.FloatTensor] = None
18
+ logits: Optional[torch.FloatTensor] = None
19
+ distil_logits: Optional[torch.FloatTensor] = None
20
+ last_hidden_state: Optional[torch.FloatTensor] = None
21
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
22
+ masks: Optional[List[torch.FloatTensor]] = None
23
+
24
+
25
+
26
+ class SMSelectiveViTModelForClassification(PreTrainedModel ):
27
+ config_class = SMSelectiveViTConfig
28
+ base_model_prefix = "backbone"
29
+
30
+ def __init__(self, config: SMSelectiveViTConfig):
31
+ super().__init__(config)
32
+
33
+ # Build kwargs for your original class from config fields
34
+ cfg_dict = config.to_dict()
35
+
36
+ # Filter to only args your MyModel.__init__ accepts
37
+ sig = inspect.signature(VisionTransformer.__init__)
38
+ allowed = set(sig.parameters.keys())
39
+ allowed.discard("self")
40
+ model_kwargs = {k: v for k, v in cfg_dict.items() if k in allowed}
41
+
42
+ self.backbone = VisionTransformer(**model_kwargs)
43
+
44
+ self.post_init() # important: ties into HF weight init hooks
45
+
46
+ def forward(
47
+ self,
48
+ pixel_values=None,
49
+ labels=None,
50
+ full=False,
51
+ output_hidden_states=None,
52
+ return_dict=None,
53
+ **kwargs,
54
+ ):
55
+ output_hidden_states = (
56
+ output_hidden_states
57
+ if output_hidden_states is not None
58
+ else self.config.output_hidden_states
59
+ )
60
+ return_dict = (
61
+ return_dict
62
+ if return_dict is not None
63
+ else self.config.use_return_dict
64
+ )
65
+
66
+ last_hidden, all_hidden, masks = self.backbone.forward_features(
67
+ pixel_values,
68
+ full=full,
69
+ output_hidden_states=output_hidden_states,
70
+ )
71
+
72
+ logits, distil_logits = self.backbone.forward_classifier(last_hidden)
73
+
74
+ loss = None
75
+ if labels is not None:
76
+ loss_fn = nn.CrossEntropyLoss()
77
+ loss = loss_fn(logits, labels)
78
+
79
+ if not return_dict:
80
+ return (loss, logits, distil_logits, last_hidden, all_hidden, masks)
81
+
82
+ return ImageClassifierWithMasksOutput(
83
+ loss=loss,
84
+ logits=logits,
85
+ distil_logits=distil_logits,
86
+ last_hidden_state=last_hidden,
87
+ hidden_states=all_hidden,
88
+ masks=masks,
89
+ )
90
+
91
+
preprocessor_config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "crop_size": {
3
+ "height": 224,
4
+ "width": 224
5
+ },
6
+ "do_center_crop": true,
7
+ "do_normalize": true,
8
+ "do_rescale": true,
9
+ "do_resize": true,
10
+ "image_mean": [
11
+ 0.485,
12
+ 0.456,
13
+ 0.406
14
+ ],
15
+ "image_processor_type": "DeiTImageProcessor",
16
+ "image_std": [
17
+ 0.229,
18
+ 0.224,
19
+ 0.225
20
+ ],
21
+ "resample": 3,
22
+ "rescale_factor": 0.00392156862745098,
23
+ "size": {
24
+ "height": 236,
25
+ "width": 236
26
+ }
27
+ }
selective_vit.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchvision.ops import StochasticDepth
5
+ import math
6
+
7
+ class SoftMaskedMultiheadAttention(nn.Module):
8
+ def __init__(self, embed_dim, num_heads, dropout=0.0, bias=True,
9
+ add_bias_kv=True, kdim=None, vdim=None,
10
+ scale=8., device=None, dtype=None):
11
+ super().__init__()
12
+ factory_kwargs = {'device': device, 'dtype': dtype}
13
+ self.embed_dim = embed_dim
14
+ self.kdim = kdim if kdim is not None else embed_dim
15
+ self.vdim = vdim if vdim is not None else embed_dim
16
+ self.num_heads = num_heads
17
+ self.dropout = dropout
18
+ self.scale = scale
19
+
20
+ assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
21
+
22
+ self.head_dim = embed_dim // num_heads
23
+
24
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
25
+ self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias and add_bias_kv, **factory_kwargs)
26
+ self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias and add_bias_kv, **factory_kwargs)
27
+
28
+ self.dropout_layer = nn.Dropout(dropout)
29
+
30
+ self.out_proj = nn.Linear(embed_dim, embed_dim)
31
+
32
+ self._reset_parameters()
33
+
34
+ def _reset_parameters(self):
35
+ nn.init.xavier_uniform_(self.q_proj.weight)
36
+ nn.init.xavier_uniform_(self.k_proj.weight)
37
+ nn.init.xavier_uniform_(self.v_proj.weight)
38
+
39
+ if self.q_proj.bias is not None:
40
+ nn.init.constant_(self.q_proj.bias, 0.)
41
+ if self.k_proj.bias is not None:
42
+ nn.init.constant_(self.k_proj.bias, 0.)
43
+ if self.v_proj.bias is not None:
44
+ nn.init.constant_(self.v_proj.bias, 0.)
45
+
46
+ nn.init.xavier_uniform_(self.out_proj.weight)
47
+ if self.v_proj.bias is not None:
48
+ nn.init.constant_(self.out_proj.bias, 0.)
49
+
50
+ def forward(self, query, key, value, key_padding_mask=None,
51
+ need_weights=True, attn_mask=None, average_attn_weights=True):
52
+ """
53
+ query, key, value: shape (L, N, E)
54
+ where L is the sequence length, N is the batch size, E is the embedding dimension.
55
+ """
56
+ batch_size, tgt_len, embed_dim = query.size()
57
+ batch_size, src_len, _ = key.size()
58
+
59
+ q = self.q_proj(query)
60
+ k = self.k_proj(key)
61
+ v = self.v_proj(value)
62
+
63
+ # Reshape q, k, v for multihead attention
64
+ q = q.view(batch_size, tgt_len, self.num_heads, self.head_dim).transpose(1,2)
65
+ k = k.view(batch_size, src_len, self.num_heads, self.head_dim).transpose(1,2)
66
+ v = v.view(batch_size, src_len, self.num_heads, self.head_dim).transpose(1,2)
67
+
68
+ # Compute scaled dot-product attention scores
69
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
70
+ # scores shape: (batch_size, num_heads, tgt_len, src_len)
71
+
72
+ # Apply the soft [0, 1] mask
73
+ if attn_mask is not None:
74
+ # Ensure attn_mask values are in (0, 1] to avoid log(0)
75
+ # attn_mask shape [b, l]
76
+ attn_mask = attn_mask.unsqueeze(1).unsqueeze(1)
77
+ if not self.training:
78
+ scores = scores.masked_fill((attn_mask == 0.), float('-inf'))
79
+ eps = 1e-6
80
+ attn_mask = attn_mask.clip(min=eps).log()
81
+ # attn_mask shape [b, 1, 1, l]
82
+ scores = scores + self.scale * attn_mask
83
+
84
+ # Apply key padding mask
85
+ if key_padding_mask is not None:
86
+ key_padding_mask = key_padding_mask.view(batch_size, 1, 1, src_len)
87
+ scores = scores.masked_fill(key_padding_mask, float('-inf'))
88
+
89
+ # Compute attention weights
90
+ attn_weights = F.softmax(scores, dim=-1)
91
+ attn_weights = self.dropout_layer(attn_weights)
92
+
93
+ # Compute attention output
94
+ attn_output = torch.matmul(attn_weights, v)
95
+ # attn_output shape: (batch_size, num_heads, tgt_len, head_dim)
96
+
97
+ # Concatenate heads and project
98
+ attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, tgt_len, embed_dim)
99
+
100
+ attn_output = self.out_proj(attn_output)
101
+
102
+ if need_weights:
103
+ # Optionally average attention weights over heads
104
+ if average_attn_weights:
105
+ attn_weights = attn_weights.mean(dim=1)
106
+ else:
107
+ attn_weights = attn_weights
108
+ else:
109
+ attn_weights = None
110
+
111
+ return attn_output, attn_weights
112
+
113
+ def get_ffn(input_dim, output_dim, middle_dim, dropout=0.1):
114
+ fc1 = nn.Linear(input_dim, middle_dim)
115
+ fc2 = nn.Linear(middle_dim, output_dim)
116
+ fc3 = nn.Identity()
117
+ return nn.Sequential(
118
+ fc1,
119
+ nn.GELU(),
120
+ nn.Dropout(dropout),
121
+ fc2,
122
+ nn.Dropout(dropout),
123
+ fc3
124
+ )
125
+ # Assuming SoftMaskedMultiheadAttention is already defined as provided earlier
126
+ class EncoderBlock(nn.Module):
127
+ def __init__(self, input_dim, embed_dim, num_heads, mlp_dim, dropout=0.1, drop_path=0.0, patch_drop=0.0, attention_scale=2., mask_threshold=0.05):
128
+ super().__init__()
129
+ self.mask_threshold = mask_threshold
130
+ self.self_attn = SoftMaskedMultiheadAttention(
131
+ embed_dim, num_heads, dropout=dropout, scale=attention_scale
132
+ )
133
+ if attention_scale > 0:
134
+ self.linear_mask = nn.Linear(input_dim, 1) # Linear layer to compute mask scores
135
+ self.patch_drop = nn.Dropout(patch_drop)
136
+ else:
137
+ self.linear_mask = None
138
+ if input_dim != embed_dim:
139
+ raise ValueError("embed_dim must equal atten_dim but {input_dim}!={embed_dim}")
140
+ else:
141
+ self.embed = nn.Identity()
142
+ self.project = nn.Identity()
143
+ self.norm1 = nn.LayerNorm(embed_dim)
144
+ self.norm2 = nn.LayerNorm(embed_dim)
145
+ # Feed-forward network (MLP)
146
+ self.mlp = get_ffn(embed_dim, embed_dim, mlp_dim, dropout=dropout)
147
+ self.path_drop = StochasticDepth(drop_path, mode='row')
148
+ self.norm3 = nn.LayerNorm(input_dim)
149
+
150
+ def _reset_parameters(self):
151
+ for n, m in self.named_modules():
152
+ if n.startswith('self_attn'):
153
+ continue
154
+ if isinstance(m, (nn.Linear, GroupedLinear)):
155
+ nn.init.trunc_normal_(m.weight.data, std=0.02)
156
+ if m.bias is not None:
157
+ nn.init.zeros_(m.bias.data)
158
+ nn.init.ones_(self.norm1.weight)
159
+ nn.init.zeros_(self.norm1.bias)
160
+ nn.init.ones_(self.norm2.weight)
161
+ nn.init.zeros_(self.norm2.bias)
162
+ nn.init.zeros_(self.norm3.weight)
163
+ nn.init.zeros_(self.norm3.bias)
164
+
165
+ def forward_common(self, x, mask):
166
+ """
167
+ x: shape (batch_size, seq_len, embed_dim)
168
+ """
169
+ # Compute mask scores: (batch_size, seq_len, 1)
170
+ x1 = x
171
+ x = self.embed(x)
172
+ x = self.norm1(x)
173
+ # Apply attention mechanism
174
+ attn_output, _ = self.self_attn(x, x, x, attn_mask=mask)
175
+ # Add & Norm
176
+ x = x + self.path_drop(attn_output)
177
+ x = self.norm2(x)
178
+ # Feed-forward network
179
+ mlp_output = self.mlp(x)
180
+ # Add & Norm
181
+ x = self.path_drop(self.project(x + mlp_output))
182
+ x = self.norm3(x)
183
+ if mask is not None:
184
+ x = x * mask.unsqueeze(-1)
185
+ x = x1 + x
186
+ return x
187
+
188
+ def get_groups(self, mask, full=False):
189
+ n_items, index = (mask != 0.0).sum(-1).cpu().sort(descending=True)
190
+ n_items, index = n_items.tolist(), index.tolist()
191
+ groups = []
192
+ t = 1.0 if full else 1.2
193
+ for ni, ii in zip(n_items, index):
194
+ if ni == 0:
195
+ break
196
+ if len(groups) == 0 or groups[-1][1] / ni > t:
197
+ groups.append(([], ni))
198
+ groups[-1][0].append(ii)
199
+ return groups
200
+
201
+ def infer_forward(self, x, mask, full=False):
202
+ """
203
+ The “sparse‐inference” path: for each group of batch‐samples that have the same
204
+ number n of tokens ≥ mask_threshold, gather only those top‐n tokens (in original order),
205
+ run forward_common on the smaller (b’, n, dim) tensor, then scatter the results back.
206
+ Fully masked tokens are left untouched.
207
+ """
208
+ # Step 1: Threshold the mask without in-place ops
209
+ mask_thresholded = mask * (mask >= self.mask_threshold)
210
+ # Step 2: Prepare output tensor (copy of x)
211
+ x_out = x.clone()
212
+ # Step 3: Group samples by number of kept tokens
213
+ groups = self.get_groups(mask_thresholded, full)
214
+ # Step 4: Process each group
215
+ for batch_indices, n_keep in groups:
216
+ x_sel = x[batch_indices] # (Bg, seq_len, input_dim)
217
+ mask_sel = mask_thresholded[batch_indices] # (Bg, seq_len)
218
+ # Top-k selection and sorting
219
+ topk_vals, topk_idx_unsorted = torch.topk(mask_sel, k=n_keep, dim=1, sorted=False)
220
+ topk_idx_sorted, _ = topk_idx_unsorted.sort(dim=1)
221
+ # Gather tokens in sorted order
222
+ idx_expanded = topk_idx_sorted.unsqueeze(-1).expand(-1, -1, x_sel.size(-1))
223
+ X_topk = torch.gather(x_sel, dim=1, index=idx_expanded)
224
+ mask_topk = torch.gather(mask_sel, dim=1, index=topk_idx_sorted)
225
+ # Run forward pass
226
+ results = self.forward_common(X_topk, mask_topk)
227
+ # Scatter results into a new x_sel tensor
228
+ x_sel_updated = x_sel.clone()
229
+ x_sel_updated = x_sel_updated.scatter(1, idx_expanded, results)
230
+ # Write the updated batch slice into the new output tensor
231
+ x_out[batch_indices] = x_sel_updated
232
+ return x_out
233
+
234
+ def forward(self, x, full=False):
235
+ if self.linear_mask is not None:
236
+ attn_mask = self.patch_drop(self.linear_mask(x).sigmoid().squeeze(-1))
237
+ else:
238
+ attn_mask = None
239
+ if not self.training and not attn_mask is None and self.mask_threshold >= 0:
240
+ x = self.infer_forward(x, attn_mask, full)
241
+ else:
242
+ x = self.forward_common(x, attn_mask)
243
+ return x, attn_mask
244
+
245
+
246
+ class VisionTransformer(nn.Module):
247
+ def __init__(
248
+ self,
249
+ image_size=256,
250
+ patch_size=16,
251
+ num_classes=1000,
252
+ embed_dim=768,
253
+ atten_dim=192,
254
+ depth=12,
255
+ num_heads=3,
256
+ mlp_dim=768,
257
+ channels=3,
258
+ dropout=0.1,
259
+ drop_path=0.1,
260
+ patch_drop=0.1,
261
+ attention_scale=2.,
262
+ mask_threshold=0.05,
263
+ use_distil_token=False
264
+ ):
265
+ super().__init__()
266
+ assert image_size % patch_size == 0, "Image dimensions must be divisible by the patch size."
267
+ num_patches = (image_size // patch_size) ** 2
268
+
269
+ # Patch embedding layer
270
+ self.patch_embed = nn.Conv2d(
271
+ in_channels=channels,
272
+ out_channels=embed_dim,
273
+ kernel_size=patch_size,
274
+ stride=patch_size
275
+ )
276
+
277
+ # Class token
278
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
279
+
280
+ # Positional embedding
281
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1 + (1 if use_distil_token else 0), embed_dim))
282
+
283
+ self.dropout = nn.Dropout(dropout)
284
+
285
+ # Encoder blocks
286
+ self.encoder_layers = nn.ModuleList([
287
+ EncoderBlock(
288
+ embed_dim, atten_dim,
289
+ num_heads, mlp_dim,
290
+ dropout, drop_path * i / (depth - 1),
291
+ patch_drop=patch_drop,
292
+ attention_scale=attention_scale,
293
+ mask_threshold=mask_threshold,
294
+ )
295
+ for i in range(depth)
296
+ ])
297
+
298
+ # Classification head
299
+ self.post_norm = nn.LayerNorm(embed_dim)
300
+ self.head = nn.Linear(embed_dim, num_classes)
301
+ if use_distil_token:
302
+ self.dis_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
303
+ self.dis_head = nn.Linear(embed_dim, num_classes)
304
+ else:
305
+ self.dis_token = None
306
+
307
+ # Initialize weights
308
+ self._init_weights()
309
+
310
+ def _init_weights(self):
311
+ for n, m in self.named_modules():
312
+ if n.startswith('encoder_layers'):
313
+ continue
314
+ if isinstance(m, (nn.Linear, nn.Conv2d)):
315
+ nn.init.trunc_normal_(m.weight.data, std=0.02)
316
+ if m.bias is not None:
317
+ nn.init.zeros_(m.bias.data)
318
+ if isinstance(m, nn.LayerNorm):
319
+ nn.init.ones_(m.weight.data)
320
+ nn.init.zeros_(m.bias.data)
321
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
322
+ if self.cls_token is not None:
323
+ nn.init.trunc_normal_(self.cls_token, std=0.02)
324
+ if self.dis_token is not None:
325
+ nn.init.trunc_normal_(self.dis_token, std=0.02)
326
+
327
+ def forward_features(
328
+ self,
329
+ pixel_values,
330
+ full=False,
331
+ output_hidden_states=False,
332
+ ):
333
+ """
334
+ Args:
335
+ pixel_values: (B, C, H, W)
336
+ Returns:
337
+ last_hidden_state: (B, N, D)
338
+ all_hidden_states: tuple or None
339
+ masks: Tensor or None
340
+ """
341
+ batch_size = pixel_values.size(0)
342
+ hidden_states = []
343
+
344
+ # Patch embedding
345
+ x = self.patch_embed(pixel_values)
346
+ x = x.flatten(2).transpose(1, 2)
347
+
348
+ # Distillation token
349
+ if self.dis_token is not None:
350
+ dis_tokens = self.dis_token.expand(batch_size, -1, -1)
351
+ x = torch.cat((dis_tokens, x), dim=1)
352
+
353
+ # CLS token
354
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
355
+ x = torch.cat((cls_tokens, x), dim=1)
356
+
357
+ # Position + dropout
358
+ x = x + self.pos_embed
359
+ x = self.dropout(x)
360
+
361
+ masks = []
362
+
363
+ for layer in self.encoder_layers:
364
+ x, mask = layer(x, full)
365
+
366
+ if output_hidden_states:
367
+ hidden_states.append(x)
368
+
369
+ if mask is not None:
370
+ masks.append(mask)
371
+
372
+ x = self.post_norm(x)
373
+
374
+ if output_hidden_states:
375
+ hidden_states = tuple(hidden_states)
376
+ else:
377
+ hidden_states = None
378
+
379
+ if len(masks) > 0:
380
+ masks = tuple(masks)
381
+ else:
382
+ masks = None
383
+
384
+ return x, hidden_states, masks
385
+
386
+
387
+ def forward_classifier(self, hidden_states):
388
+ """
389
+ Args:
390
+ hidden_states: (B, N, D)
391
+ Returns:
392
+ logits: (B, num_classes)
393
+ dis_logits: (B, num_classes) or None
394
+ """
395
+ cls_token = hidden_states[:, 0]
396
+ logits = self.head(cls_token)
397
+
398
+ dis_logits = None
399
+ if self.dis_token is not None:
400
+ dis_cls_token = hidden_states[:, 1]
401
+ dis_logits = self.dis_head(dis_cls_token)
402
+
403
+ # Inference-time averaging (same as original)
404
+ if not self.training:
405
+ logits = (logits + dis_logits) / 2
406
+
407
+ return logits, dis_logits
408
+
409
+ def forward(self, x, full=False):
410
+ last_hidden_states, hidden_states, masks = self.forward_features(x, full)
411
+ logits, dis_logits = self.forward_classifier(last_hidden_states)
412
+ return logits, dis_logits, masks