| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import PretrainedConfig, PreTrainedModel |
| from dataclasses import dataclass, field |
| from typing import Optional, Dict, Any |
|
|
|
|
| |
| |
| |
|
|
| class GeoLIPViTConfig(PretrainedConfig): |
| model_type = "geolip_vit" |
|
|
| def __init__( |
| self, |
| image_size=224, |
| patch_size=16, |
| hidden_size=384, |
| num_attention_heads=6, |
| num_hidden_layers=6, |
| intermediate_size=1536, |
| output_dim=128, |
| n_anchors=256, |
| n_comp=8, |
| d_comp=64, |
| n_classes=80, |
| hidden_dropout_prob=0.1, |
| soup_enabled=True, |
| consensus_cv=0.2731, |
| experts=None, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
| self.image_size = image_size |
| self.patch_size = patch_size |
| self.hidden_size = hidden_size |
| self.num_attention_heads = num_attention_heads |
| self.num_hidden_layers = num_hidden_layers |
| self.intermediate_size = intermediate_size |
| self.output_dim = output_dim |
| self.n_anchors = n_anchors |
| self.n_comp = n_comp |
| self.d_comp = d_comp |
| self.n_classes = n_classes |
| self.hidden_dropout_prob = hidden_dropout_prob |
| self.soup_enabled = soup_enabled |
| self.consensus_cv = consensus_cv |
| self.experts = experts or ["clip_l14_openai", "dinov2_b14", "siglip_b16_384"] |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class GeoLIPViTOutput: |
| """ |
| Output fields: |
| embedding: (B, output_dim) L2-normalized on hypersphere |
| logits: (B, n_classes) multi-label classification (if soup_enabled) |
| triangulation: (B, n_anchors) distances to constellation anchors |
| nearest: (B,) nearest anchor index |
| patch_tokens: (B, n_patches, hidden_size) pre-pooling patch representations |
| diagnostics: dict geometric metrics |
| """ |
| embedding: torch.Tensor = None |
| logits: Optional[torch.Tensor] = None |
| triangulation: Optional[torch.Tensor] = None |
| nearest: Optional[torch.Tensor] = None |
| patch_tokens: Optional[torch.Tensor] = None |
| diagnostics: Optional[Dict[str, Any]] = None |
|
|
|
|
| |
| |
| |
|
|
| class Constellation(nn.Module): |
| def __init__(self, n_anchors, d): |
| super().__init__() |
| self.n_anchors = n_anchors |
| self.anchors = nn.Parameter(F.normalize(torch.randn(n_anchors, d), dim=-1)) |
|
|
| def triangulate(self, emb): |
| a = F.normalize(self.anchors, dim=-1) |
| cos = emb @ a.T |
| return 1.0 - cos, cos.argmax(dim=-1) |
|
|
|
|
| class Patchwork(nn.Module): |
| def __init__(self, n_anchors, n_comp, d_comp): |
| super().__init__() |
| self.n_comp = n_comp |
| self.n_anchors = n_anchors |
| asgn = torch.arange(n_anchors) % n_comp |
| self.register_buffer("asgn", asgn) |
| |
| anchors_per_comp = n_anchors // n_comp |
| remainder = n_anchors % n_comp |
| self.comps = nn.ModuleList([nn.Sequential( |
| nn.Linear(anchors_per_comp + (1 if k < remainder else 0), d_comp * 2), |
| nn.GELU(), |
| nn.Linear(d_comp * 2, d_comp), nn.LayerNorm(d_comp)) |
| for k in range(n_comp)]) |
|
|
| def forward(self, tri): |
| return torch.cat([self.comps[k](tri[:, self.asgn == k]) |
| for k in range(self.n_comp)], -1) |
|
|
|
|
| |
| |
| |
|
|
| class GeoLIPViTModel(PreTrainedModel): |
| """ |
| From-scratch Vision Transformer producing L2-normalized embeddings |
| on a 128-d hypersphere, geometrically anchored by a constellation |
| of 256 reference points trained via 3-expert consensus distillation. |
| |
| The encoder is trained from Xavier initialization against consensus |
| targets from CLIP ViT-L/14, DINOv2 ViT-B/14, and SigLIP ViT-B/16. |
| |
| Optional soup pipeline (constellation + patchwork + classifier) |
| provides multi-label COCO classification from the embedding. |
| |
| Output fields: |
| embedding: (B, 128) L2-normalized, consensus-aligned |
| logits: (B, 80) multi-label COCO logits (if soup_enabled) |
| triangulation: (B, 256) distances to constellation anchors |
| nearest: (B,) nearest anchor index |
| patch_tokens: (B, 196, 384) pre-pooling patch representations |
| diagnostics: dict geometric metrics |
| """ |
| config_class = GeoLIPViTConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.config = config |
|
|
| n_patches = (config.image_size // config.patch_size) ** 2 |
|
|
| |
| self.patch_embed = nn.Conv2d( |
| 3, config.hidden_size, |
| kernel_size=config.patch_size, stride=config.patch_size) |
| self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) |
| self.pos_embed = nn.Parameter( |
| torch.zeros(1, n_patches + 1, config.hidden_size)) |
| self.embed_norm = nn.LayerNorm(config.hidden_size) |
| self.embed_drop = nn.Dropout(config.hidden_dropout_prob) |
|
|
| |
| self.layers = nn.ModuleList([ |
| nn.TransformerEncoderLayer( |
| d_model=config.hidden_size, |
| nhead=config.num_attention_heads, |
| dim_feedforward=config.intermediate_size, |
| dropout=config.hidden_dropout_prob, |
| activation="gelu", |
| batch_first=True, |
| norm_first=True) |
| for _ in range(config.num_hidden_layers)]) |
|
|
| |
| self.geo_pool_proj = nn.Linear(config.hidden_size, config.output_dim) |
| self.geo_tri_proj = nn.Sequential( |
| nn.Linear(config.n_anchors, config.hidden_size), nn.GELU(), |
| nn.LayerNorm(config.hidden_size)) |
|
|
| self.output_proj = nn.Sequential( |
| nn.Linear(config.hidden_size, config.hidden_size), |
| nn.GELU(), |
| nn.LayerNorm(config.hidden_size), |
| nn.Linear(config.hidden_size, config.output_dim), |
| ) |
|
|
| |
| if getattr(config, "soup_enabled", False): |
| self.constellation = Constellation(config.n_anchors, config.output_dim) |
| self.patchwork = Patchwork( |
| config.n_anchors, config.n_comp, config.d_comp) |
| pw_dim = config.n_comp * config.d_comp |
| self.classifier = nn.Sequential( |
| nn.Linear(pw_dim + config.output_dim, pw_dim), |
| nn.GELU(), nn.LayerNorm(pw_dim), nn.Dropout(0.0), |
| nn.Linear(pw_dim, config.n_classes)) |
| else: |
| self.constellation = None |
| self.patchwork = None |
| self.classifier = None |
|
|
| self.post_init() |
|
|
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| nn.init.xavier_uniform_(module.weight) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Conv2d): |
| nn.init.xavier_uniform_(module.weight) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.LayerNorm): |
| nn.init.ones_(module.weight) |
| nn.init.zeros_(module.bias) |
|
|
| def forward(self, pixel_values, output_patch_tokens=False, **kwargs): |
| B = pixel_values.shape[0] |
|
|
| |
| x = self.patch_embed(pixel_values) |
| x = x.flatten(2).transpose(1, 2) |
|
|
| cls = self.cls_token.expand(B, -1, -1) |
| x = torch.cat([cls, x], dim=1) |
| x = x + self.pos_embed |
| x = self.embed_drop(self.embed_norm(x)) |
|
|
| |
| |
| if self.constellation is not None: |
| anchors_n = F.normalize(self.constellation.anchors.detach(), dim=-1) |
| else: |
| anchors_n = None |
|
|
| for layer in self.layers: |
| if anchors_n is not None: |
| |
| pooled = x[:, 1:, :].mean(dim=1) |
| geo_128 = F.normalize(self.geo_pool_proj(pooled), dim=-1) |
| tri_dists = 1.0 - geo_128 @ anchors_n.T |
| geo_token = self.geo_tri_proj(tri_dists).unsqueeze(1) |
| x_with_geo = torch.cat([geo_token, x], dim=1) |
| x_with_geo = layer(x_with_geo) |
| x = x_with_geo[:, 1:, :] |
| else: |
| x = layer(x) |
|
|
| |
| patch_tokens = x[:, 1:, :] |
| pooled = patch_tokens.mean(dim=1) |
| embedding = F.normalize(self.output_proj(pooled), dim=-1) |
|
|
| |
| logits = None |
| triangulation = None |
| nearest = None |
| diagnostics = {} |
|
|
| if self.constellation is not None: |
| tri, near = self.constellation.triangulate(embedding) |
| triangulation = tri |
| nearest = near |
|
|
| if self.patchwork is not None and self.classifier is not None: |
| pw = self.patchwork(tri) |
| logits = self.classifier(torch.cat([pw, embedding], -1)) |
|
|
| |
| with torch.no_grad(): |
| anchors_n = F.normalize(self.constellation.anchors, dim=-1) |
| cos_to_anchors = embedding @ anchors_n.T |
| diagnostics = { |
| "nearest_cos": cos_to_anchors.max(dim=-1).values.mean().item(), |
| "mean_anchor_cos": cos_to_anchors.mean().item(), |
| "n_active_anchors": near.unique().numel(), |
| "embedding_norm": embedding.norm(dim=-1).mean().item(), |
| } |
|
|
| return GeoLIPViTOutput( |
| embedding=embedding, |
| logits=logits, |
| triangulation=triangulation, |
| nearest=nearest, |
| patch_tokens=patch_tokens if output_patch_tokens else None, |
| diagnostics=diagnostics, |
| ) |