Upload model files
Browse files- .gitattributes +2 -32
- README.md +110 -0
- __init__.py +7 -0
- config.json +29 -0
- configuration_selectivevit.py +43 -0
- model.safetensors +3 -0
- modeling_selectivevit.py +91 -0
- preprocessor_config.json +27 -0
- selective_vit.py +412 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,5 @@
|
|
| 1 |
-
*.
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.
|
| 5 |
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 1 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 2 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 4 |
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
*.onnx filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
library_name: transformers
|
| 4 |
+
tags:
|
| 5 |
+
- vision-transformer
|
| 6 |
+
- image-classification
|
| 7 |
+
- efficient-transformer
|
| 8 |
+
- selective-attention
|
| 9 |
+
- knowledge-distillation
|
| 10 |
+
- computer-vision
|
| 11 |
+
pipeline_tag: image-classification
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
# Soft-Masked Selective Vision Transformer
|
| 15 |
+
|
| 16 |
+
## Model Description
|
| 17 |
+
|
| 18 |
+
Soft-Masked Selective Vision Transformer is an efficient **Vision Transformer (ViT)** model designed to reduce the computational overhead of self-attention while maintaining competitive accuracy.
|
| 19 |
+
The model introduces a **patch-selective attention mechanism** that enables the transformer to focus on the most salient image regions and dynamically disregard less informative patches. This selective strategy significantly reduces the quadratic complexity typically associated with full self-attention, making the model particularly suitable for **high-resolution vision tasks** and **resource-constrained environments**.
|
| 20 |
+
|
| 21 |
+
To further improve performance, the model leverages **knowledge distillation**, transferring representational knowledge from a stronger teacher network to enhance the accuracy of lightweight transformer variants.
|
| 22 |
+
|
| 23 |
+
---
|
| 24 |
+
|
| 25 |
+
## Intended Use
|
| 26 |
+
|
| 27 |
+
This model is intended for:
|
| 28 |
+
|
| 29 |
+
- Image classification tasks
|
| 30 |
+
- Deployment in **compute- or memory-constrained environments**
|
| 31 |
+
- High-resolution image processing where standard ViTs are prohibitively expensive
|
| 32 |
+
- Research on efficient attention mechanisms and transformer compression
|
| 33 |
+
|
| 34 |
+
### Example Use Cases
|
| 35 |
+
|
| 36 |
+
- Edge or embedded vision systems
|
| 37 |
+
- Large-scale image analysis with reduced inference cost
|
| 38 |
+
- Efficient backbones for downstream vision tasks
|
| 39 |
+
|
| 40 |
+
---
|
| 41 |
+
|
| 42 |
+
## Training Details
|
| 43 |
+
|
| 44 |
+
- **Training Objective:** Cross-entropy loss with optional distillation loss
|
| 45 |
+
- **Distillation:** Teacher–student framework
|
| 46 |
+
- **Optimization:** AdamW
|
| 47 |
+
- **Training Dataset:** ILSVRC 2012
|
| 48 |
+
- **Evaluation Metrics:** Top-1 accuracy, FLOPs, parameter count
|
| 49 |
+
|
| 50 |
+
---
|
| 51 |
+
|
| 52 |
+
## Usage
|
| 53 |
+
|
| 54 |
+
### Image Classification Example
|
| 55 |
+
|
| 56 |
+
```python
|
| 57 |
+
from transformers import AutoModelForImageClassification, AutoImageProcessor
|
| 58 |
+
from PIL import Image
|
| 59 |
+
import requests
|
| 60 |
+
|
| 61 |
+
# Load image
|
| 62 |
+
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 63 |
+
image = Image.open(requests.get(url, stream=True).raw)
|
| 64 |
+
|
| 65 |
+
# Load processor and model
|
| 66 |
+
processor = AutoImageProcessor.from_pretrained(
|
| 67 |
+
"XAFT/SM-Selective-ViT-Base-224-Distilled",
|
| 68 |
+
trust_remote_code=True,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
model = AutoModelForImageClassification.from_pretrained(
|
| 72 |
+
"XAFT/SM-Selective-ViT-Base-224-Distilled",
|
| 73 |
+
trust_remote_code=True,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
# Preprocess
|
| 77 |
+
inputs = processor(
|
| 78 |
+
images=image,
|
| 79 |
+
return_tensors="pt",
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# Forward pass
|
| 83 |
+
outputs = model(**inputs)
|
| 84 |
+
logits = outputs.logits
|
| 85 |
+
predicted_class = logits.argmax(-1).item()
|
| 86 |
+
|
| 87 |
+
print("Predicted class index:", predicted_class)
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
---
|
| 91 |
+
|
| 92 |
+
## Evaluation Results
|
| 93 |
+
|
| 94 |
+
| Model | Top-1 Acc. | Top-5 Acc. | # Params | Avg. GFLOPs |
|
| 95 |
+
|------------------------|------------|------------|----------|-------------|
|
| 96 |
+
| Base | 80.350% | 94.980% | 86.60M | 9.61 |
|
| 97 |
+
| Base (distilled) | 80.990% | 95.386% | 87.37M | 9.21 |
|
| 98 |
+
| Small | 78.662% | 94.454% | 22.06M | 3.12 |
|
| 99 |
+
| Small (distilled) | 79.000% | 94.494% | 22.45M | 3.05 |
|
| 100 |
+
| Tiny tall | 74.802% | 92.794% | 11.07M | 1.64 |
|
| 101 |
+
| Tiny tall (distilled) | 75.676% | 92.988% | 11.26M | 1.64 |
|
| 102 |
+
| Tiny | 71.056% | 90.192% | 5.72M | 0.95 |
|
| 103 |
+
| Tiny (distilled) | 72.618% | 91.338% | 5.92M | 0.93 |
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
---
|
| 107 |
+
|
| 108 |
+
## Acknowledgments
|
| 109 |
+
|
| 110 |
+
We thank the TPU Research Cloud program for providing cloud TPUs that were used to build and train the models for our extensive experiments.
|
__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .configuration_selectivevit import SMSelectiveViTConfig
|
| 2 |
+
from .modeling_selectivevit import SMSelectiveViTModelForClassification
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
"SMSelectiveViTConfig",
|
| 6 |
+
"SMSelectiveViTModelForClassification",
|
| 7 |
+
]
|
config.json
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"SMSelectiveViTModelForClassification"
|
| 4 |
+
],
|
| 5 |
+
"atten_dim": 768,
|
| 6 |
+
"attention_scale": 8.0,
|
| 7 |
+
"auto_map": {
|
| 8 |
+
"AutoConfig": "configuration_selectivevit.SMSelectiveViTConfig",
|
| 9 |
+
"AutoModelForImageClassification": "modeling_selectivevit.SMSelectiveViTModelForClassification"
|
| 10 |
+
},
|
| 11 |
+
"channels": 3,
|
| 12 |
+
"depth": 12,
|
| 13 |
+
"drop_path": 0.2,
|
| 14 |
+
"dropout": 0.0,
|
| 15 |
+
"dtype": "float32",
|
| 16 |
+
"embed_dim": 768,
|
| 17 |
+
"ffn_groups": null,
|
| 18 |
+
"image_size": 224,
|
| 19 |
+
"mask_threshold": 0.05,
|
| 20 |
+
"mlp_dim": 3072,
|
| 21 |
+
"model_type": "softmasked_selective_vit",
|
| 22 |
+
"num_classes": 1000,
|
| 23 |
+
"num_groups": 12,
|
| 24 |
+
"num_heads": 12,
|
| 25 |
+
"patch_drop": 0.05,
|
| 26 |
+
"patch_size": 16,
|
| 27 |
+
"transformers_version": "4.57.3",
|
| 28 |
+
"use_distil_token": true
|
| 29 |
+
}
|
configuration_selectivevit.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# configuration_my_model.py
|
| 2 |
+
from transformers import PretrainedConfig
|
| 3 |
+
|
| 4 |
+
class SMSelectiveViTConfig(PretrainedConfig):
|
| 5 |
+
model_type = "softmasked_selective_vit"
|
| 6 |
+
|
| 7 |
+
def __init__(
|
| 8 |
+
self,
|
| 9 |
+
image_size=224,
|
| 10 |
+
patch_size=16,
|
| 11 |
+
num_classes=1000,
|
| 12 |
+
embed_dim=768,
|
| 13 |
+
atten_dim=768,
|
| 14 |
+
depth=12,
|
| 15 |
+
num_heads=12,
|
| 16 |
+
mlp_dim=3072,
|
| 17 |
+
channels=3,
|
| 18 |
+
dropout=0.0,
|
| 19 |
+
drop_path=0.0,
|
| 20 |
+
attention_scale=0.0,
|
| 21 |
+
mask_threshold=0.0,
|
| 22 |
+
patch_drop=0.0,
|
| 23 |
+
use_distil_token=False,
|
| 24 |
+
**kwargs,
|
| 25 |
+
):
|
| 26 |
+
super().__init__(**kwargs)
|
| 27 |
+
# store everything as attributes (HF will save them in config.json)
|
| 28 |
+
self.image_size = image_size
|
| 29 |
+
self.patch_size = patch_size
|
| 30 |
+
self.num_classes = num_classes
|
| 31 |
+
self.embed_dim = embed_dim
|
| 32 |
+
self.atten_dim = atten_dim
|
| 33 |
+
self.depth = depth
|
| 34 |
+
self.num_heads = num_heads
|
| 35 |
+
self.mlp_dim = mlp_dim
|
| 36 |
+
self.channels = channels
|
| 37 |
+
self.dropout = dropout
|
| 38 |
+
self.drop_path = drop_path
|
| 39 |
+
self.attention_scale = attention_scale
|
| 40 |
+
self.mask_threshold = mask_threshold
|
| 41 |
+
self.patch_drop = patch_drop
|
| 42 |
+
self.use_distil_token = use_distil_token
|
| 43 |
+
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4cd1a78510e92f164589ff458c23d69309413d14687029d5bcfccbc102ac06c6
|
| 3 |
+
size 349491616
|
modeling_selectivevit.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# modeling_my_model.py
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Optional, Tuple, List
|
| 4 |
+
import inspect
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 8 |
+
from transformers.modeling_outputs import ModelOutput
|
| 9 |
+
|
| 10 |
+
from .configuration_selectivevit import SMSelectiveViTConfig
|
| 11 |
+
|
| 12 |
+
from .selective_vit import VisionTransformer
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class ImageClassifierWithMasksOutput(ModelOutput):
|
| 17 |
+
loss: Optional[torch.FloatTensor] = None
|
| 18 |
+
logits: Optional[torch.FloatTensor] = None
|
| 19 |
+
distil_logits: Optional[torch.FloatTensor] = None
|
| 20 |
+
last_hidden_state: Optional[torch.FloatTensor] = None
|
| 21 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 22 |
+
masks: Optional[List[torch.FloatTensor]] = None
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class SMSelectiveViTModelForClassification(PreTrainedModel ):
|
| 27 |
+
config_class = SMSelectiveViTConfig
|
| 28 |
+
base_model_prefix = "backbone"
|
| 29 |
+
|
| 30 |
+
def __init__(self, config: SMSelectiveViTConfig):
|
| 31 |
+
super().__init__(config)
|
| 32 |
+
|
| 33 |
+
# Build kwargs for your original class from config fields
|
| 34 |
+
cfg_dict = config.to_dict()
|
| 35 |
+
|
| 36 |
+
# Filter to only args your MyModel.__init__ accepts
|
| 37 |
+
sig = inspect.signature(VisionTransformer.__init__)
|
| 38 |
+
allowed = set(sig.parameters.keys())
|
| 39 |
+
allowed.discard("self")
|
| 40 |
+
model_kwargs = {k: v for k, v in cfg_dict.items() if k in allowed}
|
| 41 |
+
|
| 42 |
+
self.backbone = VisionTransformer(**model_kwargs)
|
| 43 |
+
|
| 44 |
+
self.post_init() # important: ties into HF weight init hooks
|
| 45 |
+
|
| 46 |
+
def forward(
|
| 47 |
+
self,
|
| 48 |
+
pixel_values=None,
|
| 49 |
+
labels=None,
|
| 50 |
+
full=False,
|
| 51 |
+
output_hidden_states=None,
|
| 52 |
+
return_dict=None,
|
| 53 |
+
**kwargs,
|
| 54 |
+
):
|
| 55 |
+
output_hidden_states = (
|
| 56 |
+
output_hidden_states
|
| 57 |
+
if output_hidden_states is not None
|
| 58 |
+
else self.config.output_hidden_states
|
| 59 |
+
)
|
| 60 |
+
return_dict = (
|
| 61 |
+
return_dict
|
| 62 |
+
if return_dict is not None
|
| 63 |
+
else self.config.use_return_dict
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
last_hidden, all_hidden, masks = self.backbone.forward_features(
|
| 67 |
+
pixel_values,
|
| 68 |
+
full=full,
|
| 69 |
+
output_hidden_states=output_hidden_states,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
logits, distil_logits = self.backbone.forward_classifier(last_hidden)
|
| 73 |
+
|
| 74 |
+
loss = None
|
| 75 |
+
if labels is not None:
|
| 76 |
+
loss_fn = nn.CrossEntropyLoss()
|
| 77 |
+
loss = loss_fn(logits, labels)
|
| 78 |
+
|
| 79 |
+
if not return_dict:
|
| 80 |
+
return (loss, logits, distil_logits, last_hidden, all_hidden, masks)
|
| 81 |
+
|
| 82 |
+
return ImageClassifierWithMasksOutput(
|
| 83 |
+
loss=loss,
|
| 84 |
+
logits=logits,
|
| 85 |
+
distil_logits=distil_logits,
|
| 86 |
+
last_hidden_state=last_hidden,
|
| 87 |
+
hidden_states=all_hidden,
|
| 88 |
+
masks=masks,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
|
preprocessor_config.json
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"crop_size": {
|
| 3 |
+
"height": 224,
|
| 4 |
+
"width": 224
|
| 5 |
+
},
|
| 6 |
+
"do_center_crop": true,
|
| 7 |
+
"do_normalize": true,
|
| 8 |
+
"do_rescale": true,
|
| 9 |
+
"do_resize": true,
|
| 10 |
+
"image_mean": [
|
| 11 |
+
0.485,
|
| 12 |
+
0.456,
|
| 13 |
+
0.406
|
| 14 |
+
],
|
| 15 |
+
"image_processor_type": "DeiTImageProcessor",
|
| 16 |
+
"image_std": [
|
| 17 |
+
0.229,
|
| 18 |
+
0.224,
|
| 19 |
+
0.225
|
| 20 |
+
],
|
| 21 |
+
"resample": 3,
|
| 22 |
+
"rescale_factor": 0.00392156862745098,
|
| 23 |
+
"size": {
|
| 24 |
+
"height": 236,
|
| 25 |
+
"width": 236
|
| 26 |
+
}
|
| 27 |
+
}
|
selective_vit.py
ADDED
|
@@ -0,0 +1,412 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torchvision.ops import StochasticDepth
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
class SoftMaskedMultiheadAttention(nn.Module):
|
| 8 |
+
def __init__(self, embed_dim, num_heads, dropout=0.0, bias=True,
|
| 9 |
+
add_bias_kv=True, kdim=None, vdim=None,
|
| 10 |
+
scale=8., device=None, dtype=None):
|
| 11 |
+
super().__init__()
|
| 12 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
| 13 |
+
self.embed_dim = embed_dim
|
| 14 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
| 15 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
| 16 |
+
self.num_heads = num_heads
|
| 17 |
+
self.dropout = dropout
|
| 18 |
+
self.scale = scale
|
| 19 |
+
|
| 20 |
+
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
|
| 21 |
+
|
| 22 |
+
self.head_dim = embed_dim // num_heads
|
| 23 |
+
|
| 24 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
|
| 25 |
+
self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias and add_bias_kv, **factory_kwargs)
|
| 26 |
+
self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias and add_bias_kv, **factory_kwargs)
|
| 27 |
+
|
| 28 |
+
self.dropout_layer = nn.Dropout(dropout)
|
| 29 |
+
|
| 30 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim)
|
| 31 |
+
|
| 32 |
+
self._reset_parameters()
|
| 33 |
+
|
| 34 |
+
def _reset_parameters(self):
|
| 35 |
+
nn.init.xavier_uniform_(self.q_proj.weight)
|
| 36 |
+
nn.init.xavier_uniform_(self.k_proj.weight)
|
| 37 |
+
nn.init.xavier_uniform_(self.v_proj.weight)
|
| 38 |
+
|
| 39 |
+
if self.q_proj.bias is not None:
|
| 40 |
+
nn.init.constant_(self.q_proj.bias, 0.)
|
| 41 |
+
if self.k_proj.bias is not None:
|
| 42 |
+
nn.init.constant_(self.k_proj.bias, 0.)
|
| 43 |
+
if self.v_proj.bias is not None:
|
| 44 |
+
nn.init.constant_(self.v_proj.bias, 0.)
|
| 45 |
+
|
| 46 |
+
nn.init.xavier_uniform_(self.out_proj.weight)
|
| 47 |
+
if self.v_proj.bias is not None:
|
| 48 |
+
nn.init.constant_(self.out_proj.bias, 0.)
|
| 49 |
+
|
| 50 |
+
def forward(self, query, key, value, key_padding_mask=None,
|
| 51 |
+
need_weights=True, attn_mask=None, average_attn_weights=True):
|
| 52 |
+
"""
|
| 53 |
+
query, key, value: shape (L, N, E)
|
| 54 |
+
where L is the sequence length, N is the batch size, E is the embedding dimension.
|
| 55 |
+
"""
|
| 56 |
+
batch_size, tgt_len, embed_dim = query.size()
|
| 57 |
+
batch_size, src_len, _ = key.size()
|
| 58 |
+
|
| 59 |
+
q = self.q_proj(query)
|
| 60 |
+
k = self.k_proj(key)
|
| 61 |
+
v = self.v_proj(value)
|
| 62 |
+
|
| 63 |
+
# Reshape q, k, v for multihead attention
|
| 64 |
+
q = q.view(batch_size, tgt_len, self.num_heads, self.head_dim).transpose(1,2)
|
| 65 |
+
k = k.view(batch_size, src_len, self.num_heads, self.head_dim).transpose(1,2)
|
| 66 |
+
v = v.view(batch_size, src_len, self.num_heads, self.head_dim).transpose(1,2)
|
| 67 |
+
|
| 68 |
+
# Compute scaled dot-product attention scores
|
| 69 |
+
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
| 70 |
+
# scores shape: (batch_size, num_heads, tgt_len, src_len)
|
| 71 |
+
|
| 72 |
+
# Apply the soft [0, 1] mask
|
| 73 |
+
if attn_mask is not None:
|
| 74 |
+
# Ensure attn_mask values are in (0, 1] to avoid log(0)
|
| 75 |
+
# attn_mask shape [b, l]
|
| 76 |
+
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1)
|
| 77 |
+
if not self.training:
|
| 78 |
+
scores = scores.masked_fill((attn_mask == 0.), float('-inf'))
|
| 79 |
+
eps = 1e-6
|
| 80 |
+
attn_mask = attn_mask.clip(min=eps).log()
|
| 81 |
+
# attn_mask shape [b, 1, 1, l]
|
| 82 |
+
scores = scores + self.scale * attn_mask
|
| 83 |
+
|
| 84 |
+
# Apply key padding mask
|
| 85 |
+
if key_padding_mask is not None:
|
| 86 |
+
key_padding_mask = key_padding_mask.view(batch_size, 1, 1, src_len)
|
| 87 |
+
scores = scores.masked_fill(key_padding_mask, float('-inf'))
|
| 88 |
+
|
| 89 |
+
# Compute attention weights
|
| 90 |
+
attn_weights = F.softmax(scores, dim=-1)
|
| 91 |
+
attn_weights = self.dropout_layer(attn_weights)
|
| 92 |
+
|
| 93 |
+
# Compute attention output
|
| 94 |
+
attn_output = torch.matmul(attn_weights, v)
|
| 95 |
+
# attn_output shape: (batch_size, num_heads, tgt_len, head_dim)
|
| 96 |
+
|
| 97 |
+
# Concatenate heads and project
|
| 98 |
+
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, tgt_len, embed_dim)
|
| 99 |
+
|
| 100 |
+
attn_output = self.out_proj(attn_output)
|
| 101 |
+
|
| 102 |
+
if need_weights:
|
| 103 |
+
# Optionally average attention weights over heads
|
| 104 |
+
if average_attn_weights:
|
| 105 |
+
attn_weights = attn_weights.mean(dim=1)
|
| 106 |
+
else:
|
| 107 |
+
attn_weights = attn_weights
|
| 108 |
+
else:
|
| 109 |
+
attn_weights = None
|
| 110 |
+
|
| 111 |
+
return attn_output, attn_weights
|
| 112 |
+
|
| 113 |
+
def get_ffn(input_dim, output_dim, middle_dim, dropout=0.1):
|
| 114 |
+
fc1 = nn.Linear(input_dim, middle_dim)
|
| 115 |
+
fc2 = nn.Linear(middle_dim, output_dim)
|
| 116 |
+
fc3 = nn.Identity()
|
| 117 |
+
return nn.Sequential(
|
| 118 |
+
fc1,
|
| 119 |
+
nn.GELU(),
|
| 120 |
+
nn.Dropout(dropout),
|
| 121 |
+
fc2,
|
| 122 |
+
nn.Dropout(dropout),
|
| 123 |
+
fc3
|
| 124 |
+
)
|
| 125 |
+
# Assuming SoftMaskedMultiheadAttention is already defined as provided earlier
|
| 126 |
+
class EncoderBlock(nn.Module):
|
| 127 |
+
def __init__(self, input_dim, embed_dim, num_heads, mlp_dim, dropout=0.1, drop_path=0.0, patch_drop=0.0, attention_scale=2., mask_threshold=0.05):
|
| 128 |
+
super().__init__()
|
| 129 |
+
self.mask_threshold = mask_threshold
|
| 130 |
+
self.self_attn = SoftMaskedMultiheadAttention(
|
| 131 |
+
embed_dim, num_heads, dropout=dropout, scale=attention_scale
|
| 132 |
+
)
|
| 133 |
+
if attention_scale > 0:
|
| 134 |
+
self.linear_mask = nn.Linear(input_dim, 1) # Linear layer to compute mask scores
|
| 135 |
+
self.patch_drop = nn.Dropout(patch_drop)
|
| 136 |
+
else:
|
| 137 |
+
self.linear_mask = None
|
| 138 |
+
if input_dim != embed_dim:
|
| 139 |
+
raise ValueError("embed_dim must equal atten_dim but {input_dim}!={embed_dim}")
|
| 140 |
+
else:
|
| 141 |
+
self.embed = nn.Identity()
|
| 142 |
+
self.project = nn.Identity()
|
| 143 |
+
self.norm1 = nn.LayerNorm(embed_dim)
|
| 144 |
+
self.norm2 = nn.LayerNorm(embed_dim)
|
| 145 |
+
# Feed-forward network (MLP)
|
| 146 |
+
self.mlp = get_ffn(embed_dim, embed_dim, mlp_dim, dropout=dropout)
|
| 147 |
+
self.path_drop = StochasticDepth(drop_path, mode='row')
|
| 148 |
+
self.norm3 = nn.LayerNorm(input_dim)
|
| 149 |
+
|
| 150 |
+
def _reset_parameters(self):
|
| 151 |
+
for n, m in self.named_modules():
|
| 152 |
+
if n.startswith('self_attn'):
|
| 153 |
+
continue
|
| 154 |
+
if isinstance(m, (nn.Linear, GroupedLinear)):
|
| 155 |
+
nn.init.trunc_normal_(m.weight.data, std=0.02)
|
| 156 |
+
if m.bias is not None:
|
| 157 |
+
nn.init.zeros_(m.bias.data)
|
| 158 |
+
nn.init.ones_(self.norm1.weight)
|
| 159 |
+
nn.init.zeros_(self.norm1.bias)
|
| 160 |
+
nn.init.ones_(self.norm2.weight)
|
| 161 |
+
nn.init.zeros_(self.norm2.bias)
|
| 162 |
+
nn.init.zeros_(self.norm3.weight)
|
| 163 |
+
nn.init.zeros_(self.norm3.bias)
|
| 164 |
+
|
| 165 |
+
def forward_common(self, x, mask):
|
| 166 |
+
"""
|
| 167 |
+
x: shape (batch_size, seq_len, embed_dim)
|
| 168 |
+
"""
|
| 169 |
+
# Compute mask scores: (batch_size, seq_len, 1)
|
| 170 |
+
x1 = x
|
| 171 |
+
x = self.embed(x)
|
| 172 |
+
x = self.norm1(x)
|
| 173 |
+
# Apply attention mechanism
|
| 174 |
+
attn_output, _ = self.self_attn(x, x, x, attn_mask=mask)
|
| 175 |
+
# Add & Norm
|
| 176 |
+
x = x + self.path_drop(attn_output)
|
| 177 |
+
x = self.norm2(x)
|
| 178 |
+
# Feed-forward network
|
| 179 |
+
mlp_output = self.mlp(x)
|
| 180 |
+
# Add & Norm
|
| 181 |
+
x = self.path_drop(self.project(x + mlp_output))
|
| 182 |
+
x = self.norm3(x)
|
| 183 |
+
if mask is not None:
|
| 184 |
+
x = x * mask.unsqueeze(-1)
|
| 185 |
+
x = x1 + x
|
| 186 |
+
return x
|
| 187 |
+
|
| 188 |
+
def get_groups(self, mask, full=False):
|
| 189 |
+
n_items, index = (mask != 0.0).sum(-1).cpu().sort(descending=True)
|
| 190 |
+
n_items, index = n_items.tolist(), index.tolist()
|
| 191 |
+
groups = []
|
| 192 |
+
t = 1.0 if full else 1.2
|
| 193 |
+
for ni, ii in zip(n_items, index):
|
| 194 |
+
if ni == 0:
|
| 195 |
+
break
|
| 196 |
+
if len(groups) == 0 or groups[-1][1] / ni > t:
|
| 197 |
+
groups.append(([], ni))
|
| 198 |
+
groups[-1][0].append(ii)
|
| 199 |
+
return groups
|
| 200 |
+
|
| 201 |
+
def infer_forward(self, x, mask, full=False):
|
| 202 |
+
"""
|
| 203 |
+
The “sparse‐inference” path: for each group of batch‐samples that have the same
|
| 204 |
+
number n of tokens ≥ mask_threshold, gather only those top‐n tokens (in original order),
|
| 205 |
+
run forward_common on the smaller (b’, n, dim) tensor, then scatter the results back.
|
| 206 |
+
Fully masked tokens are left untouched.
|
| 207 |
+
"""
|
| 208 |
+
# Step 1: Threshold the mask without in-place ops
|
| 209 |
+
mask_thresholded = mask * (mask >= self.mask_threshold)
|
| 210 |
+
# Step 2: Prepare output tensor (copy of x)
|
| 211 |
+
x_out = x.clone()
|
| 212 |
+
# Step 3: Group samples by number of kept tokens
|
| 213 |
+
groups = self.get_groups(mask_thresholded, full)
|
| 214 |
+
# Step 4: Process each group
|
| 215 |
+
for batch_indices, n_keep in groups:
|
| 216 |
+
x_sel = x[batch_indices] # (Bg, seq_len, input_dim)
|
| 217 |
+
mask_sel = mask_thresholded[batch_indices] # (Bg, seq_len)
|
| 218 |
+
# Top-k selection and sorting
|
| 219 |
+
topk_vals, topk_idx_unsorted = torch.topk(mask_sel, k=n_keep, dim=1, sorted=False)
|
| 220 |
+
topk_idx_sorted, _ = topk_idx_unsorted.sort(dim=1)
|
| 221 |
+
# Gather tokens in sorted order
|
| 222 |
+
idx_expanded = topk_idx_sorted.unsqueeze(-1).expand(-1, -1, x_sel.size(-1))
|
| 223 |
+
X_topk = torch.gather(x_sel, dim=1, index=idx_expanded)
|
| 224 |
+
mask_topk = torch.gather(mask_sel, dim=1, index=topk_idx_sorted)
|
| 225 |
+
# Run forward pass
|
| 226 |
+
results = self.forward_common(X_topk, mask_topk)
|
| 227 |
+
# Scatter results into a new x_sel tensor
|
| 228 |
+
x_sel_updated = x_sel.clone()
|
| 229 |
+
x_sel_updated = x_sel_updated.scatter(1, idx_expanded, results)
|
| 230 |
+
# Write the updated batch slice into the new output tensor
|
| 231 |
+
x_out[batch_indices] = x_sel_updated
|
| 232 |
+
return x_out
|
| 233 |
+
|
| 234 |
+
def forward(self, x, full=False):
|
| 235 |
+
if self.linear_mask is not None:
|
| 236 |
+
attn_mask = self.patch_drop(self.linear_mask(x).sigmoid().squeeze(-1))
|
| 237 |
+
else:
|
| 238 |
+
attn_mask = None
|
| 239 |
+
if not self.training and not attn_mask is None and self.mask_threshold >= 0:
|
| 240 |
+
x = self.infer_forward(x, attn_mask, full)
|
| 241 |
+
else:
|
| 242 |
+
x = self.forward_common(x, attn_mask)
|
| 243 |
+
return x, attn_mask
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
class VisionTransformer(nn.Module):
|
| 247 |
+
def __init__(
|
| 248 |
+
self,
|
| 249 |
+
image_size=256,
|
| 250 |
+
patch_size=16,
|
| 251 |
+
num_classes=1000,
|
| 252 |
+
embed_dim=768,
|
| 253 |
+
atten_dim=192,
|
| 254 |
+
depth=12,
|
| 255 |
+
num_heads=3,
|
| 256 |
+
mlp_dim=768,
|
| 257 |
+
channels=3,
|
| 258 |
+
dropout=0.1,
|
| 259 |
+
drop_path=0.1,
|
| 260 |
+
patch_drop=0.1,
|
| 261 |
+
attention_scale=2.,
|
| 262 |
+
mask_threshold=0.05,
|
| 263 |
+
use_distil_token=False
|
| 264 |
+
):
|
| 265 |
+
super().__init__()
|
| 266 |
+
assert image_size % patch_size == 0, "Image dimensions must be divisible by the patch size."
|
| 267 |
+
num_patches = (image_size // patch_size) ** 2
|
| 268 |
+
|
| 269 |
+
# Patch embedding layer
|
| 270 |
+
self.patch_embed = nn.Conv2d(
|
| 271 |
+
in_channels=channels,
|
| 272 |
+
out_channels=embed_dim,
|
| 273 |
+
kernel_size=patch_size,
|
| 274 |
+
stride=patch_size
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
# Class token
|
| 278 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 279 |
+
|
| 280 |
+
# Positional embedding
|
| 281 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1 + (1 if use_distil_token else 0), embed_dim))
|
| 282 |
+
|
| 283 |
+
self.dropout = nn.Dropout(dropout)
|
| 284 |
+
|
| 285 |
+
# Encoder blocks
|
| 286 |
+
self.encoder_layers = nn.ModuleList([
|
| 287 |
+
EncoderBlock(
|
| 288 |
+
embed_dim, atten_dim,
|
| 289 |
+
num_heads, mlp_dim,
|
| 290 |
+
dropout, drop_path * i / (depth - 1),
|
| 291 |
+
patch_drop=patch_drop,
|
| 292 |
+
attention_scale=attention_scale,
|
| 293 |
+
mask_threshold=mask_threshold,
|
| 294 |
+
)
|
| 295 |
+
for i in range(depth)
|
| 296 |
+
])
|
| 297 |
+
|
| 298 |
+
# Classification head
|
| 299 |
+
self.post_norm = nn.LayerNorm(embed_dim)
|
| 300 |
+
self.head = nn.Linear(embed_dim, num_classes)
|
| 301 |
+
if use_distil_token:
|
| 302 |
+
self.dis_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 303 |
+
self.dis_head = nn.Linear(embed_dim, num_classes)
|
| 304 |
+
else:
|
| 305 |
+
self.dis_token = None
|
| 306 |
+
|
| 307 |
+
# Initialize weights
|
| 308 |
+
self._init_weights()
|
| 309 |
+
|
| 310 |
+
def _init_weights(self):
|
| 311 |
+
for n, m in self.named_modules():
|
| 312 |
+
if n.startswith('encoder_layers'):
|
| 313 |
+
continue
|
| 314 |
+
if isinstance(m, (nn.Linear, nn.Conv2d)):
|
| 315 |
+
nn.init.trunc_normal_(m.weight.data, std=0.02)
|
| 316 |
+
if m.bias is not None:
|
| 317 |
+
nn.init.zeros_(m.bias.data)
|
| 318 |
+
if isinstance(m, nn.LayerNorm):
|
| 319 |
+
nn.init.ones_(m.weight.data)
|
| 320 |
+
nn.init.zeros_(m.bias.data)
|
| 321 |
+
nn.init.trunc_normal_(self.pos_embed, std=0.02)
|
| 322 |
+
if self.cls_token is not None:
|
| 323 |
+
nn.init.trunc_normal_(self.cls_token, std=0.02)
|
| 324 |
+
if self.dis_token is not None:
|
| 325 |
+
nn.init.trunc_normal_(self.dis_token, std=0.02)
|
| 326 |
+
|
| 327 |
+
def forward_features(
|
| 328 |
+
self,
|
| 329 |
+
pixel_values,
|
| 330 |
+
full=False,
|
| 331 |
+
output_hidden_states=False,
|
| 332 |
+
):
|
| 333 |
+
"""
|
| 334 |
+
Args:
|
| 335 |
+
pixel_values: (B, C, H, W)
|
| 336 |
+
Returns:
|
| 337 |
+
last_hidden_state: (B, N, D)
|
| 338 |
+
all_hidden_states: tuple or None
|
| 339 |
+
masks: Tensor or None
|
| 340 |
+
"""
|
| 341 |
+
batch_size = pixel_values.size(0)
|
| 342 |
+
hidden_states = []
|
| 343 |
+
|
| 344 |
+
# Patch embedding
|
| 345 |
+
x = self.patch_embed(pixel_values)
|
| 346 |
+
x = x.flatten(2).transpose(1, 2)
|
| 347 |
+
|
| 348 |
+
# Distillation token
|
| 349 |
+
if self.dis_token is not None:
|
| 350 |
+
dis_tokens = self.dis_token.expand(batch_size, -1, -1)
|
| 351 |
+
x = torch.cat((dis_tokens, x), dim=1)
|
| 352 |
+
|
| 353 |
+
# CLS token
|
| 354 |
+
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
| 355 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 356 |
+
|
| 357 |
+
# Position + dropout
|
| 358 |
+
x = x + self.pos_embed
|
| 359 |
+
x = self.dropout(x)
|
| 360 |
+
|
| 361 |
+
masks = []
|
| 362 |
+
|
| 363 |
+
for layer in self.encoder_layers:
|
| 364 |
+
x, mask = layer(x, full)
|
| 365 |
+
|
| 366 |
+
if output_hidden_states:
|
| 367 |
+
hidden_states.append(x)
|
| 368 |
+
|
| 369 |
+
if mask is not None:
|
| 370 |
+
masks.append(mask)
|
| 371 |
+
|
| 372 |
+
x = self.post_norm(x)
|
| 373 |
+
|
| 374 |
+
if output_hidden_states:
|
| 375 |
+
hidden_states = tuple(hidden_states)
|
| 376 |
+
else:
|
| 377 |
+
hidden_states = None
|
| 378 |
+
|
| 379 |
+
if len(masks) > 0:
|
| 380 |
+
masks = tuple(masks)
|
| 381 |
+
else:
|
| 382 |
+
masks = None
|
| 383 |
+
|
| 384 |
+
return x, hidden_states, masks
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
def forward_classifier(self, hidden_states):
|
| 388 |
+
"""
|
| 389 |
+
Args:
|
| 390 |
+
hidden_states: (B, N, D)
|
| 391 |
+
Returns:
|
| 392 |
+
logits: (B, num_classes)
|
| 393 |
+
dis_logits: (B, num_classes) or None
|
| 394 |
+
"""
|
| 395 |
+
cls_token = hidden_states[:, 0]
|
| 396 |
+
logits = self.head(cls_token)
|
| 397 |
+
|
| 398 |
+
dis_logits = None
|
| 399 |
+
if self.dis_token is not None:
|
| 400 |
+
dis_cls_token = hidden_states[:, 1]
|
| 401 |
+
dis_logits = self.dis_head(dis_cls_token)
|
| 402 |
+
|
| 403 |
+
# Inference-time averaging (same as original)
|
| 404 |
+
if not self.training:
|
| 405 |
+
logits = (logits + dis_logits) / 2
|
| 406 |
+
|
| 407 |
+
return logits, dis_logits
|
| 408 |
+
|
| 409 |
+
def forward(self, x, full=False):
|
| 410 |
+
last_hidden_states, hidden_states, masks = self.forward_features(x, full)
|
| 411 |
+
logits, dis_logits = self.forward_classifier(last_hidden_states)
|
| 412 |
+
return logits, dis_logits, masks
|