geolip-vit-large-x3 / modeling_geolip_vit.py
AbstractPhil's picture
Create modeling_geolip_vit.py
d677e5c verified
# ============================================================================
# GeoLIP ViT: HuggingFace AutoModel
#
# Usage:
# from transformers import AutoModel
# model = AutoModel.from_pretrained("AbstractPhil/geolip-vit-base-x3",
# trust_remote_code=True)
#
# from torchvision import transforms
# transform = transforms.Compose([
# transforms.Resize((224, 224)),
# transforms.ToTensor(),
# transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
# ])
# pixel_values = transform(image).unsqueeze(0)
# outputs = model(pixel_values)
#
# # 128-d embedding on hypersphere (L2-normalized)
# embedding = outputs.embedding # (B, 128)
#
# # Multi-label classification logits (80 COCO classes)
# logits = outputs.logits # (B, 80) β€” if soup_enabled
#
# # Triangulation distances to 256 constellation anchors
# triangulation = outputs.triangulation # (B, 256)
#
# # Nearest anchor index per sample
# nearest = outputs.nearest # (B,)
#
# # Geometric diagnostics
# diagnostics = outputs.diagnostics # dict
# ============================================================================
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig, PreTrainedModel
from dataclasses import dataclass, field
from typing import Optional, Dict, Any
# ══════════════════════════════════════════════════════════════════
# CONFIG
# ══════════════════════════════════════════════════════════════════
class GeoLIPViTConfig(PretrainedConfig):
model_type = "geolip_vit"
def __init__(
self,
image_size=224,
patch_size=16,
hidden_size=384,
num_attention_heads=6,
num_hidden_layers=6,
intermediate_size=1536,
output_dim=128,
n_anchors=256,
n_comp=8,
d_comp=64,
n_classes=80,
hidden_dropout_prob=0.1,
soup_enabled=True,
consensus_cv=0.2731,
experts=None,
**kwargs,
):
super().__init__(**kwargs)
self.image_size = image_size
self.patch_size = patch_size
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.num_hidden_layers = num_hidden_layers
self.intermediate_size = intermediate_size
self.output_dim = output_dim
self.n_anchors = n_anchors
self.n_comp = n_comp
self.d_comp = d_comp
self.n_classes = n_classes
self.hidden_dropout_prob = hidden_dropout_prob
self.soup_enabled = soup_enabled
self.consensus_cv = consensus_cv
self.experts = experts or ["clip_l14_openai", "dinov2_b14", "siglip_b16_384"]
# ══════════════════════════════════════════════════════════════════
# OUTPUT
# ══════════════════════════════════════════════════════════════════
@dataclass
class GeoLIPViTOutput:
"""
Output fields:
embedding: (B, output_dim) L2-normalized on hypersphere
logits: (B, n_classes) multi-label classification (if soup_enabled)
triangulation: (B, n_anchors) distances to constellation anchors
nearest: (B,) nearest anchor index
patch_tokens: (B, n_patches, hidden_size) pre-pooling patch representations
diagnostics: dict geometric metrics
"""
embedding: torch.Tensor = None
logits: Optional[torch.Tensor] = None
triangulation: Optional[torch.Tensor] = None
nearest: Optional[torch.Tensor] = None
patch_tokens: Optional[torch.Tensor] = None
diagnostics: Optional[Dict[str, Any]] = None
# ══════════════════════════════════════════════════════════════════
# GEOMETRIC COMPONENTS
# ══════════════════════════════════════════════════════════════════
class Constellation(nn.Module):
def __init__(self, n_anchors, d):
super().__init__()
self.n_anchors = n_anchors
self.anchors = nn.Parameter(F.normalize(torch.randn(n_anchors, d), dim=-1))
def triangulate(self, emb):
a = F.normalize(self.anchors, dim=-1)
cos = emb @ a.T
return 1.0 - cos, cos.argmax(dim=-1)
class Patchwork(nn.Module):
def __init__(self, n_anchors, n_comp, d_comp):
super().__init__()
self.n_comp = n_comp
self.n_anchors = n_anchors
asgn = torch.arange(n_anchors) % n_comp
self.register_buffer("asgn", asgn)
# Compute input sizes from ints, not tensors (meta-tensor safe)
anchors_per_comp = n_anchors // n_comp
remainder = n_anchors % n_comp
self.comps = nn.ModuleList([nn.Sequential(
nn.Linear(anchors_per_comp + (1 if k < remainder else 0), d_comp * 2),
nn.GELU(),
nn.Linear(d_comp * 2, d_comp), nn.LayerNorm(d_comp))
for k in range(n_comp)])
def forward(self, tri):
return torch.cat([self.comps[k](tri[:, self.asgn == k])
for k in range(self.n_comp)], -1)
# ══════════════════════════════════════════════════════════════════
# MODEL
# ══════════════════════════════════════════════════════════════════
class GeoLIPViTModel(PreTrainedModel):
"""
From-scratch Vision Transformer producing L2-normalized embeddings
on a 128-d hypersphere, geometrically anchored by a constellation
of 256 reference points trained via 3-expert consensus distillation.
The encoder is trained from Xavier initialization against consensus
targets from CLIP ViT-L/14, DINOv2 ViT-B/14, and SigLIP ViT-B/16.
Optional soup pipeline (constellation + patchwork + classifier)
provides multi-label COCO classification from the embedding.
Output fields:
embedding: (B, 128) L2-normalized, consensus-aligned
logits: (B, 80) multi-label COCO logits (if soup_enabled)
triangulation: (B, 256) distances to constellation anchors
nearest: (B,) nearest anchor index
patch_tokens: (B, 196, 384) pre-pooling patch representations
diagnostics: dict geometric metrics
"""
config_class = GeoLIPViTConfig
def __init__(self, config):
super().__init__(config)
self.config = config
n_patches = (config.image_size // config.patch_size) ** 2
# ── Encoder ──
self.patch_embed = nn.Conv2d(
3, config.hidden_size,
kernel_size=config.patch_size, stride=config.patch_size)
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
self.pos_embed = nn.Parameter(
torch.zeros(1, n_patches + 1, config.hidden_size))
self.embed_norm = nn.LayerNorm(config.hidden_size)
self.embed_drop = nn.Dropout(config.hidden_dropout_prob)
# Individual layers for geometric injection between each
self.layers = nn.ModuleList([
nn.TransformerEncoderLayer(
d_model=config.hidden_size,
nhead=config.num_attention_heads,
dim_feedforward=config.intermediate_size,
dropout=config.hidden_dropout_prob,
activation="gelu",
batch_first=True,
norm_first=True)
for _ in range(config.num_hidden_layers)])
# Geometric injection: pool β†’ anchor_dim β†’ triangulate β†’ hidden_size
self.geo_pool_proj = nn.Linear(config.hidden_size, config.output_dim)
self.geo_tri_proj = nn.Sequential(
nn.Linear(config.n_anchors, config.hidden_size), nn.GELU(),
nn.LayerNorm(config.hidden_size))
self.output_proj = nn.Sequential(
nn.Linear(config.hidden_size, config.hidden_size),
nn.GELU(),
nn.LayerNorm(config.hidden_size),
nn.Linear(config.hidden_size, config.output_dim),
)
# ── Soup Pipeline (optional) ──
if getattr(config, "soup_enabled", False):
self.constellation = Constellation(config.n_anchors, config.output_dim)
self.patchwork = Patchwork(
config.n_anchors, config.n_comp, config.d_comp)
pw_dim = config.n_comp * config.d_comp
self.classifier = nn.Sequential(
nn.Linear(pw_dim + config.output_dim, pw_dim),
nn.GELU(), nn.LayerNorm(pw_dim), nn.Dropout(0.0),
nn.Linear(pw_dim, config.n_classes))
else:
self.constellation = None
self.patchwork = None
self.classifier = None
self.post_init()
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Conv2d):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
def forward(self, pixel_values, output_patch_tokens=False, **kwargs):
B = pixel_values.shape[0]
# ── Encode ──
x = self.patch_embed(pixel_values)
x = x.flatten(2).transpose(1, 2)
cls = self.cls_token.expand(B, -1, -1)
x = torch.cat([cls, x], dim=1)
x = x + self.pos_embed
x = self.embed_drop(self.embed_norm(x))
# ── Transformer with geometric injection ──
# Get anchors for triangulation (from constellation if available)
if self.constellation is not None:
anchors_n = F.normalize(self.constellation.anchors.detach(), dim=-1)
else:
anchors_n = None
for layer in self.layers:
if anchors_n is not None:
# Pool β†’ project β†’ triangulate β†’ geo token
pooled = x[:, 1:, :].mean(dim=1)
geo_128 = F.normalize(self.geo_pool_proj(pooled), dim=-1)
tri_dists = 1.0 - geo_128 @ anchors_n.T
geo_token = self.geo_tri_proj(tri_dists).unsqueeze(1)
x_with_geo = torch.cat([geo_token, x], dim=1)
x_with_geo = layer(x_with_geo)
x = x_with_geo[:, 1:, :]
else:
x = layer(x)
# ── Pool + Project ──
patch_tokens = x[:, 1:, :]
pooled = patch_tokens.mean(dim=1)
embedding = F.normalize(self.output_proj(pooled), dim=-1)
# ── Soup Pipeline ──
logits = None
triangulation = None
nearest = None
diagnostics = {}
if self.constellation is not None:
tri, near = self.constellation.triangulate(embedding)
triangulation = tri
nearest = near
if self.patchwork is not None and self.classifier is not None:
pw = self.patchwork(tri)
logits = self.classifier(torch.cat([pw, embedding], -1))
# Geometric diagnostics
with torch.no_grad():
anchors_n = F.normalize(self.constellation.anchors, dim=-1)
cos_to_anchors = embedding @ anchors_n.T
diagnostics = {
"nearest_cos": cos_to_anchors.max(dim=-1).values.mean().item(),
"mean_anchor_cos": cos_to_anchors.mean().item(),
"n_active_anchors": near.unique().numel(),
"embedding_norm": embedding.norm(dim=-1).mean().item(),
}
return GeoLIPViTOutput(
embedding=embedding,
logits=logits,
triangulation=triangulation,
nearest=nearest,
patch_tokens=patch_tokens if output_patch_tokens else None,
diagnostics=diagnostics,
)