AbstractPhil commited on
Commit
d677e5c
Β·
verified Β·
1 Parent(s): 27f893b

Create modeling_geolip_vit.py

Browse files
Files changed (1) hide show
  1. modeling_geolip_vit.py +314 -0
modeling_geolip_vit.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================================
2
+ # GeoLIP ViT: HuggingFace AutoModel
3
+ #
4
+ # Usage:
5
+ # from transformers import AutoModel
6
+ # model = AutoModel.from_pretrained("AbstractPhil/geolip-vit-base-x3",
7
+ # trust_remote_code=True)
8
+ #
9
+ # from torchvision import transforms
10
+ # transform = transforms.Compose([
11
+ # transforms.Resize((224, 224)),
12
+ # transforms.ToTensor(),
13
+ # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
14
+ # ])
15
+ # pixel_values = transform(image).unsqueeze(0)
16
+ # outputs = model(pixel_values)
17
+ #
18
+ # # 128-d embedding on hypersphere (L2-normalized)
19
+ # embedding = outputs.embedding # (B, 128)
20
+ #
21
+ # # Multi-label classification logits (80 COCO classes)
22
+ # logits = outputs.logits # (B, 80) β€” if soup_enabled
23
+ #
24
+ # # Triangulation distances to 256 constellation anchors
25
+ # triangulation = outputs.triangulation # (B, 256)
26
+ #
27
+ # # Nearest anchor index per sample
28
+ # nearest = outputs.nearest # (B,)
29
+ #
30
+ # # Geometric diagnostics
31
+ # diagnostics = outputs.diagnostics # dict
32
+ # ============================================================================
33
+
34
+ import math
35
+ import torch
36
+ import torch.nn as nn
37
+ import torch.nn.functional as F
38
+ from transformers import PretrainedConfig, PreTrainedModel
39
+ from dataclasses import dataclass, field
40
+ from typing import Optional, Dict, Any
41
+
42
+
43
+ # ══════════════════════════════════════════════════════════════════
44
+ # CONFIG
45
+ # ══════════════════════════════════════════════════════════════════
46
+
47
+ class GeoLIPViTConfig(PretrainedConfig):
48
+ model_type = "geolip_vit"
49
+
50
+ def __init__(
51
+ self,
52
+ image_size=224,
53
+ patch_size=16,
54
+ hidden_size=384,
55
+ num_attention_heads=6,
56
+ num_hidden_layers=6,
57
+ intermediate_size=1536,
58
+ output_dim=128,
59
+ n_anchors=256,
60
+ n_comp=8,
61
+ d_comp=64,
62
+ n_classes=80,
63
+ hidden_dropout_prob=0.1,
64
+ soup_enabled=True,
65
+ consensus_cv=0.2731,
66
+ experts=None,
67
+ **kwargs,
68
+ ):
69
+ super().__init__(**kwargs)
70
+ self.image_size = image_size
71
+ self.patch_size = patch_size
72
+ self.hidden_size = hidden_size
73
+ self.num_attention_heads = num_attention_heads
74
+ self.num_hidden_layers = num_hidden_layers
75
+ self.intermediate_size = intermediate_size
76
+ self.output_dim = output_dim
77
+ self.n_anchors = n_anchors
78
+ self.n_comp = n_comp
79
+ self.d_comp = d_comp
80
+ self.n_classes = n_classes
81
+ self.hidden_dropout_prob = hidden_dropout_prob
82
+ self.soup_enabled = soup_enabled
83
+ self.consensus_cv = consensus_cv
84
+ self.experts = experts or ["clip_l14_openai", "dinov2_b14", "siglip_b16_384"]
85
+
86
+
87
+ # ══════════════════════════════════════════════════════════════════
88
+ # OUTPUT
89
+ # ══════════════════════════════════════════════════════════════════
90
+
91
+ @dataclass
92
+ class GeoLIPViTOutput:
93
+ """
94
+ Output fields:
95
+ embedding: (B, output_dim) L2-normalized on hypersphere
96
+ logits: (B, n_classes) multi-label classification (if soup_enabled)
97
+ triangulation: (B, n_anchors) distances to constellation anchors
98
+ nearest: (B,) nearest anchor index
99
+ patch_tokens: (B, n_patches, hidden_size) pre-pooling patch representations
100
+ diagnostics: dict geometric metrics
101
+ """
102
+ embedding: torch.Tensor = None
103
+ logits: Optional[torch.Tensor] = None
104
+ triangulation: Optional[torch.Tensor] = None
105
+ nearest: Optional[torch.Tensor] = None
106
+ patch_tokens: Optional[torch.Tensor] = None
107
+ diagnostics: Optional[Dict[str, Any]] = None
108
+
109
+
110
+ # ══════════════════════════════════════════════════════════════════
111
+ # GEOMETRIC COMPONENTS
112
+ # ══════════════════════════════════════════════════════════════════
113
+
114
+ class Constellation(nn.Module):
115
+ def __init__(self, n_anchors, d):
116
+ super().__init__()
117
+ self.n_anchors = n_anchors
118
+ self.anchors = nn.Parameter(F.normalize(torch.randn(n_anchors, d), dim=-1))
119
+
120
+ def triangulate(self, emb):
121
+ a = F.normalize(self.anchors, dim=-1)
122
+ cos = emb @ a.T
123
+ return 1.0 - cos, cos.argmax(dim=-1)
124
+
125
+
126
+ class Patchwork(nn.Module):
127
+ def __init__(self, n_anchors, n_comp, d_comp):
128
+ super().__init__()
129
+ self.n_comp = n_comp
130
+ self.n_anchors = n_anchors
131
+ asgn = torch.arange(n_anchors) % n_comp
132
+ self.register_buffer("asgn", asgn)
133
+ # Compute input sizes from ints, not tensors (meta-tensor safe)
134
+ anchors_per_comp = n_anchors // n_comp
135
+ remainder = n_anchors % n_comp
136
+ self.comps = nn.ModuleList([nn.Sequential(
137
+ nn.Linear(anchors_per_comp + (1 if k < remainder else 0), d_comp * 2),
138
+ nn.GELU(),
139
+ nn.Linear(d_comp * 2, d_comp), nn.LayerNorm(d_comp))
140
+ for k in range(n_comp)])
141
+
142
+ def forward(self, tri):
143
+ return torch.cat([self.comps[k](tri[:, self.asgn == k])
144
+ for k in range(self.n_comp)], -1)
145
+
146
+
147
+ # ══════════════════════════════════════════════════════════════════
148
+ # MODEL
149
+ # ══════════════════════════════════════════════════════════════════
150
+
151
+ class GeoLIPViTModel(PreTrainedModel):
152
+ """
153
+ From-scratch Vision Transformer producing L2-normalized embeddings
154
+ on a 128-d hypersphere, geometrically anchored by a constellation
155
+ of 256 reference points trained via 3-expert consensus distillation.
156
+
157
+ The encoder is trained from Xavier initialization against consensus
158
+ targets from CLIP ViT-L/14, DINOv2 ViT-B/14, and SigLIP ViT-B/16.
159
+
160
+ Optional soup pipeline (constellation + patchwork + classifier)
161
+ provides multi-label COCO classification from the embedding.
162
+
163
+ Output fields:
164
+ embedding: (B, 128) L2-normalized, consensus-aligned
165
+ logits: (B, 80) multi-label COCO logits (if soup_enabled)
166
+ triangulation: (B, 256) distances to constellation anchors
167
+ nearest: (B,) nearest anchor index
168
+ patch_tokens: (B, 196, 384) pre-pooling patch representations
169
+ diagnostics: dict geometric metrics
170
+ """
171
+ config_class = GeoLIPViTConfig
172
+
173
+ def __init__(self, config):
174
+ super().__init__(config)
175
+ self.config = config
176
+
177
+ n_patches = (config.image_size // config.patch_size) ** 2
178
+
179
+ # ── Encoder ──
180
+ self.patch_embed = nn.Conv2d(
181
+ 3, config.hidden_size,
182
+ kernel_size=config.patch_size, stride=config.patch_size)
183
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
184
+ self.pos_embed = nn.Parameter(
185
+ torch.zeros(1, n_patches + 1, config.hidden_size))
186
+ self.embed_norm = nn.LayerNorm(config.hidden_size)
187
+ self.embed_drop = nn.Dropout(config.hidden_dropout_prob)
188
+
189
+ # Individual layers for geometric injection between each
190
+ self.layers = nn.ModuleList([
191
+ nn.TransformerEncoderLayer(
192
+ d_model=config.hidden_size,
193
+ nhead=config.num_attention_heads,
194
+ dim_feedforward=config.intermediate_size,
195
+ dropout=config.hidden_dropout_prob,
196
+ activation="gelu",
197
+ batch_first=True,
198
+ norm_first=True)
199
+ for _ in range(config.num_hidden_layers)])
200
+
201
+ # Geometric injection: pool β†’ anchor_dim β†’ triangulate β†’ hidden_size
202
+ self.geo_pool_proj = nn.Linear(config.hidden_size, config.output_dim)
203
+ self.geo_tri_proj = nn.Sequential(
204
+ nn.Linear(config.n_anchors, config.hidden_size), nn.GELU(),
205
+ nn.LayerNorm(config.hidden_size))
206
+
207
+ self.output_proj = nn.Sequential(
208
+ nn.Linear(config.hidden_size, config.hidden_size),
209
+ nn.GELU(),
210
+ nn.LayerNorm(config.hidden_size),
211
+ nn.Linear(config.hidden_size, config.output_dim),
212
+ )
213
+
214
+ # ── Soup Pipeline (optional) ──
215
+ if getattr(config, "soup_enabled", False):
216
+ self.constellation = Constellation(config.n_anchors, config.output_dim)
217
+ self.patchwork = Patchwork(
218
+ config.n_anchors, config.n_comp, config.d_comp)
219
+ pw_dim = config.n_comp * config.d_comp
220
+ self.classifier = nn.Sequential(
221
+ nn.Linear(pw_dim + config.output_dim, pw_dim),
222
+ nn.GELU(), nn.LayerNorm(pw_dim), nn.Dropout(0.0),
223
+ nn.Linear(pw_dim, config.n_classes))
224
+ else:
225
+ self.constellation = None
226
+ self.patchwork = None
227
+ self.classifier = None
228
+
229
+ self.post_init()
230
+
231
+ def _init_weights(self, module):
232
+ if isinstance(module, nn.Linear):
233
+ nn.init.xavier_uniform_(module.weight)
234
+ if module.bias is not None:
235
+ nn.init.zeros_(module.bias)
236
+ elif isinstance(module, nn.Conv2d):
237
+ nn.init.xavier_uniform_(module.weight)
238
+ if module.bias is not None:
239
+ nn.init.zeros_(module.bias)
240
+ elif isinstance(module, nn.LayerNorm):
241
+ nn.init.ones_(module.weight)
242
+ nn.init.zeros_(module.bias)
243
+
244
+ def forward(self, pixel_values, output_patch_tokens=False, **kwargs):
245
+ B = pixel_values.shape[0]
246
+
247
+ # ── Encode ──
248
+ x = self.patch_embed(pixel_values)
249
+ x = x.flatten(2).transpose(1, 2)
250
+
251
+ cls = self.cls_token.expand(B, -1, -1)
252
+ x = torch.cat([cls, x], dim=1)
253
+ x = x + self.pos_embed
254
+ x = self.embed_drop(self.embed_norm(x))
255
+
256
+ # ── Transformer with geometric injection ──
257
+ # Get anchors for triangulation (from constellation if available)
258
+ if self.constellation is not None:
259
+ anchors_n = F.normalize(self.constellation.anchors.detach(), dim=-1)
260
+ else:
261
+ anchors_n = None
262
+
263
+ for layer in self.layers:
264
+ if anchors_n is not None:
265
+ # Pool β†’ project β†’ triangulate β†’ geo token
266
+ pooled = x[:, 1:, :].mean(dim=1)
267
+ geo_128 = F.normalize(self.geo_pool_proj(pooled), dim=-1)
268
+ tri_dists = 1.0 - geo_128 @ anchors_n.T
269
+ geo_token = self.geo_tri_proj(tri_dists).unsqueeze(1)
270
+ x_with_geo = torch.cat([geo_token, x], dim=1)
271
+ x_with_geo = layer(x_with_geo)
272
+ x = x_with_geo[:, 1:, :]
273
+ else:
274
+ x = layer(x)
275
+
276
+ # ── Pool + Project ──
277
+ patch_tokens = x[:, 1:, :]
278
+ pooled = patch_tokens.mean(dim=1)
279
+ embedding = F.normalize(self.output_proj(pooled), dim=-1)
280
+
281
+ # ── Soup Pipeline ──
282
+ logits = None
283
+ triangulation = None
284
+ nearest = None
285
+ diagnostics = {}
286
+
287
+ if self.constellation is not None:
288
+ tri, near = self.constellation.triangulate(embedding)
289
+ triangulation = tri
290
+ nearest = near
291
+
292
+ if self.patchwork is not None and self.classifier is not None:
293
+ pw = self.patchwork(tri)
294
+ logits = self.classifier(torch.cat([pw, embedding], -1))
295
+
296
+ # Geometric diagnostics
297
+ with torch.no_grad():
298
+ anchors_n = F.normalize(self.constellation.anchors, dim=-1)
299
+ cos_to_anchors = embedding @ anchors_n.T
300
+ diagnostics = {
301
+ "nearest_cos": cos_to_anchors.max(dim=-1).values.mean().item(),
302
+ "mean_anchor_cos": cos_to_anchors.mean().item(),
303
+ "n_active_anchors": near.unique().numel(),
304
+ "embedding_norm": embedding.norm(dim=-1).mean().item(),
305
+ }
306
+
307
+ return GeoLIPViTOutput(
308
+ embedding=embedding,
309
+ logits=logits,
310
+ triangulation=triangulation,
311
+ nearest=nearest,
312
+ patch_tokens=patch_tokens if output_patch_tokens else None,
313
+ diagnostics=diagnostics,
314
+ )