nebulette commited on
Commit
28d6428
·
verified ·
1 Parent(s): 2f42f2e

Upload 6 files

Browse files
Files changed (6) hide show
  1. __init__.py +4 -0
  2. config.json +19 -0
  3. configuration_tips.py +29 -0
  4. image_encoder.py +446 -0
  5. model.safetensors +3 -0
  6. modeling_tips.py +67 -0
__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .configuration_tips import TIPSv2ImageConfig
2
+ from .modeling_tips import TIPSv2ImageModel, TIPSv2ImageOutput
3
+
4
+ __all__ = ["TIPSv2ImageConfig", "TIPSv2ImageModel", "TIPSv2ImageOutput"]
config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "TIPSv2ImageModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_tips.TIPSv2ImageConfig",
7
+ "AutoModel": "modeling_tips.TIPSv2ImageModel"
8
+ },
9
+ "dtype": "float32",
10
+ "ffn_layer": "mlp",
11
+ "hidden_size": 768,
12
+ "image_size": 448,
13
+ "init_values": 1.0,
14
+ "model_type": "tipsv2",
15
+ "model_variant": "vit_base",
16
+ "num_register_tokens": 1,
17
+ "patch_size": 14,
18
+ "transformers_version": "4.57.3"
19
+ }
configuration_tips.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """TIPSv2 model configuration."""
2
+
3
+ from transformers import PretrainedConfig
4
+
5
+
6
+ class TIPSv2ImageConfig(PretrainedConfig):
7
+ """Configuration for TIPSv2 vision-language model."""
8
+
9
+ model_type = "tipsv2"
10
+
11
+ def __init__(
12
+ self,
13
+ model_variant="base",
14
+ hidden_size=768,
15
+ patch_size=14,
16
+ image_size=448,
17
+ ffn_layer="mlp",
18
+ init_values=1.0,
19
+ num_register_tokens=1,
20
+ **kwargs,
21
+ ):
22
+ super().__init__(**kwargs)
23
+ self.model_variant = model_variant
24
+ self.hidden_size = hidden_size
25
+ self.patch_size = patch_size
26
+ self.image_size = image_size
27
+ self.ffn_layer = ffn_layer
28
+ self.init_values = init_values
29
+ self.num_register_tokens = num_register_tokens
image_encoder.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+
8
+
9
+ class MLP(nn.Module):
10
+ def __init__(
11
+ self,
12
+ in_features: int,
13
+ hidden_features: int,
14
+ out_features: Optional[int] = None,
15
+ bias: bool = True,
16
+ ) -> None:
17
+ super().__init__()
18
+ out_features = out_features or in_features
19
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
20
+ self.act = nn.GELU()
21
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
22
+
23
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
24
+ return self.fc2(self.act(self.fc1(x)))
25
+
26
+
27
+ class SwiGLUFFN(nn.Module):
28
+ def __init__(
29
+ self,
30
+ in_features: int,
31
+ hidden_features: int,
32
+ out_features: Optional[int] = None,
33
+ bias: bool = True,
34
+ ) -> None:
35
+ super().__init__()
36
+ out_features = out_features or in_features
37
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
38
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
39
+
40
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
41
+ x1, x2 = self.w12(x).chunk(2, dim=-1)
42
+ return self.w3(F.silu(x1) * x2)
43
+
44
+
45
+ class PatchEmbed(nn.Module):
46
+ """
47
+ Image to patch embedding.
48
+
49
+ Input:
50
+ (B, C, H, W)
51
+ Output:
52
+ (B, N, D)
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ img_size: int = 224,
58
+ patch_size: int = 16,
59
+ in_chans: int = 3,
60
+ embed_dim: int = 768,
61
+ ) -> None:
62
+ super().__init__()
63
+ self.img_size = img_size
64
+ self.patch_size = patch_size
65
+ self.grid_size = (img_size // patch_size, img_size // patch_size)
66
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
67
+
68
+ self.proj = nn.Conv2d(
69
+ in_chans,
70
+ embed_dim,
71
+ kernel_size=patch_size,
72
+ stride=patch_size,
73
+ bias=True,
74
+ )
75
+
76
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
77
+ _, _, h, w = x.shape
78
+ if h % self.patch_size != 0 or w % self.patch_size != 0:
79
+ raise ValueError(
80
+ f"Input size {(h, w)} must be divisible by patch_size={self.patch_size}."
81
+ )
82
+
83
+ x = self.proj(x) # (B, D, H', W')
84
+ x = x.flatten(2).transpose(1, 2) # (B, N, D)
85
+ return x
86
+
87
+
88
+ class LayerScale(nn.Module):
89
+ def __init__(self, dim: int, init_values: Optional[float]) -> None:
90
+ super().__init__()
91
+ if init_values is None:
92
+ self.gamma = None
93
+ else:
94
+ self.gamma = nn.Parameter(torch.full((dim,), float(init_values)))
95
+
96
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
97
+ if self.gamma is None:
98
+ return x
99
+ return x * self.gamma
100
+
101
+
102
+ class Attention(nn.Module):
103
+ """
104
+ Standard multi-head self-attention using PyTorch SDPA.
105
+ """
106
+
107
+ def __init__(
108
+ self,
109
+ dim: int,
110
+ num_heads: int = 8,
111
+ qkv_bias: bool = True,
112
+ proj_bias: bool = True,
113
+ ) -> None:
114
+ super().__init__()
115
+ if dim % num_heads != 0:
116
+ raise ValueError(f"dim={dim} must be divisible by num_heads={num_heads}")
117
+
118
+ self.num_heads = num_heads
119
+ self.head_dim = dim // num_heads
120
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
121
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
122
+
123
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
124
+ bsz, seq_len, dim = x.shape
125
+
126
+ qkv = self.qkv(x)
127
+ qkv = qkv.view(bsz, seq_len, 3, self.num_heads, self.head_dim)
128
+ qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, H, N, Dh)
129
+ q, k, v = qkv.unbind(dim=0)
130
+
131
+ x = F.scaled_dot_product_attention(
132
+ q,
133
+ k,
134
+ v,
135
+ attn_mask=None,
136
+ dropout_p=0.0,
137
+ is_causal=False,
138
+ )
139
+ x = x.transpose(1, 2).contiguous().view(bsz, seq_len, dim)
140
+ x = self.proj(x)
141
+ return x
142
+
143
+
144
+ def build_ffn(
145
+ ffn_layer: str,
146
+ dim: int,
147
+ mlp_ratio: float,
148
+ bias: bool = True,
149
+ ) -> nn.Module:
150
+ hidden_dim = int(dim * mlp_ratio)
151
+
152
+ if ffn_layer == "mlp":
153
+ return MLP(
154
+ in_features=dim,
155
+ hidden_features=hidden_dim,
156
+ out_features=dim,
157
+ bias=bias,
158
+ )
159
+ if ffn_layer in {"swiglu", "swiglufused"}:
160
+ return SwiGLUFFN(
161
+ in_features=dim,
162
+ hidden_features=hidden_dim,
163
+ out_features=dim,
164
+ bias=bias,
165
+ )
166
+
167
+ raise ValueError(f"Unsupported ffn_layer: {ffn_layer}")
168
+
169
+
170
+ class Block(nn.Module):
171
+ def __init__(
172
+ self,
173
+ dim: int,
174
+ num_heads: int,
175
+ mlp_ratio: float = 4.0,
176
+ qkv_bias: bool = True,
177
+ proj_bias: bool = True,
178
+ ffn_bias: bool = True,
179
+ init_values: Optional[float] = None,
180
+ ffn_layer: str = "mlp",
181
+ norm_eps: float = 1e-6,
182
+ ) -> None:
183
+ super().__init__()
184
+ self.norm1 = nn.LayerNorm(dim, eps=norm_eps)
185
+ self.attn = Attention(
186
+ dim=dim,
187
+ num_heads=num_heads,
188
+ qkv_bias=qkv_bias,
189
+ proj_bias=proj_bias,
190
+ )
191
+ self.ls1 = LayerScale(dim, init_values)
192
+
193
+ self.norm2 = nn.LayerNorm(dim, eps=norm_eps)
194
+ self.mlp = build_ffn(
195
+ ffn_layer=ffn_layer,
196
+ dim=dim,
197
+ mlp_ratio=mlp_ratio,
198
+ bias=ffn_bias,
199
+ )
200
+ self.ls2 = LayerScale(dim, init_values)
201
+
202
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
203
+ x = x + self.ls1(self.attn(self.norm1(x)))
204
+ x = x + self.ls2(self.mlp(self.norm2(x)))
205
+ return x
206
+
207
+
208
+ class VisionTransformer(nn.Module):
209
+ def __init__(
210
+ self,
211
+ image_size: int = 224,
212
+ patch_size: int = 16,
213
+ in_chans: int = 3,
214
+ hidden_size: int = 768,
215
+ num_layers: int = 12,
216
+ num_heads: int = 12,
217
+ mlp_ratio: float = 4.0,
218
+ qkv_bias: bool = True,
219
+ ffn_bias: bool = True,
220
+ proj_bias: bool = True,
221
+ init_values: Optional[float] = None,
222
+ ffn_layer: str = "mlp",
223
+ num_register_tokens: int = 0,
224
+ norm_eps: float = 1e-6,
225
+ ) -> None:
226
+ super().__init__()
227
+ self.embed_dim = hidden_size
228
+ self.patch_size = patch_size
229
+ self.num_register_tokens = num_register_tokens
230
+ self.num_tokens = 1 # cls token
231
+
232
+ self.patch_embed = PatchEmbed(
233
+ img_size=image_size,
234
+ patch_size=patch_size,
235
+ in_chans=in_chans,
236
+ embed_dim=hidden_size,
237
+ )
238
+ num_patches = self.patch_embed.num_patches
239
+
240
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size))
241
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, hidden_size))
242
+
243
+ self.register_tokens = (
244
+ nn.Parameter(torch.zeros(1, num_register_tokens, hidden_size))
245
+ if num_register_tokens > 0
246
+ else None
247
+ )
248
+ self.mask_token = nn.Parameter(torch.zeros(1, hidden_size))
249
+ self.blocks = nn.ModuleList(
250
+ [
251
+ Block(
252
+ dim=hidden_size,
253
+ num_heads=num_heads,
254
+ mlp_ratio=mlp_ratio,
255
+ qkv_bias=qkv_bias,
256
+ proj_bias=proj_bias,
257
+ ffn_bias=ffn_bias,
258
+ init_values=init_values,
259
+ ffn_layer=ffn_layer,
260
+ norm_eps=norm_eps,
261
+ )
262
+ for _ in range(num_layers)
263
+ ]
264
+ )
265
+ self.norm = nn.LayerNorm(hidden_size, eps=norm_eps)
266
+ self.head = nn.Identity()
267
+
268
+ self.reset_parameters()
269
+
270
+ def reset_parameters(self) -> None:
271
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
272
+ nn.init.normal_(self.cls_token, std=1e-6)
273
+ nn.init.normal_(self.mask_token, std=1e-6)
274
+
275
+ if self.register_tokens is not None:
276
+ nn.init.normal_(self.register_tokens, std=1e-6)
277
+
278
+ self.apply(self._init_module)
279
+
280
+ @staticmethod
281
+ def _init_module(module: nn.Module) -> None:
282
+ if isinstance(module, nn.Linear):
283
+ nn.init.trunc_normal_(module.weight, std=0.02)
284
+ if module.bias is not None:
285
+ nn.init.zeros_(module.bias)
286
+ elif isinstance(module, nn.Conv2d):
287
+ nn.init.trunc_normal_(module.weight, std=0.02)
288
+ if module.bias is not None:
289
+ nn.init.zeros_(module.bias)
290
+ elif isinstance(module, nn.LayerNorm):
291
+ nn.init.ones_(module.weight)
292
+ nn.init.zeros_(module.bias)
293
+
294
+ def interpolate_pos_encoding(
295
+ self,
296
+ x: torch.Tensor,
297
+ width: int,
298
+ height: int,
299
+ ) -> torch.Tensor:
300
+ """
301
+ Interpolate positional embeddings for arbitrary image size.
302
+ Positional embedding covers cls + patch tokens only.
303
+ Register tokens are inserted after position embedding is added.
304
+ """
305
+ dtype = x.dtype
306
+ num_tokens = x.shape[1] - 1
307
+ num_ref_tokens = self.pos_embed.shape[1] - 1
308
+
309
+ grid_h = height // self.patch_size
310
+ grid_w = width // self.patch_size
311
+
312
+ if num_tokens == num_ref_tokens and grid_h * grid_w == num_ref_tokens:
313
+ return self.pos_embed.to(dtype=dtype)
314
+
315
+ cls_pos = self.pos_embed[:, :1]
316
+ patch_pos = self.pos_embed[:, 1:]
317
+
318
+ ref_size = int(math.sqrt(num_ref_tokens))
319
+ if ref_size * ref_size != num_ref_tokens:
320
+ raise ValueError("Reference positional embedding is not a square grid.")
321
+
322
+ patch_pos = patch_pos.view(1, ref_size, ref_size, self.embed_dim).permute(
323
+ 0, 3, 1, 2
324
+ )
325
+ patch_pos = F.interpolate(
326
+ patch_pos,
327
+ size=(grid_h, grid_w),
328
+ mode="bicubic",
329
+ align_corners=False,
330
+ )
331
+ patch_pos = patch_pos.permute(0, 2, 3, 1).reshape(
332
+ 1, grid_h * grid_w, self.embed_dim
333
+ )
334
+
335
+ return torch.cat([cls_pos, patch_pos], dim=1).to(dtype=dtype)
336
+
337
+ def prepare_tokens_with_masks(
338
+ self,
339
+ x: torch.Tensor,
340
+ masks: Optional[torch.Tensor] = None,
341
+ ) -> torch.Tensor:
342
+ batch_size, _, height, width = x.shape
343
+
344
+ x = self.patch_embed(x) # (B, N, D)
345
+
346
+ if masks is not None:
347
+ if masks.shape != x.shape[:2]:
348
+ raise ValueError(
349
+ f"masks shape {masks.shape} must match patch sequence shape {x.shape[:2]}"
350
+ )
351
+ x = torch.where(
352
+ masks.unsqueeze(-1),
353
+ self.mask_token.to(dtype=x.dtype).unsqueeze(0),
354
+ x,
355
+ )
356
+
357
+ cls_token = self.cls_token.expand(batch_size, -1, -1)
358
+ x = torch.cat([cls_token, x], dim=1)
359
+ x = x + self.interpolate_pos_encoding(x, width=width, height=height)
360
+
361
+ if self.register_tokens is not None:
362
+ reg = self.register_tokens.expand(batch_size, -1, -1)
363
+ x = torch.cat([x[:, :1], reg, x[:, 1:]], dim=1)
364
+
365
+ return x
366
+
367
+ def forward(
368
+ self,
369
+ x: torch.Tensor,
370
+ masks: Optional[torch.Tensor] = None,
371
+ ) -> dict[str, torch.Tensor]:
372
+ x = self.prepare_tokens_with_masks(x, masks)
373
+
374
+ for block in self.blocks:
375
+ x = block(x)
376
+
377
+ x_norm = self.norm(x)
378
+
379
+ reg_start = 1
380
+ reg_end = 1 + self.num_register_tokens
381
+
382
+ cls_token = x_norm[:, :1]
383
+ register_tokens = x_norm[:, reg_start:reg_end]
384
+ patch_tokens = x_norm[:, reg_end:]
385
+
386
+ return self.head(cls_token), self.head(register_tokens), patch_tokens
387
+
388
+
389
+ def vit_small(patch_size: int = 14, **kwargs) -> VisionTransformer:
390
+ return VisionTransformer(
391
+ patch_size=patch_size,
392
+ hidden_size=384,
393
+ num_layers=12,
394
+ num_heads=6,
395
+ mlp_ratio=4.0,
396
+ num_register_tokens=1,
397
+ **kwargs,
398
+ )
399
+
400
+
401
+ def vit_base(patch_size: int = 14, **kwargs) -> VisionTransformer:
402
+ return VisionTransformer(
403
+ patch_size=patch_size,
404
+ hidden_size=768,
405
+ num_layers=12,
406
+ num_heads=12,
407
+ mlp_ratio=4.0,
408
+ num_register_tokens=1,
409
+ **kwargs,
410
+ )
411
+
412
+
413
+ def vit_large(patch_size: int = 14, **kwargs) -> VisionTransformer:
414
+ return VisionTransformer(
415
+ patch_size=patch_size,
416
+ hidden_size=1024,
417
+ num_layers=24,
418
+ num_heads=16,
419
+ mlp_ratio=4.0,
420
+ num_register_tokens=1,
421
+ **kwargs,
422
+ )
423
+
424
+
425
+ def vit_so400m(patch_size: int = 14, **kwargs) -> VisionTransformer:
426
+ return VisionTransformer(
427
+ patch_size=patch_size,
428
+ hidden_size=1152,
429
+ num_layers=27,
430
+ num_heads=16,
431
+ mlp_ratio=4304 / 1152,
432
+ num_register_tokens=1,
433
+ **kwargs,
434
+ )
435
+
436
+
437
+ def vit_giant2(patch_size: int = 14, **kwargs) -> VisionTransformer:
438
+ return VisionTransformer(
439
+ patch_size=patch_size,
440
+ hidden_size=1536,
441
+ num_layers=40,
442
+ num_heads=24,
443
+ mlp_ratio=4.0,
444
+ num_register_tokens=1,
445
+ **kwargs,
446
+ )
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c9c6dad533bc4331ba20fec2a992bc259d8d5a226f02dffd6be5a274f6f7cb04
3
+ size 345282432
modeling_tips.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """TIPSv2 image encoder for HuggingFace."""
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import torch
6
+ from transformers import AutoConfig, AutoModel, PreTrainedModel
7
+
8
+ from .configuration_tips import TIPSv2ImageConfig
9
+ from .image_encoder import (
10
+ VisionTransformer,
11
+ vit_base,
12
+ vit_giant2,
13
+ vit_large,
14
+ vit_small,
15
+ vit_so400m,
16
+ )
17
+
18
+
19
+ MODEL_INIT_FUNCTIONS = {
20
+ "vit_small": vit_small,
21
+ "vit_base": vit_base,
22
+ "vit_large": vit_large,
23
+ "vit_so400m": vit_so400m,
24
+ "vit_giant2": vit_giant2,
25
+ }
26
+
27
+
28
+ @dataclass
29
+ class TIPSv2ImageOutput:
30
+ cls_token: torch.Tensor
31
+ register_tokens: torch.Tensor
32
+ patch_tokens: torch.Tensor
33
+
34
+
35
+ class TIPSv2ImageModel(PreTrainedModel):
36
+ config_class = TIPSv2ImageConfig
37
+ base_model_prefix = "model"
38
+ all_tied_weights_keys = dict()
39
+
40
+ def __init__(self, config: TIPSv2ImageConfig):
41
+ super().__init__(config)
42
+
43
+ if config.model_variant not in MODEL_INIT_FUNCTIONS:
44
+ raise ValueError(
45
+ f"Unknown model_variant={config.model_variant!r}. "
46
+ f"Expected one of {list(MODEL_INIT_FUNCTIONS)}."
47
+ )
48
+
49
+ build_fn = MODEL_INIT_FUNCTIONS[config.model_variant]
50
+ self.model: VisionTransformer = build_fn(
51
+ image_size=config.image_size,
52
+ patch_size=config.patch_size,
53
+ ffn_layer=config.ffn_layer,
54
+ init_values=config.init_values,
55
+ )
56
+
57
+ def forward(self, pixel_values: torch.Tensor) -> TIPSv2ImageOutput:
58
+ cls_token, register_tokens, patch_tokens = self.model(pixel_values)
59
+ return TIPSv2ImageOutput(
60
+ cls_token=cls_token,
61
+ register_tokens=register_tokens,
62
+ patch_tokens=patch_tokens,
63
+ )
64
+
65
+
66
+ AutoConfig.register("tipsv2", TIPSv2ImageConfig, exist_ok=True)
67
+ AutoModel.register(TIPSv2ImageConfig, TIPSv2ImageModel, exist_ok=True)