Feature Extraction
Transformers
Safetensors
custom_code
mranzinger commited on
Commit
668c73e
·
verified ·
1 Parent(s): 309430d
adaptor_attn.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+ import math
9
+ from typing import Dict, Optional
10
+
11
+ import torch
12
+ from torch import nn
13
+
14
+ from einops import rearrange
15
+ from timm.models.vision_transformer import Block
16
+
17
+ from .enable_spectral_reparam import disable_spectral_reparam, enable_spectral_reparam
18
+ from .adaptor_base import AdaptorModuleBase
19
+ from .adaptor_mlp import MLP2
20
+
21
+
22
+ class AttnFDHead(AdaptorModuleBase):
23
+ def __init__(
24
+ self,
25
+ input_size: int,
26
+ hidden_size: int,
27
+ output_size: int,
28
+ num_inner: int = 0,
29
+ pre_norm: bool = False,
30
+ device: torch.device = None,
31
+ upsample_factor: int = 1,
32
+ upsample_rank: int = 0,
33
+ **kwargs # Ignore kwargs that might be to other "mlp" verions, e.g. teacher_summary_idxs
34
+ ) -> None:
35
+ super().__init__(requires_summary_and_spatial=False)
36
+ from timm.models.vision_transformer import Block
37
+ self.blocks = nn.Sequential(*[
38
+ Block(input_size, num_heads=16, init_values=1e-5)
39
+ for _ in range(2)
40
+ ])
41
+ self.mlp = MLP2(input_size, hidden_size, output_size,
42
+ num_inner=0, pre_norm=pre_norm, device=device,
43
+ upsample_factor=upsample_factor, upsample_rank=upsample_rank, **kwargs)
44
+
45
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
46
+ x = self.blocks(x)
47
+ x = self.mlp(x)
48
+ return x
adaptor_base.py CHANGED
@@ -32,6 +32,19 @@ class RadioOutput(NamedTuple):
32
  )
33
 
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  class AdaptorBase(nn.Module):
36
  def forward(self, input: AdaptorInput) -> RadioOutput:
37
  raise NotImplementedError("Subclasses must implement this!")
 
32
  )
33
 
34
 
35
+ class AdaptorModuleBase(nn.Module):
36
+ def __init__(
37
+ self,
38
+ requires_summary_and_spatial: bool,
39
+ handles_summary_and_spatial: bool = False
40
+ ) -> None:
41
+ super().__init__()
42
+ self.requires_summary_and_spatial = requires_summary_and_spatial
43
+ self.handles_summary_and_spatial = handles_summary_and_spatial
44
+
45
+ assert not handles_summary_and_spatial or requires_summary_and_spatial, "If handles summary and spatial, must require it too!"
46
+
47
+
48
  class AdaptorBase(nn.Module):
49
  def forward(self, input: AdaptorInput) -> RadioOutput:
50
  raise NotImplementedError("Subclasses must implement this!")
adaptor_generic.py CHANGED
@@ -12,7 +12,7 @@ from torch import nn
12
  import torch.nn.functional as F
13
 
14
  from .adaptor_base import AdaptorBase, AdaptorInput, RadioOutput
15
- from .adaptor_mlp import create_mlp_from_state, create_mlp_from_config
16
 
17
 
18
  class GenericAdaptor(AdaptorBase):
 
12
  import torch.nn.functional as F
13
 
14
  from .adaptor_base import AdaptorBase, AdaptorInput, RadioOutput
15
+ from .adaptor_module_factory import create_mlp_from_state, create_mlp_from_config
16
 
17
 
18
  class GenericAdaptor(AdaptorBase):
adaptor_mlp.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
  #
3
  # NVIDIA CORPORATION and its licensors retain all intellectual property
4
  # and proprietary rights in and to this software, related documentation
@@ -15,21 +15,10 @@ from einops import rearrange
15
  from timm.models.vision_transformer import Block
16
 
17
  from .enable_spectral_reparam import disable_spectral_reparam, enable_spectral_reparam
 
18
 
19
 
20
- class MLPBase(nn.Module):
21
- def __init__(
22
- self,
23
- requires_summary_and_spatial: bool,
24
- handles_summary_and_spatial: bool = False
25
- ) -> None:
26
- super().__init__()
27
- self.requires_summary_and_spatial = requires_summary_and_spatial
28
- self.handles_summary_and_spatial = handles_summary_and_spatial
29
-
30
- assert not handles_summary_and_spatial or requires_summary_and_spatial, "If handles summary and spatial, must require it too!"
31
-
32
- class MLP(MLPBase):
33
  def __init__(self, input_size: int, hidden_size: int, output_size: int,
34
  num_inner: int = 0, device: torch.device = None, **kwargs):
35
  super(MLP, self).__init__(requires_summary_and_spatial=False)
@@ -60,7 +49,7 @@ class MLP(MLPBase):
60
  return x
61
 
62
 
63
- class MLP2(MLPBase):
64
  def __init__(self, input_size: int, hidden_size: int, output_size: int,
65
  num_inner: int = 0,
66
  pre_norm: bool = False, device: torch.device = None,
@@ -118,109 +107,3 @@ class MLP2(MLPBase):
118
  c=self._real_output_dim)
119
 
120
  return x
121
-
122
-
123
- class AttnFDHead(MLPBase):
124
- def __init__(
125
- self,
126
- input_size: int,
127
- hidden_size: int,
128
- output_size: int,
129
- num_inner: int = 0,
130
- pre_norm: bool = False,
131
- device: torch.device = None,
132
- upsample_factor: int = 1,
133
- upsample_rank: int = 0,
134
- **kwargs # Ignore kwargs that might be to other "mlp" verions, e.g. teacher_summary_idxs
135
- ) -> None:
136
- super().__init__(requires_summary_and_spatial=False)
137
- from timm.models.vision_transformer import Block
138
- self.blocks = nn.Sequential(*[
139
- Block(input_size, num_heads=16, init_values=1e-5)
140
- for _ in range(2)
141
- ])
142
- self.mlp = MLP2(input_size, hidden_size, output_size,
143
- num_inner=0, pre_norm=pre_norm, device=device,
144
- upsample_factor=upsample_factor, upsample_rank=upsample_rank, **kwargs)
145
-
146
- def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
147
- x = self.blocks(x)
148
- x = self.mlp(x)
149
- return x
150
-
151
-
152
- MLP_SUMMARY_FACTORY = {
153
- 'v1': MLP,
154
- 'v2': MLP2,
155
- }
156
-
157
- MLP_FD_FACTORY = {
158
- 'v1': MLP,
159
- 'v2': MLP2,
160
- 'attn': AttnFDHead,
161
- }
162
-
163
-
164
- def strip_prefix(state: Dict[str, torch.Tensor], prefix: str):
165
- state = {
166
- k[len(prefix):]: v
167
- for k, v in state.items()
168
- if k.startswith(prefix)
169
- }
170
- return state
171
-
172
-
173
- def get_mlp_info_from_state(version: str, state: Dict[str, torch.Tensor], prefix: str = '', spectral_weights: bool = False):
174
- state = strip_prefix(state, prefix)
175
-
176
- weight_suffix = 'weight' if not spectral_weights else 'parametrizations.weight.original'
177
-
178
- if version == 'v1':
179
- hidden_dim, input_dim = state[f'fc1.{weight_suffix}'].shape
180
- output_dim = state[f'fc2.{weight_suffix}'].shape[0]
181
-
182
- for num_inner in range(1000):
183
- k = f'inner.{num_inner}.0.weight'
184
- if k not in state:
185
- break
186
- elif version == 'v2':
187
- hidden_dim, input_dim = state[f'fc1.{weight_suffix}'].shape
188
- output_dim = state[f'final.2.{weight_suffix}'].shape[0]
189
-
190
- for num_inner in range(1000):
191
- k = f'blocks.{num_inner}.0.weight'
192
- if k not in state:
193
- break
194
- elif version == 'attn':
195
- hidden_dim, input_dim = state[f'mlp.fc1.{weight_suffix}'].shape
196
- output_dim = state[f'mlp.final.2.{weight_suffix}'].shape[0]
197
- num_inner = 0
198
- else:
199
- raise ValueError(f'Unsupported MLP version: {version}')
200
-
201
- return input_dim, hidden_dim, output_dim, num_inner
202
-
203
-
204
- def create_mlp_from_config(version: str, input_dim: int, hidden_dim: int, output_dim: int, num_inner: int, is_summary: bool = True, **kwargs):
205
- factory = MLP_SUMMARY_FACTORY if is_summary else MLP_FD_FACTORY
206
-
207
- ret: nn.Module = factory[version](input_dim, hidden_dim, output_dim, num_inner, from_config=True, **kwargs)
208
-
209
- return ret
210
-
211
-
212
- def create_mlp_from_state(version: str, state: Dict[str, torch.Tensor], prefix: str = '', spectral_weights: bool = False, is_summary: bool = True, **kwargs):
213
- state = strip_prefix(state, prefix)
214
-
215
- input_dim, hidden_dim, output_dim, num_inner = get_mlp_info_from_state(version, state, spectral_weights=spectral_weights)
216
-
217
- ret: nn.Module = create_mlp_from_config(version, input_dim, hidden_dim, output_dim, num_inner, is_summary=is_summary, **kwargs)
218
- if spectral_weights:
219
- enable_spectral_reparam(ret, init_norm_to_current=False, state_dict_guidance=state)
220
-
221
- ret.load_state_dict(state)
222
-
223
- if spectral_weights:
224
- disable_spectral_reparam(ret)
225
-
226
- return ret
 
1
+ # Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
2
  #
3
  # NVIDIA CORPORATION and its licensors retain all intellectual property
4
  # and proprietary rights in and to this software, related documentation
 
15
  from timm.models.vision_transformer import Block
16
 
17
  from .enable_spectral_reparam import disable_spectral_reparam, enable_spectral_reparam
18
+ from .adaptor_base import AdaptorModuleBase
19
 
20
 
21
+ class MLP(AdaptorModuleBase):
 
 
 
 
 
 
 
 
 
 
 
 
22
  def __init__(self, input_size: int, hidden_size: int, output_size: int,
23
  num_inner: int = 0, device: torch.device = None, **kwargs):
24
  super(MLP, self).__init__(requires_summary_and_spatial=False)
 
49
  return x
50
 
51
 
52
+ class MLP2(AdaptorModuleBase):
53
  def __init__(self, input_size: int, hidden_size: int, output_size: int,
54
  num_inner: int = 0,
55
  pre_norm: bool = False, device: torch.device = None,
 
107
  c=self._real_output_dim)
108
 
109
  return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adaptor_module_factory.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+ import math
9
+ from typing import Dict, Optional
10
+
11
+ import torch
12
+ from torch import nn
13
+
14
+ from einops import rearrange
15
+ from timm.models.vision_transformer import Block
16
+
17
+ from .enable_spectral_reparam import disable_spectral_reparam, enable_spectral_reparam
18
+ from .adaptor_mlp import MLP, MLP2
19
+ from .adaptor_attn import AttnFDHead
20
+
21
+
22
+ MLP_SUMMARY_FACTORY = {
23
+ 'v1': MLP,
24
+ 'v2': MLP2,
25
+ }
26
+
27
+ MLP_FD_FACTORY = {
28
+ 'v1': MLP,
29
+ 'v2': MLP2,
30
+ 'attn': AttnFDHead,
31
+ }
32
+
33
+
34
+ def strip_prefix(state: Dict[str, torch.Tensor], prefix: str):
35
+ state = {
36
+ k[len(prefix):]: v
37
+ for k, v in state.items()
38
+ if k.startswith(prefix)
39
+ }
40
+ return state
41
+
42
+
43
+ def get_mlp_info_from_state(version: str, state: Dict[str, torch.Tensor], prefix: str = '', spectral_weights: bool = False):
44
+ state = strip_prefix(state, prefix)
45
+
46
+ weight_suffix = 'weight' if not spectral_weights else 'parametrizations.weight.original'
47
+
48
+ if version == 'v1':
49
+ hidden_dim, input_dim = state[f'fc1.{weight_suffix}'].shape
50
+ output_dim = state[f'fc2.{weight_suffix}'].shape[0]
51
+
52
+ for num_inner in range(1000):
53
+ k = f'inner.{num_inner}.0.weight'
54
+ if k not in state:
55
+ break
56
+ elif version == 'v2':
57
+ hidden_dim, input_dim = state[f'fc1.{weight_suffix}'].shape
58
+ output_dim = state[f'final.2.{weight_suffix}'].shape[0]
59
+
60
+ for num_inner in range(1000):
61
+ k = f'blocks.{num_inner}.0.weight'
62
+ if k not in state:
63
+ break
64
+ elif version == 'attn':
65
+ hidden_dim, input_dim = state[f'mlp.fc1.{weight_suffix}'].shape
66
+ output_dim = state[f'mlp.final.2.{weight_suffix}'].shape[0]
67
+ num_inner = 0
68
+ else:
69
+ raise ValueError(f'Unsupported MLP version: {version}')
70
+
71
+ return input_dim, hidden_dim, output_dim, num_inner
72
+
73
+
74
+ def create_mlp_from_config(version: str, input_dim: int, hidden_dim: int, output_dim: int, num_inner: int, is_summary: bool = True, **kwargs):
75
+ factory = MLP_SUMMARY_FACTORY if is_summary else MLP_FD_FACTORY
76
+
77
+ ret: nn.Module = factory[version](input_dim, hidden_dim, output_dim, num_inner, from_config=True, **kwargs)
78
+
79
+ return ret
80
+
81
+
82
+ def create_mlp_from_state(version: str, state: Dict[str, torch.Tensor], prefix: str = '', spectral_weights: bool = False, is_summary: bool = True, **kwargs):
83
+ state = strip_prefix(state, prefix)
84
+
85
+ input_dim, hidden_dim, output_dim, num_inner = get_mlp_info_from_state(version, state, spectral_weights=spectral_weights)
86
+
87
+ ret: nn.Module = create_mlp_from_config(version, input_dim, hidden_dim, output_dim, num_inner, is_summary=is_summary, **kwargs)
88
+ if spectral_weights:
89
+ enable_spectral_reparam(ret, init_norm_to_current=False, state_dict_guidance=state)
90
+
91
+ ret.load_state_dict(state)
92
+
93
+ if spectral_weights:
94
+ disable_spectral_reparam(ret)
95
+
96
+ return ret
common.py CHANGED
@@ -146,7 +146,7 @@ RESOURCE_MAP = {
146
  "c-radio_v4-so400m": RadioResource(
147
  # NOTE: C-RADIO models are bound by different license terms than that present in the LICENSE file.
148
  # Please refer to the readme, or to https://huggingface.co/nvidia/C-RADIOv4-SO400M for more information.
149
- "https://huggingface.co/nvidia/C-RADIOv4-SO400M/resolve/main/c-radio-v4-so400m_half.pth.tar?download=true",
150
  patch_size=16,
151
  max_resolution=2048,
152
  preferred_resolution=Resolution(512, 512),
 
146
  "c-radio_v4-so400m": RadioResource(
147
  # NOTE: C-RADIO models are bound by different license terms than that present in the LICENSE file.
148
  # Please refer to the readme, or to https://huggingface.co/nvidia/C-RADIOv4-SO400M for more information.
149
+ "https://huggingface.co/nvidia/C-RADIOv4-SO400M/resolve/main/c-radio_v4-so400m_half.pth.tar?download=true",
150
  patch_size=16,
151
  max_resolution=2048,
152
  preferred_resolution=Resolution(512, 512),
hf_model.py CHANGED
@@ -25,7 +25,9 @@ from .common import RESOURCE_MAP, DEFAULT_VERSION
25
  # Import all required modules.
26
  from .adaptor_base import AdaptorBase, RadioOutput, AdaptorInput
27
  from .adaptor_generic import GenericAdaptor, AdaptorBase
28
- from .adaptor_mlp import create_mlp_from_config
 
 
29
  from .adaptor_registry import adaptor_registry
30
  from .cls_token import ClsToken
31
  from .dinov2_arch import dinov2_vitg14_reg
 
25
  # Import all required modules.
26
  from .adaptor_base import AdaptorBase, RadioOutput, AdaptorInput
27
  from .adaptor_generic import GenericAdaptor, AdaptorBase
28
+ from .adaptor_module_factory import create_mlp_from_config
29
+ from .adaptor_mlp import MLP, MLP2
30
+ from .adaptor_attn import AttnFDHead
31
  from .adaptor_registry import adaptor_registry
32
  from .cls_token import ClsToken
33
  from .dinov2_arch import dinov2_vitg14_reg