Thatphum commited on
Commit
c88b91b
·
verified ·
1 Parent(s): b9a90bc

Upload FastViTImageEncoder

Browse files
Files changed (5) hide show
  1. README.md +199 -0
  2. config.json +16 -0
  3. image_encoder.py +92 -0
  4. mci.py +1480 -0
  5. model.safetensors +3 -0
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "FastViTImageEncoder"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "image_encoder.FastViTImageConfig",
7
+ "AutoImageProcessor": "transformers.CLIPImageProcessor",
8
+ "AutoModel": "image_encoder.FastViTImageEncoder"
9
+ },
10
+ "embed_dim": 3072,
11
+ "image_size": 1024,
12
+ "model_type": "fastvit_image_encoder",
13
+ "patch_size": 64,
14
+ "torch_dtype": "float32",
15
+ "transformers_version": "4.55.2"
16
+ }
image_encoder.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HF-compatible wrapper that turns the FastViT backbone into a pure *image encoder*.
3
+ Output: a single (B, embed_dim) vector obtained with the built-in GlobalPool2D head.
4
+ """
5
+ import torch
6
+ from transformers import PreTrainedModel, PretrainedConfig
7
+ from .mci import fastvithd, GlobalPool2D # imports your backbone factory
8
+
9
+
10
+ # ----------------------- Config -----------------------
11
+ class FastViTImageConfig(PretrainedConfig):
12
+ """Minimal config so HF knows the image size & embed dim."""
13
+ model_type = "fastvit_image_encoder"
14
+
15
+ def __init__(
16
+ self,
17
+ image_size: int = 1024,
18
+ embed_dim: int = 3072, # channels after conv_exp
19
+ patch_size: int = 16,
20
+ **kwargs
21
+ ):
22
+ self.image_size = image_size
23
+ self.embed_dim = embed_dim
24
+ self.patch_size = patch_size
25
+ super().__init__(**kwargs)
26
+
27
+
28
+ # ----------------------- Model ------------------------
29
+ class FastViTImageEncoder(PreTrainedModel):
30
+ """
31
+ Wraps FastViT-HD and exposes an `.embeddings` output;
32
+ no text tower, no CLIP logits, only a pooled image embedding.
33
+ """
34
+ config_class = FastViTImageConfig
35
+ main_input_name = "pixel_values"
36
+
37
+ def __init__(self, config: FastViTImageConfig):
38
+ super().__init__(config)
39
+
40
+ # We **keep** GlobalPool2D by asking for `num_classes = embed_dim`
41
+ # (FastViT replaces the classifier with GlobalPool2D in that case).
42
+ self.backbone = fastvithd(num_classes=0)
43
+ self.backbone.head = GlobalPool2D(
44
+ in_dim = 3072,
45
+ out_dim = 768
46
+ )
47
+
48
+ # HF helper that registers weights for bf16/half-precision etc.
49
+ self.post_init()
50
+
51
+ # ------------------------------------------
52
+ def forward(self, pixel_values, return_dict=True, **unused):
53
+ """
54
+ Args:
55
+ pixel_values: (B, 3, H, W) tensor (already resized/normalized).
56
+ Returns:
57
+ Dict with a single key `"embeddings"` of shape (B, embed_dim).
58
+ """
59
+ # FastViT forward returns the pooled tensor directly because
60
+ # `num_classes == embed_dim` and head == GlobalPool2D.
61
+ embeddings = self.backbone(pixel_values) # (B, embed_dim)
62
+
63
+ if not return_dict:
64
+ return (embeddings,)
65
+
66
+ return {"embeddings": embeddings}
67
+
68
+ def forward(self, images):
69
+ return self.forward_images(images)
70
+
71
+ def feature_select(self, image_forward_outs):
72
+ # Features from penultimate layer
73
+ image_features = image_forward_outs["image_embeddings"]
74
+
75
+ # Reshape 4D tensor to 3D
76
+ B, C, H, W = image_features.shape
77
+ image_features = image_features.reshape(B, C, H*W)
78
+ image_features = image_features.transpose(1, 2)
79
+ return image_features
80
+
81
+ def forward_images(self, images):
82
+ if type(images) is list:
83
+ image_features = []
84
+ for image in images:
85
+ image_forward_out = self.backbone(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), return_image_embeddings=True)
86
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
87
+ image_features.append(image_feature)
88
+ else:
89
+ image_forward_outs = self.backbone(images.to(device=self.device, dtype=self.dtype), return_image_embeddings=True)
90
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
91
+
92
+ return image_features
mci.py ADDED
@@ -0,0 +1,1480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # For licensing see accompanying LICENSE file.
3
+ # Copyright (C) 2025 Apple Inc. All Rights Reserved.
4
+ #
5
+ import copy
6
+ from functools import partial
7
+ from typing import List, Tuple, Optional, Union, Dict
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch import Tensor
12
+ import torch.nn.functional as F
13
+ from torch.nn.init import normal_
14
+
15
+ from timm.models import register_model
16
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
17
+ from timm.layers import DropPath, SqueezeExcite
18
+
19
+
20
+ def _cfg(url="", **kwargs):
21
+ return {
22
+ "url": url,
23
+ "num_classes": 1000,
24
+ "input_size": (3, 256, 256),
25
+ "pool_size": None,
26
+ "crop_pct": 0.95,
27
+ "interpolation": "bicubic",
28
+ "mean": IMAGENET_DEFAULT_MEAN,
29
+ "std": IMAGENET_DEFAULT_STD,
30
+ "classifier": "head",
31
+ **kwargs,
32
+ }
33
+
34
+
35
+ default_cfgs = {
36
+ "fastvit_t": _cfg(crop_pct=0.9),
37
+ "fastvit_s": _cfg(crop_pct=0.9),
38
+ "fastvit_m": _cfg(crop_pct=0.95),
39
+ }
40
+
41
+
42
+ class SEBlock(nn.Module):
43
+ """Squeeze and Excite module.
44
+
45
+ Pytorch implementation of `Squeeze-and-Excitation Networks` -
46
+ https://arxiv.org/pdf/1709.01507.pdf
47
+ """
48
+
49
+ def __init__(self, in_channels: int, rd_ratio: float = 0.0625) -> None:
50
+ """Construct a Squeeze and Excite Module.
51
+
52
+ Args:
53
+ in_channels: Number of input channels.
54
+ rd_ratio: Input channel reduction ratio.
55
+ """
56
+ super(SEBlock, self).__init__()
57
+ self.reduce = nn.Conv2d(
58
+ in_channels=in_channels,
59
+ out_channels=int(in_channels * rd_ratio),
60
+ kernel_size=1,
61
+ stride=1,
62
+ bias=True,
63
+ )
64
+ self.expand = nn.Conv2d(
65
+ in_channels=int(in_channels * rd_ratio),
66
+ out_channels=in_channels,
67
+ kernel_size=1,
68
+ stride=1,
69
+ bias=True,
70
+ )
71
+
72
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
73
+ """Apply forward pass."""
74
+ b, c, h, w = inputs.size()
75
+ x = F.avg_pool2d(inputs, kernel_size=[h, w])
76
+ x = self.reduce(x)
77
+ x = F.relu(x)
78
+ x = self.expand(x)
79
+ x = torch.sigmoid(x)
80
+ x = x.view(-1, c, 1, 1)
81
+ return inputs * x
82
+
83
+
84
+ class MobileOneBlock(nn.Module):
85
+ """MobileOne building block.
86
+
87
+ This block has a multi-branched architecture at train-time
88
+ and plain-CNN style architecture at inference time
89
+ For more details, please refer to our paper:
90
+ `An Improved One millisecond Mobile Backbone` -
91
+ https://arxiv.org/pdf/2206.04040.pdf
92
+ """
93
+
94
+ def __init__(
95
+ self,
96
+ in_channels: int,
97
+ out_channels: int,
98
+ kernel_size: int,
99
+ stride: int = 1,
100
+ padding: int = 0,
101
+ dilation: int = 1,
102
+ groups: int = 1,
103
+ inference_mode: bool = False,
104
+ use_se: bool = False,
105
+ use_act: bool = True,
106
+ use_scale_branch: bool = True,
107
+ num_conv_branches: int = 1,
108
+ activation: nn.Module = nn.GELU(),
109
+ ) -> None:
110
+ """Construct a MobileOneBlock module.
111
+
112
+ Args:
113
+ in_channels: Number of channels in the input.
114
+ out_channels: Number of channels produced by the block.
115
+ kernel_size: Size of the convolution kernel.
116
+ stride: Stride size.
117
+ padding: Zero-padding size.
118
+ dilation: Kernel dilation factor.
119
+ groups: Group number.
120
+ inference_mode: If True, instantiates model in inference mode.
121
+ use_se: Whether to use SE-ReLU activations.
122
+ use_act: Whether to use activation. Default: ``True``
123
+ use_scale_branch: Whether to use scale branch. Default: ``True``
124
+ num_conv_branches: Number of linear conv branches.
125
+ """
126
+ super(MobileOneBlock, self).__init__()
127
+ self.inference_mode = inference_mode
128
+ self.groups = groups
129
+ self.stride = stride
130
+ self.padding = padding
131
+ self.dilation = dilation
132
+ self.kernel_size = kernel_size
133
+ self.in_channels = in_channels
134
+ self.out_channels = out_channels
135
+ self.num_conv_branches = num_conv_branches
136
+
137
+ # Check if SE-ReLU is requested
138
+ if use_se:
139
+ self.se = SEBlock(out_channels)
140
+ else:
141
+ self.se = nn.Identity()
142
+
143
+ if use_act:
144
+ self.activation = activation
145
+ else:
146
+ self.activation = nn.Identity()
147
+
148
+ if inference_mode:
149
+ self.reparam_conv = nn.Conv2d(
150
+ in_channels=in_channels,
151
+ out_channels=out_channels,
152
+ kernel_size=kernel_size,
153
+ stride=stride,
154
+ padding=padding,
155
+ dilation=dilation,
156
+ groups=groups,
157
+ bias=True,
158
+ )
159
+ else:
160
+ # Re-parameterizable skip connection
161
+ # Fallback, sometimes batchnorm tensors
162
+ # do not get instantiated correctly on some processes
163
+ # when using deepspeed + accelerate
164
+ norm_layer = nn.BatchNorm2d(num_features=in_channels)
165
+ if norm_layer.weight.shape[0] == 0:
166
+ norm_layer.weight = nn.Parameter(torch.zeros(in_channels))
167
+ if norm_layer.bias.shape[0] == 0:
168
+ norm_layer.bias = nn.Parameter(torch.zeros(in_channels))
169
+
170
+ self.rbr_skip = (
171
+ norm_layer
172
+ if out_channels == in_channels and stride == 1
173
+ else None
174
+ )
175
+
176
+ # Re-parameterizable conv branches
177
+ if num_conv_branches > 0:
178
+ rbr_conv = list()
179
+ for _ in range(self.num_conv_branches):
180
+ rbr_conv.append(
181
+ self._conv_bn(kernel_size=kernel_size, padding=padding)
182
+ )
183
+ self.rbr_conv = nn.ModuleList(rbr_conv)
184
+ else:
185
+ self.rbr_conv = None
186
+
187
+ # Re-parameterizable scale branch
188
+ self.rbr_scale = None
189
+ if not isinstance(kernel_size, int):
190
+ kernel_size = kernel_size[0]
191
+ if (kernel_size > 1) and use_scale_branch:
192
+ self.rbr_scale = self._conv_bn(kernel_size=1, padding=0)
193
+
194
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
195
+ """Apply forward pass."""
196
+ # Inference mode forward pass.
197
+ if self.inference_mode:
198
+ return self.activation(self.se(self.reparam_conv(x)))
199
+
200
+ # Multi-branched train-time forward pass.
201
+ # Skip branch output
202
+ identity_out = 0
203
+ if self.rbr_skip is not None:
204
+ identity_out = self.rbr_skip(x)
205
+
206
+ # Scale branch output
207
+ scale_out = 0
208
+ if self.rbr_scale is not None:
209
+ scale_out = self.rbr_scale(x)
210
+
211
+ # Other branches
212
+ out = scale_out + identity_out
213
+ if self.rbr_conv is not None:
214
+ for ix in range(self.num_conv_branches):
215
+ out += self.rbr_conv[ix](x)
216
+
217
+ return self.activation(self.se(out))
218
+
219
+ def reparameterize(self):
220
+ """Following works like `RepVGG: Making VGG-style ConvNets Great Again` -
221
+ https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched
222
+ architecture used at training time to obtain a plain CNN-like structure
223
+ for inference.
224
+ """
225
+ if self.inference_mode:
226
+ return
227
+ kernel, bias = self._get_kernel_bias()
228
+ self.reparam_conv = nn.Conv2d(
229
+ in_channels=self.in_channels,
230
+ out_channels=self.out_channels,
231
+ kernel_size=self.kernel_size,
232
+ stride=self.stride,
233
+ padding=self.padding,
234
+ dilation=self.dilation,
235
+ groups=self.groups,
236
+ bias=True,
237
+ )
238
+ self.reparam_conv.weight.data = kernel
239
+ self.reparam_conv.bias.data = bias
240
+
241
+ # Delete un-used branches
242
+ self.__delattr__("rbr_conv")
243
+ self.__delattr__("rbr_scale")
244
+ if hasattr(self, "rbr_skip"):
245
+ self.__delattr__("rbr_skip")
246
+
247
+ self.inference_mode = True
248
+
249
+ def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
250
+ """Method to obtain re-parameterized kernel and bias.
251
+ Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83
252
+
253
+ Returns:
254
+ Tuple of (kernel, bias) after fusing branches.
255
+ """
256
+ # get weights and bias of scale branch
257
+ kernel_scale = 0
258
+ bias_scale = 0
259
+ if self.rbr_scale is not None:
260
+ kernel_scale, bias_scale = self._fuse_bn_tensor(self.rbr_scale)
261
+ # Pad scale branch kernel to match conv branch kernel size.
262
+ pad = self.kernel_size // 2
263
+ kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad])
264
+
265
+ # get weights and bias of skip branch
266
+ kernel_identity = 0
267
+ bias_identity = 0
268
+ if self.rbr_skip is not None:
269
+ kernel_identity, bias_identity = self._fuse_bn_tensor(self.rbr_skip)
270
+
271
+ # get weights and bias of conv branches
272
+ kernel_conv = 0
273
+ bias_conv = 0
274
+ if self.rbr_conv is not None:
275
+ for ix in range(self.num_conv_branches):
276
+ _kernel, _bias = self._fuse_bn_tensor(self.rbr_conv[ix])
277
+ kernel_conv += _kernel
278
+ bias_conv += _bias
279
+
280
+ kernel_final = kernel_conv + kernel_scale + kernel_identity
281
+ bias_final = bias_conv + bias_scale + bias_identity
282
+ return kernel_final, bias_final
283
+
284
+ def _fuse_bn_tensor(
285
+ self, branch: Union[nn.Sequential, nn.BatchNorm2d]
286
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
287
+ """Method to fuse batchnorm layer with preceeding conv layer.
288
+ Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95
289
+
290
+ Args:
291
+ branch: Sequence of ops to be fused.
292
+
293
+ Returns:
294
+ Tuple of (kernel, bias) after fusing batchnorm.
295
+ """
296
+ if isinstance(branch, nn.Sequential):
297
+ kernel = branch.conv.weight
298
+ running_mean = branch.bn.running_mean
299
+ running_var = branch.bn.running_var
300
+ gamma = branch.bn.weight
301
+ beta = branch.bn.bias
302
+ eps = branch.bn.eps
303
+ else:
304
+ assert isinstance(branch, nn.BatchNorm2d)
305
+ if not hasattr(self, "id_tensor"):
306
+ input_dim = self.in_channels // self.groups
307
+
308
+ kernel_size = self.kernel_size
309
+ if isinstance(self.kernel_size, int):
310
+ kernel_size = (self.kernel_size, self.kernel_size)
311
+
312
+ kernel_value = torch.zeros(
313
+ (self.in_channels, input_dim, kernel_size[0], kernel_size[1]),
314
+ dtype=branch.weight.dtype,
315
+ device=branch.weight.device,
316
+ )
317
+ for i in range(self.in_channels):
318
+ kernel_value[
319
+ i, i % input_dim, kernel_size[0] // 2, kernel_size[1] // 2
320
+ ] = 1
321
+ self.id_tensor = kernel_value
322
+ kernel = self.id_tensor
323
+ running_mean = branch.running_mean
324
+ running_var = branch.running_var
325
+ gamma = branch.weight
326
+ beta = branch.bias
327
+ eps = branch.eps
328
+ std = (running_var + eps).sqrt()
329
+ t = (gamma / std).reshape(-1, 1, 1, 1)
330
+ return kernel * t, beta - running_mean * gamma / std
331
+
332
+ def _conv_bn(self, kernel_size: int, padding: int) -> nn.Sequential:
333
+ """Helper method to construct conv-batchnorm layers.
334
+
335
+ Args:
336
+ kernel_size: Size of the convolution kernel.
337
+ padding: Zero-padding size.
338
+
339
+ Returns:
340
+ Conv-BN module.
341
+ """
342
+ # Fallback, sometimes batchnorm tensors
343
+ # do not get instantiated correctly on some processes
344
+ # when using deepspeed + accelerate
345
+ norm_layer = nn.BatchNorm2d(num_features=self.out_channels)
346
+ if norm_layer.weight.shape[0] == 0:
347
+ norm_layer.weight = nn.Parameter(torch.zeros(self.out_channels))
348
+ if norm_layer.bias.shape[0] == 0:
349
+ norm_layer.bias = nn.Parameter(torch.zeros(self.out_channels))
350
+
351
+ mod_list = nn.Sequential()
352
+ mod_list.add_module(
353
+ "conv",
354
+ nn.Conv2d(
355
+ in_channels=self.in_channels,
356
+ out_channels=self.out_channels,
357
+ kernel_size=kernel_size,
358
+ stride=self.stride,
359
+ padding=padding,
360
+ groups=self.groups,
361
+ bias=False,
362
+ ),
363
+ )
364
+ mod_list.add_module("bn", norm_layer)
365
+ return mod_list
366
+
367
+
368
+ class ReparamLargeKernelConv(nn.Module):
369
+ """Building Block of RepLKNet
370
+
371
+ This class defines overparameterized large kernel conv block
372
+ introduced in `RepLKNet <https://arxiv.org/abs/2203.06717>`_
373
+
374
+ Reference: https://github.com/DingXiaoH/RepLKNet-pytorch
375
+ """
376
+
377
+ def __init__(
378
+ self,
379
+ in_channels: int,
380
+ out_channels: int,
381
+ kernel_size: int,
382
+ stride: int,
383
+ groups: int,
384
+ small_kernel: int,
385
+ inference_mode: bool = False,
386
+ use_se: bool = False,
387
+ activation: nn.Module = nn.GELU(),
388
+ ) -> None:
389
+ """Construct a ReparamLargeKernelConv module.
390
+
391
+ Args:
392
+ in_channels: Number of input channels.
393
+ out_channels: Number of output channels.
394
+ kernel_size: Kernel size of the large kernel conv branch.
395
+ stride: Stride size. Default: 1
396
+ groups: Group number. Default: 1
397
+ small_kernel: Kernel size of small kernel conv branch.
398
+ inference_mode: If True, instantiates model in inference mode. Default: ``False``
399
+ activation: Activation module. Default: ``nn.GELU``
400
+ """
401
+ super(ReparamLargeKernelConv, self).__init__()
402
+
403
+ self.stride = stride
404
+ self.groups = groups
405
+ self.in_channels = in_channels
406
+ self.out_channels = out_channels
407
+ self.activation = activation
408
+
409
+ self.kernel_size = kernel_size
410
+ self.small_kernel = small_kernel
411
+ self.padding = kernel_size // 2
412
+
413
+ # Check if SE is requested
414
+ if use_se:
415
+ self.se = SqueezeExcite(out_channels, rd_ratio=0.25)
416
+ else:
417
+ self.se = nn.Identity()
418
+
419
+ if inference_mode:
420
+ self.lkb_reparam = nn.Conv2d(
421
+ in_channels=in_channels,
422
+ out_channels=out_channels,
423
+ kernel_size=kernel_size,
424
+ stride=stride,
425
+ padding=self.padding,
426
+ dilation=1,
427
+ groups=groups,
428
+ bias=True,
429
+ )
430
+ else:
431
+ self.lkb_origin = self._conv_bn(
432
+ kernel_size=kernel_size, padding=self.padding
433
+ )
434
+ if small_kernel is not None:
435
+ assert (
436
+ small_kernel <= kernel_size
437
+ ), "The kernel size for re-param cannot be larger than the large kernel!"
438
+ self.small_conv = self._conv_bn(
439
+ kernel_size=small_kernel, padding=small_kernel // 2
440
+ )
441
+
442
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
443
+ """Apply forward pass."""
444
+ if hasattr(self, "lkb_reparam"):
445
+ out = self.lkb_reparam(x)
446
+ else:
447
+ out = self.lkb_origin(x)
448
+ if hasattr(self, "small_conv"):
449
+ out += self.small_conv(x)
450
+
451
+ return self.activation(self.se(out))
452
+
453
+ def get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
454
+ """Method to obtain re-parameterized kernel and bias.
455
+ Reference: https://github.com/DingXiaoH/RepLKNet-pytorch
456
+
457
+ Returns:
458
+ Tuple of (kernel, bias) after fusing branches.
459
+ """
460
+ eq_k, eq_b = self._fuse_bn(self.lkb_origin.conv, self.lkb_origin.bn)
461
+ if hasattr(self, "small_conv"):
462
+ small_k, small_b = self._fuse_bn(self.small_conv.conv, self.small_conv.bn)
463
+ eq_b += small_b
464
+ eq_k += nn.functional.pad(
465
+ small_k, [(self.kernel_size - self.small_kernel) // 2] * 4
466
+ )
467
+ return eq_k, eq_b
468
+
469
+ def reparameterize(self) -> None:
470
+ """
471
+ Following works like `RepVGG: Making VGG-style ConvNets Great Again` -
472
+ https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched
473
+ architecture used at training time to obtain a plain CNN-like structure
474
+ for inference.
475
+ """
476
+ eq_k, eq_b = self.get_kernel_bias()
477
+ self.lkb_reparam = nn.Conv2d(
478
+ in_channels=self.in_channels,
479
+ out_channels=self.out_channels,
480
+ kernel_size=self.kernel_size,
481
+ stride=self.stride,
482
+ padding=self.padding,
483
+ dilation=self.lkb_origin.conv.dilation,
484
+ groups=self.groups,
485
+ bias=True,
486
+ )
487
+
488
+ self.lkb_reparam.weight.data = eq_k
489
+ self.lkb_reparam.bias.data = eq_b
490
+ self.__delattr__("lkb_origin")
491
+ if hasattr(self, "small_conv"):
492
+ self.__delattr__("small_conv")
493
+
494
+ @staticmethod
495
+ def _fuse_bn(
496
+ conv: torch.Tensor, bn: nn.BatchNorm2d
497
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
498
+ """Method to fuse batchnorm layer with conv layer.
499
+
500
+ Args:
501
+ conv: Convolutional kernel weights.
502
+ bn: Batchnorm 2d layer.
503
+
504
+ Returns:
505
+ Tuple of (kernel, bias) after fusing batchnorm.
506
+ """
507
+ kernel = conv.weight
508
+ running_mean = bn.running_mean
509
+ running_var = bn.running_var
510
+ gamma = bn.weight
511
+ beta = bn.bias
512
+ eps = bn.eps
513
+ std = (running_var + eps).sqrt()
514
+ t = (gamma / std).reshape(-1, 1, 1, 1)
515
+ return kernel * t, beta - running_mean * gamma / std
516
+
517
+ def _conv_bn(self, kernel_size: int, padding: int = 0) -> nn.Sequential:
518
+ """Helper method to construct conv-batchnorm layers.
519
+
520
+ Args:
521
+ kernel_size: Size of the convolution kernel.
522
+ padding: Zero-padding size.
523
+
524
+ Returns:
525
+ A nn.Sequential Conv-BN module.
526
+ """
527
+ # Fallback, sometimes batchnorm tensors
528
+ # do not get instantiated correctly on some processes
529
+ # when using deepspeed + accelerate
530
+ norm_layer = nn.BatchNorm2d(num_features=self.out_channels)
531
+ if norm_layer.weight.shape[0] == 0:
532
+ norm_layer.weight = nn.Parameter(torch.zeros(self.out_channels))
533
+ if norm_layer.bias.shape[0] == 0:
534
+ norm_layer.bias = nn.Parameter(torch.zeros(self.out_channels))
535
+
536
+ mod_list = nn.Sequential()
537
+ mod_list.add_module(
538
+ "conv",
539
+ nn.Conv2d(
540
+ in_channels=self.in_channels,
541
+ out_channels=self.out_channels,
542
+ kernel_size=kernel_size,
543
+ stride=self.stride,
544
+ padding=padding,
545
+ groups=self.groups,
546
+ bias=False,
547
+ ),
548
+ )
549
+ mod_list.add_module("bn", norm_layer)
550
+ return mod_list
551
+
552
+
553
+ def convolutional_stem(
554
+ in_channels: int, out_channels: int, inference_mode: bool = False, use_scale_branch: bool = True,
555
+ ) -> nn.Sequential:
556
+ """Build convolutional stem with MobileOne blocks.
557
+
558
+ Args:
559
+ in_channels: Number of input channels.
560
+ out_channels: Number of output channels.
561
+ inference_mode: Flag to instantiate model in inference mode. Default: ``False``
562
+
563
+ Returns:
564
+ nn.Sequential object with stem elements.
565
+ """
566
+ return nn.Sequential(
567
+ MobileOneBlock(
568
+ in_channels=in_channels,
569
+ out_channels=out_channels,
570
+ kernel_size=3,
571
+ stride=2,
572
+ padding=1,
573
+ groups=1,
574
+ inference_mode=inference_mode,
575
+ use_se=False,
576
+ num_conv_branches=1,
577
+ use_scale_branch=use_scale_branch
578
+ ),
579
+ MobileOneBlock(
580
+ in_channels=out_channels,
581
+ out_channels=out_channels,
582
+ kernel_size=3,
583
+ stride=2,
584
+ padding=1,
585
+ groups=out_channels,
586
+ inference_mode=inference_mode,
587
+ use_se=False,
588
+ num_conv_branches=1,
589
+ use_scale_branch=use_scale_branch
590
+ ),
591
+ MobileOneBlock(
592
+ in_channels=out_channels,
593
+ out_channels=out_channels,
594
+ kernel_size=1,
595
+ stride=1,
596
+ padding=0,
597
+ groups=1,
598
+ inference_mode=inference_mode,
599
+ use_se=False,
600
+ num_conv_branches=1,
601
+ use_scale_branch=use_scale_branch
602
+ ),
603
+ )
604
+
605
+
606
+ class LayerNormChannel(nn.Module):
607
+ """
608
+ LayerNorm only for Channel Dimension.
609
+ Input: tensor in shape [B, C, H, W]
610
+ """
611
+ def __init__(self, num_features, eps=1e-05) -> None:
612
+ super().__init__()
613
+ self.weight = nn.Parameter(torch.ones(num_features))
614
+ self.bias = nn.Parameter(torch.zeros(num_features))
615
+ self.eps = eps
616
+
617
+ def forward(self, x) -> torch.Tensor:
618
+ u = x.mean(1, keepdim=True)
619
+ s = (x - u).pow(2).mean(1, keepdim=True)
620
+ x = (x - u) / torch.sqrt(s + self.eps)
621
+ x = self.weight.unsqueeze(-1).unsqueeze(-1) * x \
622
+ + self.bias.unsqueeze(-1).unsqueeze(-1)
623
+ return x
624
+
625
+
626
+ class MHSA(nn.Module):
627
+ """Multi-headed Self Attention module.
628
+
629
+ Source modified from:
630
+ https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
631
+ """
632
+
633
+ def __init__(
634
+ self,
635
+ dim: int,
636
+ head_dim: int = 32,
637
+ qkv_bias: bool = False,
638
+ attn_drop: float = 0.0,
639
+ proj_drop: float = 0.0,
640
+ ) -> None:
641
+ """Build MHSA module that can handle 3D or 4D input tensors.
642
+
643
+ Args:
644
+ dim: Number of embedding dimensions.
645
+ head_dim: Number of hidden dimensions per head. Default: ``32``
646
+ qkv_bias: Use bias or not. Default: ``False``
647
+ attn_drop: Dropout rate for attention tensor.
648
+ proj_drop: Dropout rate for projection tensor.
649
+ """
650
+ super().__init__()
651
+ assert dim % head_dim == 0, "dim should be divisible by head_dim"
652
+ self.head_dim = head_dim
653
+ self.num_heads = dim // head_dim
654
+ self.scale = head_dim**-0.5
655
+
656
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
657
+ self.attn_drop = nn.Dropout(attn_drop)
658
+ self.proj = nn.Linear(dim, dim)
659
+ self.proj_drop = nn.Dropout(proj_drop)
660
+
661
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
662
+ shape = x.shape
663
+ B, C, H, W = shape
664
+ N = H * W
665
+ if len(shape) == 4:
666
+ x = torch.flatten(x, start_dim=2).transpose(-2, -1) # (B, N, C)
667
+ qkv = (
668
+ self.qkv(x)
669
+ .reshape(B, N, 3, self.num_heads, self.head_dim)
670
+ .permute(2, 0, 3, 1, 4)
671
+ )
672
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
673
+
674
+ # trick here to make q@k.t more stable
675
+ attn = (q * self.scale) @ k.transpose(-2, -1)
676
+ attn = attn.softmax(dim=-1)
677
+ attn = self.attn_drop(attn)
678
+
679
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
680
+ x = self.proj(x)
681
+ x = self.proj_drop(x)
682
+ if len(shape) == 4:
683
+ x = x.transpose(-2, -1).reshape(B, C, H, W)
684
+
685
+ return x
686
+
687
+
688
+ class PatchEmbed(nn.Module):
689
+ """Convolutional patch embedding layer."""
690
+
691
+ def __init__(
692
+ self,
693
+ patch_size: int,
694
+ stride: int,
695
+ in_channels: int,
696
+ embed_dim: int,
697
+ inference_mode: bool = False,
698
+ use_se: bool = False,
699
+ ) -> None:
700
+ """Build patch embedding layer.
701
+
702
+ Args:
703
+ patch_size: Patch size for embedding computation.
704
+ stride: Stride for convolutional embedding layer.
705
+ in_channels: Number of channels of input tensor.
706
+ embed_dim: Number of embedding dimensions.
707
+ inference_mode: Flag to instantiate model in inference mode. Default: ``False``
708
+ use_se: If ``True`` SE block will be used.
709
+ """
710
+ super().__init__()
711
+ block = list()
712
+ block.append(
713
+ ReparamLargeKernelConv(
714
+ in_channels=in_channels,
715
+ out_channels=embed_dim,
716
+ kernel_size=patch_size,
717
+ stride=stride,
718
+ groups=in_channels,
719
+ small_kernel=3,
720
+ inference_mode=inference_mode,
721
+ use_se=use_se,
722
+ )
723
+ )
724
+ block.append(
725
+ MobileOneBlock(
726
+ in_channels=embed_dim,
727
+ out_channels=embed_dim,
728
+ kernel_size=1,
729
+ stride=1,
730
+ padding=0,
731
+ groups=1,
732
+ inference_mode=inference_mode,
733
+ use_se=False,
734
+ num_conv_branches=1,
735
+ )
736
+ )
737
+ self.proj = nn.Sequential(*block)
738
+
739
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
740
+ x = self.proj(x)
741
+ return x
742
+
743
+
744
+ class RepMixer(nn.Module):
745
+ """Reparameterizable token mixer.
746
+
747
+ For more details, please refer to our paper:
748
+ `FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization <https://arxiv.org/pdf/2303.14189.pdf>`_
749
+ """
750
+
751
+ def __init__(
752
+ self,
753
+ dim,
754
+ kernel_size=3,
755
+ use_layer_scale=True,
756
+ layer_scale_init_value=1e-5,
757
+ inference_mode: bool = False,
758
+ ):
759
+ """Build RepMixer Module.
760
+
761
+ Args:
762
+ dim: Input feature map dimension. :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, H, W)`.
763
+ kernel_size: Kernel size for spatial mixing. Default: 3
764
+ use_layer_scale: If True, learnable layer scale is used. Default: ``True``
765
+ layer_scale_init_value: Initial value for layer scale. Default: 1e-5
766
+ inference_mode: If True, instantiates model in inference mode. Default: ``False``
767
+ """
768
+ super().__init__()
769
+ self.dim = dim
770
+ self.kernel_size = kernel_size
771
+ self.inference_mode = inference_mode
772
+
773
+ if inference_mode:
774
+ self.reparam_conv = nn.Conv2d(
775
+ in_channels=self.dim,
776
+ out_channels=self.dim,
777
+ kernel_size=self.kernel_size,
778
+ stride=1,
779
+ padding=self.kernel_size // 2,
780
+ groups=self.dim,
781
+ bias=True,
782
+ )
783
+ else:
784
+ self.norm = MobileOneBlock(
785
+ dim,
786
+ dim,
787
+ kernel_size,
788
+ padding=kernel_size // 2,
789
+ groups=dim,
790
+ use_act=False,
791
+ use_scale_branch=False,
792
+ num_conv_branches=0,
793
+ )
794
+ self.mixer = MobileOneBlock(
795
+ dim,
796
+ dim,
797
+ kernel_size,
798
+ padding=kernel_size // 2,
799
+ groups=dim,
800
+ use_act=False,
801
+ )
802
+ self.use_layer_scale = use_layer_scale
803
+ if use_layer_scale:
804
+ self.layer_scale = nn.Parameter(
805
+ layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
806
+ )
807
+
808
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
809
+ if hasattr(self, "reparam_conv"):
810
+ x = self.reparam_conv(x)
811
+ return x
812
+ else:
813
+ if self.use_layer_scale:
814
+ x = x + self.layer_scale * (self.mixer(x) - self.norm(x))
815
+ else:
816
+ x = x + self.mixer(x) - self.norm(x)
817
+ return x
818
+
819
+ def reparameterize(self) -> None:
820
+ """Reparameterize mixer and norm into a single
821
+ convolutional layer for efficient inference.
822
+ """
823
+ if self.inference_mode:
824
+ return
825
+
826
+ self.mixer.reparameterize()
827
+ self.norm.reparameterize()
828
+
829
+ if self.use_layer_scale:
830
+ w = self.mixer.id_tensor + self.layer_scale.unsqueeze(-1) * (
831
+ self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight
832
+ )
833
+ b = torch.squeeze(self.layer_scale) * (
834
+ self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias
835
+ )
836
+ else:
837
+ w = (
838
+ self.mixer.id_tensor
839
+ + self.mixer.reparam_conv.weight
840
+ - self.norm.reparam_conv.weight
841
+ )
842
+ b = self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias
843
+
844
+ self.reparam_conv = nn.Conv2d(
845
+ in_channels=self.dim,
846
+ out_channels=self.dim,
847
+ kernel_size=self.kernel_size,
848
+ stride=1,
849
+ padding=self.kernel_size // 2,
850
+ groups=self.dim,
851
+ bias=True,
852
+ )
853
+ self.reparam_conv.weight.data = w
854
+ self.reparam_conv.bias.data = b
855
+
856
+ self.__delattr__("mixer")
857
+ self.__delattr__("norm")
858
+ if self.use_layer_scale:
859
+ self.__delattr__("layer_scale")
860
+
861
+
862
+ class ConvFFN(nn.Module):
863
+ """Convolutional FFN Module."""
864
+
865
+ def __init__(
866
+ self,
867
+ in_channels: int,
868
+ hidden_channels: Optional[int] = None,
869
+ out_channels: Optional[int] = None,
870
+ act_layer: nn.Module = nn.GELU,
871
+ drop: float = 0.0,
872
+ ) -> None:
873
+ """Build convolutional FFN module.
874
+
875
+ Args:
876
+ in_channels: Number of input channels.
877
+ hidden_channels: Number of channels after expansion. Default: None
878
+ out_channels: Number of output channels. Default: None
879
+ act_layer: Activation layer. Default: ``GELU``
880
+ drop: Dropout rate. Default: ``0.0``.
881
+ """
882
+ super().__init__()
883
+ out_channels = out_channels or in_channels
884
+ hidden_channels = hidden_channels or in_channels
885
+ self.conv = nn.Sequential()
886
+ self.conv.add_module(
887
+ "conv",
888
+ nn.Conv2d(
889
+ in_channels=in_channels,
890
+ out_channels=out_channels,
891
+ kernel_size=7,
892
+ padding=3,
893
+ groups=in_channels,
894
+ bias=False,
895
+ ),
896
+ )
897
+
898
+ # Fallback, sometimes batchnorm tensors
899
+ # do not get instantiated correctly on some processes
900
+ # when using deepspeed + accelerate
901
+ norm_layer = nn.BatchNorm2d(num_features=out_channels)
902
+ if norm_layer.weight.shape[0] == 0:
903
+ norm_layer.weight = nn.Parameter(torch.zeros(out_channels))
904
+ if norm_layer.bias.shape[0] == 0:
905
+ norm_layer.bias = nn.Parameter(torch.zeros(out_channels))
906
+
907
+ self.conv.add_module("bn", norm_layer)
908
+ self.fc1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1)
909
+ self.act = act_layer()
910
+ self.fc2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=1)
911
+ self.drop = nn.Dropout(drop)
912
+ self.apply(self._init_weights)
913
+
914
+ def _init_weights(self, m: nn.Module) -> None:
915
+ if isinstance(m, nn.Conv2d):
916
+ normal_(m.weight, std=0.02)
917
+ if m.bias is not None:
918
+ nn.init.constant_(m.bias, 0)
919
+
920
+ _initialize_weights = _init_weights
921
+
922
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
923
+ x = self.conv(x)
924
+ x = self.fc1(x)
925
+ x = self.act(x)
926
+ x = self.drop(x)
927
+ x = self.fc2(x)
928
+ x = self.drop(x)
929
+ return x
930
+
931
+
932
+ class RepCPE(nn.Module):
933
+ """Implementation of conditional positional encoding.
934
+
935
+ For more details refer to paper:
936
+ `Conditional Positional Encodings for Vision Transformers <https://arxiv.org/pdf/2102.10882.pdf>`_
937
+
938
+ In our implementation, we can reparameterize this module to eliminate a skip connection.
939
+ """
940
+
941
+ def __init__(
942
+ self,
943
+ in_channels: int,
944
+ embed_dim: int = 768,
945
+ spatial_shape: Union[int, Tuple[int, int]] = (7, 7),
946
+ inference_mode=False,
947
+ ) -> None:
948
+ """Build reparameterizable conditional positional encoding
949
+
950
+ Args:
951
+ in_channels: Number of input channels.
952
+ embed_dim: Number of embedding dimensions. Default: 768
953
+ spatial_shape: Spatial shape of kernel for positional encoding. Default: (7, 7)
954
+ inference_mode: Flag to instantiate block in inference mode. Default: ``False``
955
+ """
956
+ super(RepCPE, self).__init__()
957
+ if isinstance(spatial_shape, int):
958
+ spatial_shape = tuple([spatial_shape] * 2)
959
+ assert isinstance(spatial_shape, Tuple), (
960
+ f'"spatial_shape" must by a sequence or int, '
961
+ f"get {type(spatial_shape)} instead."
962
+ )
963
+ assert len(spatial_shape) == 2, (
964
+ f'Length of "spatial_shape" should be 2, '
965
+ f"got {len(spatial_shape)} instead."
966
+ )
967
+
968
+ self.spatial_shape = spatial_shape
969
+ self.embed_dim = embed_dim
970
+ self.in_channels = in_channels
971
+ self.groups = embed_dim
972
+
973
+ if inference_mode:
974
+ self.reparam_conv = nn.Conv2d(
975
+ in_channels=self.in_channels,
976
+ out_channels=self.embed_dim,
977
+ kernel_size=self.spatial_shape,
978
+ stride=1,
979
+ padding=int(self.spatial_shape[0] // 2),
980
+ groups=self.embed_dim,
981
+ bias=True,
982
+ )
983
+ else:
984
+ self.pe = nn.Conv2d(
985
+ in_channels,
986
+ embed_dim,
987
+ spatial_shape,
988
+ 1,
989
+ int(spatial_shape[0] // 2),
990
+ bias=True,
991
+ groups=embed_dim,
992
+ )
993
+
994
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
995
+ if hasattr(self, "reparam_conv"):
996
+ x = self.reparam_conv(x)
997
+ return x
998
+ else:
999
+ x = self.pe(x) + x
1000
+ return x
1001
+
1002
+ def reparameterize(self) -> None:
1003
+ # Build equivalent Id tensor
1004
+ input_dim = self.in_channels // self.groups
1005
+ kernel_value = torch.zeros(
1006
+ (
1007
+ self.in_channels,
1008
+ input_dim,
1009
+ self.spatial_shape[0],
1010
+ self.spatial_shape[1],
1011
+ ),
1012
+ dtype=self.pe.weight.dtype,
1013
+ device=self.pe.weight.device,
1014
+ )
1015
+ for i in range(self.in_channels):
1016
+ kernel_value[
1017
+ i,
1018
+ i % input_dim,
1019
+ self.spatial_shape[0] // 2,
1020
+ self.spatial_shape[1] // 2,
1021
+ ] = 1
1022
+ id_tensor = kernel_value
1023
+
1024
+ # Reparameterize Id tensor and conv
1025
+ w_final = id_tensor + self.pe.weight
1026
+ b_final = self.pe.bias
1027
+
1028
+ # Introduce reparam conv
1029
+ self.reparam_conv = nn.Conv2d(
1030
+ in_channels=self.in_channels,
1031
+ out_channels=self.embed_dim,
1032
+ kernel_size=self.spatial_shape,
1033
+ stride=1,
1034
+ padding=int(self.spatial_shape[0] // 2),
1035
+ groups=self.embed_dim,
1036
+ bias=True,
1037
+ )
1038
+ self.reparam_conv.weight.data = w_final
1039
+ self.reparam_conv.bias.data = b_final
1040
+
1041
+ self.__delattr__("pe")
1042
+
1043
+
1044
+ class RepMixerBlock(nn.Module):
1045
+ """Implementation of Metaformer block with RepMixer as token mixer.
1046
+
1047
+ For more details on Metaformer structure, please refer to:
1048
+ `MetaFormer Is Actually What You Need for Vision <https://arxiv.org/pdf/2111.11418.pdf>`_
1049
+ """
1050
+
1051
+ def __init__(
1052
+ self,
1053
+ dim: int,
1054
+ kernel_size: int = 3,
1055
+ mlp_ratio: float = 4.0,
1056
+ act_layer: nn.Module = nn.GELU,
1057
+ drop: float = 0.0,
1058
+ drop_path: float = 0.0,
1059
+ use_layer_scale: bool = True,
1060
+ layer_scale_init_value: float = 1e-5,
1061
+ inference_mode: bool = False,
1062
+ ):
1063
+ """Build RepMixer Block.
1064
+
1065
+ Args:
1066
+ dim: Number of embedding dimensions.
1067
+ kernel_size: Kernel size for repmixer. Default: 3
1068
+ mlp_ratio: MLP expansion ratio. Default: 4.0
1069
+ act_layer: Activation layer. Default: ``nn.GELU``
1070
+ drop: Dropout rate. Default: 0.0
1071
+ drop_path: Drop path rate. Default: 0.0
1072
+ use_layer_scale: Flag to turn on layer scale. Default: ``True``
1073
+ layer_scale_init_value: Layer scale value at initialization. Default: 1e-5
1074
+ inference_mode: Flag to instantiate block in inference mode. Default: ``False``
1075
+ """
1076
+
1077
+ super().__init__()
1078
+
1079
+ self.token_mixer = RepMixer(
1080
+ dim,
1081
+ kernel_size=kernel_size,
1082
+ use_layer_scale=use_layer_scale,
1083
+ layer_scale_init_value=layer_scale_init_value,
1084
+ inference_mode=inference_mode,
1085
+ )
1086
+
1087
+ assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format(
1088
+ mlp_ratio
1089
+ )
1090
+ mlp_hidden_dim = int(dim * mlp_ratio)
1091
+ self.convffn = ConvFFN(
1092
+ in_channels=dim,
1093
+ hidden_channels=mlp_hidden_dim,
1094
+ act_layer=act_layer,
1095
+ drop=drop,
1096
+ )
1097
+
1098
+ # Drop Path
1099
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
1100
+
1101
+ # Layer Scale
1102
+ self.use_layer_scale = use_layer_scale
1103
+ if use_layer_scale:
1104
+ self.layer_scale = nn.Parameter(
1105
+ layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
1106
+ )
1107
+
1108
+ def forward(self, x):
1109
+ if self.use_layer_scale:
1110
+ x = self.token_mixer(x)
1111
+ x = x + self.drop_path(self.layer_scale * self.convffn(x))
1112
+ else:
1113
+ x = self.token_mixer(x)
1114
+ x = x + self.drop_path(self.convffn(x))
1115
+ return x
1116
+
1117
+
1118
+ class AttentionBlock(nn.Module):
1119
+ """Implementation of metaformer block with MHSA as token mixer.
1120
+
1121
+ For more details on Metaformer structure, please refer to:
1122
+ `MetaFormer Is Actually What You Need for Vision <https://arxiv.org/pdf/2111.11418.pdf>`_
1123
+ """
1124
+
1125
+ def __init__(
1126
+ self,
1127
+ dim: int,
1128
+ mlp_ratio: float = 4.0,
1129
+ act_layer: nn.Module = nn.GELU,
1130
+ norm_layer: nn.Module = nn.BatchNorm2d,
1131
+ drop: float = 0.0,
1132
+ drop_path: float = 0.0,
1133
+ use_layer_scale: bool = True,
1134
+ layer_scale_init_value: float = 1e-5,
1135
+ ):
1136
+ """Build Attention Block.
1137
+
1138
+ Args:
1139
+ dim: Number of embedding dimensions.
1140
+ mlp_ratio: MLP expansion ratio. Default: 4.0
1141
+ act_layer: Activation layer. Default: ``nn.GELU``
1142
+ norm_layer: Normalization layer. Default: ``nn.BatchNorm2d``
1143
+ drop: Dropout rate. Default: 0.0
1144
+ drop_path: Drop path rate. Default: 0.0
1145
+ use_layer_scale: Flag to turn on layer scale. Default: ``True``
1146
+ layer_scale_init_value: Layer scale value at initialization. Default: 1e-5
1147
+ """
1148
+
1149
+ super().__init__()
1150
+
1151
+ # Fallback, sometimes batchnorm tensors
1152
+ # do not get instantiated correctly on some processes
1153
+ # when using deepspeed + accelerate
1154
+ norm_layer_ = norm_layer(num_features=dim)
1155
+ if norm_layer_.weight.shape[0] == 0:
1156
+ norm_layer_.weight = nn.Parameter(torch.zeros(dim))
1157
+ if norm_layer_.bias.shape[0] == 0:
1158
+ norm_layer_.bias = nn.Parameter(torch.zeros(dim))
1159
+
1160
+ self.norm = norm_layer_
1161
+ self.token_mixer = MHSA(dim=dim)
1162
+
1163
+ assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format(
1164
+ mlp_ratio
1165
+ )
1166
+ mlp_hidden_dim = int(dim * mlp_ratio)
1167
+ self.convffn = ConvFFN(
1168
+ in_channels=dim,
1169
+ hidden_channels=mlp_hidden_dim,
1170
+ act_layer=act_layer,
1171
+ drop=drop,
1172
+ )
1173
+
1174
+ # Drop path
1175
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
1176
+
1177
+ # Layer Scale
1178
+ self.use_layer_scale = use_layer_scale
1179
+ if use_layer_scale:
1180
+ self.layer_scale_1 = nn.Parameter(
1181
+ layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
1182
+ )
1183
+ self.layer_scale_2 = nn.Parameter(
1184
+ layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
1185
+ )
1186
+
1187
+ def forward(self, x):
1188
+ if self.use_layer_scale:
1189
+ x = x + self.drop_path(self.layer_scale_1 * self.token_mixer(self.norm(x)))
1190
+ x = x + self.drop_path(self.layer_scale_2 * self.convffn(x))
1191
+ else:
1192
+ x = x + self.drop_path(self.token_mixer(self.norm(x)))
1193
+ x = x + self.drop_path(self.convffn(x))
1194
+ return x
1195
+
1196
+
1197
+ def basic_blocks(
1198
+ dim: int,
1199
+ block_index: int,
1200
+ num_blocks: List[int],
1201
+ token_mixer_type: str,
1202
+ kernel_size: int = 3,
1203
+ mlp_ratio: float = 4.0,
1204
+ act_layer: nn.Module = nn.GELU,
1205
+ norm_layer: nn.Module = nn.BatchNorm2d,
1206
+ drop_rate: float = 0.0,
1207
+ drop_path_rate: float = 0.0,
1208
+ use_layer_scale: bool = True,
1209
+ layer_scale_init_value: float = 1e-5,
1210
+ inference_mode=False,
1211
+ ) -> nn.Sequential:
1212
+ """Build FastViT blocks within a stage.
1213
+
1214
+ Args:
1215
+ dim: Number of embedding dimensions.
1216
+ block_index: block index.
1217
+ num_blocks: List containing number of blocks per stage.
1218
+ token_mixer_type: Token mixer type.
1219
+ kernel_size: Kernel size for repmixer.
1220
+ mlp_ratio: MLP expansion ratio.
1221
+ act_layer: Activation layer.
1222
+ norm_layer: Normalization layer.
1223
+ drop_rate: Dropout rate.
1224
+ drop_path_rate: Drop path rate.
1225
+ use_layer_scale: Flag to turn on layer scale regularization.
1226
+ layer_scale_init_value: Layer scale value at initialization.
1227
+ inference_mode: Flag to instantiate block in inference mode.
1228
+
1229
+ Returns:
1230
+ nn.Sequential object of all the blocks within the stage.
1231
+ """
1232
+ blocks = []
1233
+ for block_idx in range(num_blocks[block_index]):
1234
+ block_dpr = (
1235
+ drop_path_rate
1236
+ * (block_idx + sum(num_blocks[:block_index]))
1237
+ / (sum(num_blocks) - 1)
1238
+ )
1239
+ if token_mixer_type == "repmixer":
1240
+ blocks.append(
1241
+ RepMixerBlock(
1242
+ dim,
1243
+ kernel_size=kernel_size,
1244
+ mlp_ratio=mlp_ratio,
1245
+ act_layer=act_layer,
1246
+ drop=drop_rate,
1247
+ drop_path=block_dpr,
1248
+ use_layer_scale=use_layer_scale,
1249
+ layer_scale_init_value=layer_scale_init_value,
1250
+ inference_mode=inference_mode,
1251
+ )
1252
+ )
1253
+ elif token_mixer_type == "attention":
1254
+ blocks.append(
1255
+ AttentionBlock(
1256
+ dim,
1257
+ mlp_ratio=mlp_ratio,
1258
+ act_layer=act_layer,
1259
+ norm_layer=norm_layer,
1260
+ drop=drop_rate,
1261
+ drop_path=block_dpr,
1262
+ use_layer_scale=use_layer_scale,
1263
+ layer_scale_init_value=layer_scale_init_value,
1264
+ )
1265
+ )
1266
+ else:
1267
+ raise ValueError(
1268
+ "Token mixer type: {} not supported".format(token_mixer_type)
1269
+ )
1270
+ blocks = nn.Sequential(*blocks)
1271
+ return blocks
1272
+
1273
+
1274
+ class GlobalPool2D(nn.Module):
1275
+ """This class implements global pooling with linear projection."""
1276
+
1277
+ def __init__(self, in_dim: int, out_dim: int, *args, **kwargs) -> None:
1278
+ super().__init__()
1279
+ scale = in_dim**-0.5
1280
+ self.proj = nn.Parameter(scale * torch.randn(size=(in_dim, out_dim)))
1281
+ self.in_dim = in_dim
1282
+ self.out_dim = out_dim
1283
+
1284
+ def pool(self, x) -> Tensor:
1285
+ if x.dim() == 4:
1286
+ dims = [-2, -1]
1287
+ elif x.dim() == 5:
1288
+ dims = [-3, -2, -1]
1289
+ x = torch.mean(x, dim=dims, keepdim=False)
1290
+ return x
1291
+
1292
+ def forward(self, x: Tensor, *args, **kwargs) -> Tensor:
1293
+ # x is of shape [batch, in_dim]
1294
+ assert (
1295
+ x.dim() == 4
1296
+ ), "Input should be 4-dimensional (Batch x in_dim x in_height x in_width). Got: {}".format(
1297
+ x.shape
1298
+ )
1299
+
1300
+ # [batch, in_dim, in_height, in_width] --> [batch, in_dim]
1301
+ x = self.pool(x)
1302
+ # [batch, in_dim] x [in_dim, out_dim] --> [batch, out_dim]
1303
+ x = x @ self.proj
1304
+ return x
1305
+
1306
+
1307
+ class FastViT(nn.Module):
1308
+ """
1309
+ This class implements `FastViT architecture <https://arxiv.org/pdf/2303.14189.pdf>`_
1310
+ """
1311
+
1312
+ def __init__(
1313
+ self,
1314
+ layers,
1315
+ token_mixers: Tuple[str, ...],
1316
+ embed_dims=None,
1317
+ mlp_ratios=None,
1318
+ downsamples=None,
1319
+ se_downsamples=None,
1320
+ repmixer_kernel_size=3,
1321
+ norm_layer: nn.Module = nn.BatchNorm2d,
1322
+ act_layer: nn.Module = nn.GELU,
1323
+ num_classes=1000,
1324
+ pos_embs=None,
1325
+ down_patch_size=7,
1326
+ down_stride=2,
1327
+ drop_rate=0.0,
1328
+ drop_path_rate=0.0,
1329
+ use_layer_scale=True,
1330
+ layer_scale_init_value=1e-5,
1331
+ init_cfg=None,
1332
+ pretrained=None,
1333
+ cls_ratio=2.0,
1334
+ inference_mode=False,
1335
+ stem_scale_branch=True,
1336
+ **kwargs,
1337
+ ) -> None:
1338
+
1339
+ super().__init__()
1340
+
1341
+ self.num_classes = num_classes
1342
+ if len(layers) == 4:
1343
+ self.out_indices = [0, 2, 4, 7]
1344
+ elif len(layers) == 5:
1345
+ self.out_indices = [0, 2, 4, 7, 10]
1346
+ else:
1347
+ raise NotImplementedError("FPN is not implemented for more than 5 stages.")
1348
+
1349
+ if pos_embs is None:
1350
+ pos_embs = [None] * len(layers)
1351
+
1352
+ if se_downsamples is None:
1353
+ se_downsamples = [False] * len(layers)
1354
+
1355
+ # Convolutional stem
1356
+ self.patch_embed = convolutional_stem(3, embed_dims[0], inference_mode,
1357
+ use_scale_branch=stem_scale_branch)
1358
+
1359
+ # Build the main stages of the network architecture
1360
+ network = []
1361
+ for i in range(len(layers)):
1362
+ # Add position embeddings if requested
1363
+ if pos_embs[i] is not None:
1364
+ network.append(
1365
+ pos_embs[i](
1366
+ embed_dims[i], embed_dims[i], inference_mode=inference_mode
1367
+ )
1368
+ )
1369
+ stage = basic_blocks(
1370
+ embed_dims[i],
1371
+ i,
1372
+ layers,
1373
+ token_mixer_type=token_mixers[i],
1374
+ kernel_size=repmixer_kernel_size,
1375
+ mlp_ratio=mlp_ratios[i],
1376
+ act_layer=act_layer,
1377
+ norm_layer=norm_layer,
1378
+ drop_rate=drop_rate,
1379
+ drop_path_rate=drop_path_rate,
1380
+ use_layer_scale=use_layer_scale,
1381
+ layer_scale_init_value=layer_scale_init_value,
1382
+ inference_mode=inference_mode,
1383
+ )
1384
+ network.append(stage)
1385
+ if i >= len(layers) - 1:
1386
+ break
1387
+
1388
+ # Patch merging/downsampling between stages.
1389
+ if downsamples[i] or embed_dims[i] != embed_dims[i + 1]:
1390
+ network.append(
1391
+ PatchEmbed(
1392
+ patch_size=down_patch_size,
1393
+ stride=down_stride,
1394
+ in_channels=embed_dims[i],
1395
+ embed_dim=embed_dims[i + 1],
1396
+ inference_mode=inference_mode,
1397
+ use_se=se_downsamples[i + 1],
1398
+ )
1399
+ )
1400
+ self.network = nn.ModuleList(network)
1401
+
1402
+ # Classifier head
1403
+ self.conv_exp = MobileOneBlock(
1404
+ in_channels=embed_dims[-1],
1405
+ out_channels=int(embed_dims[-1] * cls_ratio),
1406
+ kernel_size=3,
1407
+ stride=1,
1408
+ padding=1,
1409
+ groups=embed_dims[-1],
1410
+ inference_mode=inference_mode,
1411
+ use_se=True,
1412
+ num_conv_branches=1,
1413
+ )
1414
+ self.head = (
1415
+ nn.Linear(int(embed_dims[-1] * cls_ratio), num_classes)
1416
+ if num_classes > 0
1417
+ else nn.Identity()
1418
+ )
1419
+ self.apply(self.cls_init_weights)
1420
+ self.init_cfg = copy.deepcopy(init_cfg)
1421
+
1422
+ def cls_init_weights(self, m: nn.Module) -> None:
1423
+ """Init. for classification"""
1424
+ if isinstance(m, nn.Linear):
1425
+ normal_(m.weight, std=0.02)
1426
+ if isinstance(m, nn.Linear) and m.bias is not None:
1427
+ nn.init.constant_(m.bias, 0)
1428
+
1429
+ def forward_embeddings(self, x: torch.Tensor) -> torch.Tensor:
1430
+ x = self.patch_embed(x)
1431
+ return x
1432
+
1433
+ def forward_tokens(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
1434
+ for idx, block in enumerate(self.network):
1435
+ x = block(x)
1436
+ return x
1437
+
1438
+ def forward(self, x: torch.Tensor, *args, **kwargs) -> Union[Tensor, Dict[str, Tensor]]:
1439
+ # input embedding
1440
+ x = self.forward_embeddings(x)
1441
+ # through backbone
1442
+ x = self.forward_tokens(x)
1443
+ # for image classification/embedding
1444
+ x = self.conv_exp(x)
1445
+ cls_out = self.head(x)
1446
+
1447
+ out_dict = dict()
1448
+ if kwargs.get("return_image_embeddings", False):
1449
+ out_dict.update({"logits": cls_out})
1450
+ out_dict.update({"image_embeddings": x})
1451
+ return out_dict
1452
+ else:
1453
+ return cls_out
1454
+
1455
+
1456
+ @register_model
1457
+ def fastvithd(pretrained=False, **kwargs):
1458
+ """Instantiate FastViTHD model variant."""
1459
+ layers = [2, 12, 24, 4, 2]
1460
+ embed_dims = [96, 192, 384, 768, 1536]
1461
+ mlp_ratios = [4, 4, 4, 4, 4]
1462
+ downsamples = [True, True, True, True, True]
1463
+ pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7)), partial(RepCPE, spatial_shape=(7, 7))]
1464
+ token_mixers = ("repmixer", "repmixer", "repmixer", "attention", "attention")
1465
+ model = FastViT(
1466
+ layers,
1467
+ token_mixers=token_mixers,
1468
+ embed_dims=embed_dims,
1469
+ pos_embs=pos_embs,
1470
+ mlp_ratios=mlp_ratios,
1471
+ downsamples=downsamples,
1472
+ norm_layer=LayerNormChannel,
1473
+ stem_scale_branch=False,
1474
+ inference_mode=True,
1475
+ **kwargs,
1476
+ )
1477
+ model.default_cfg = default_cfgs["fastvit_m"]
1478
+ if pretrained:
1479
+ raise ValueError("Functionality not implemented.")
1480
+ return model
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:182a922cc3b72983dd681331c7687aeb9365ea4555e84a1a092af77dca4ddc54
3
+ size 500509440