# ============================================================================ # CaptionBERT-8192: HuggingFace AutoModel with Alignment Bank # # Usage: # from transformers import AutoModel, AutoTokenizer # model = AutoModel.from_pretrained("AbstractPhil/geolip-captionbert-8192", # trust_remote_code=True) # tokenizer = AutoTokenizer.from_pretrained("AbstractPhil/geolip-captionbert-8192", # trust_remote_code=True) # inputs = tokenizer("A cat on a windowsill", return_tensors="pt", # padding=True, truncation=True, max_length=512) # outputs = model(**inputs) # # # Core embedding (consensus-distilled, L2-normalized) # embedding = outputs.last_hidden_state # (B, 768) # # # Enriched embedding (with geometric context from 5-expert bank) # enriched = outputs.enriched # (B, 768 + bank_dim) # # # Token-level representations (pre-pooling, for sequence tasks) # tokens = outputs.token_embeddings # (B, L, 384) # # # Geometric diagnostics # geo = outputs.geometric_context # dict with expert cos, anchors, etc. # ============================================================================ import math import torch import torch.nn as nn import torch.nn.functional as F from transformers import PretrainedConfig, PreTrainedModel class CaptionBertConfig(PretrainedConfig): model_type = "caption_bert" def __init__( self, vocab_size=30522, max_position_embeddings=8192, hidden_size=384, num_attention_heads=6, num_hidden_layers=6, intermediate_size=1536, output_dim=768, hidden_dropout_prob=0.0, pad_token_id=0, # Alignment bank bank_enabled=True, bank_n_experts=5, bank_n_anchors=512, bank_dim=128, bank_cv_target=0.082, **kwargs, ): super().__init__(pad_token_id=pad_token_id, **kwargs) self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size self.num_attention_heads = num_attention_heads self.num_hidden_layers = num_hidden_layers self.intermediate_size = intermediate_size self.output_dim = output_dim self.hidden_dropout_prob = hidden_dropout_prob self.bank_enabled = bank_enabled self.bank_n_experts = bank_n_experts self.bank_n_anchors = bank_n_anchors self.bank_dim = bank_dim self.bank_cv_target = bank_cv_target class AlignmentBank(nn.Module): """ Geometric interface layer preserving 5-expert differentiation structure. Trained post-hoc on frozen encoder via GPA + whitened Procrustes. Stores per-expert rotation matrices, whiteners, and means that encode how each expert's geometric perspective differs from the consensus center. Provides geometric context annotations (128-dim) alongside the core 768-dim consensus embedding for downstream heads. """ def __init__(self, d_embed=768, n_experts=5, n_anchors=512, d_bank=128): super().__init__() self.d_embed = d_embed self.n_experts = n_experts self.n_anchors = n_anchors self.d_bank = d_bank # Per-expert Procrustes components (the differentiation structure) self.expert_rotations = nn.ParameterList([ nn.Parameter(torch.eye(d_embed)) for _ in range(n_experts)]) self.expert_whiteners = nn.ParameterList([ nn.Parameter(torch.eye(d_embed)) for _ in range(n_experts)]) self.expert_means = nn.ParameterList([ nn.Parameter(torch.zeros(d_embed)) for _ in range(n_experts)]) # Consensus landmarks on the hypersphere self.anchors = nn.Parameter( F.normalize(torch.randn(n_anchors, d_embed), dim=-1)) # Geometric context projection n_cross = n_experts * (n_experts - 1) // 2 geo_dim = n_experts + n_experts + n_cross + 1 + n_experts + n_anchors self.geo_proj = nn.Sequential( nn.Linear(geo_dim, d_bank * 2), nn.GELU(), nn.LayerNorm(d_bank * 2), nn.Linear(d_bank * 2, d_bank), nn.LayerNorm(d_bank)) # Calibrated consensus targets (preserved from training) self.register_buffer("target_cv", torch.tensor(0.082)) self.register_buffer("target_mean_cos", torch.tensor(0.0)) self.register_buffer("target_spectral", torch.zeros(50)) self.register_buffer("target_cross_cos_mean", torch.tensor(0.0)) self.register_buffer("target_cross_cos_std", torch.tensor(0.0)) self.register_buffer("target_disagreement_ratio", torch.tensor(0.0)) def forward(self, embedding): B = embedding.shape[0] emb = embedding.float() # Full whitened Procrustes per expert: center → whiten → normalize → rotate expert_consistency = [] expert_recon = [] expert_projected = [] for i in range(self.n_experts): R = self.expert_rotations[i] W = self.expert_whiteners[i] mu = self.expert_means[i] centered = emb - mu whitened = centered @ W whitened_n = F.normalize(whitened, dim=-1) in_expert = whitened_n @ R.T back = in_expert @ R cos = F.cosine_similarity(whitened_n, back, dim=-1) recon = (whitened_n - back).pow(2).mean(dim=-1) expert_consistency.append(cos) expert_recon.append(recon) expert_projected.append(in_expert) expert_cos = torch.stack(expert_consistency, dim=-1) expert_mse = torch.stack(expert_recon, dim=-1) # Cross-expert differentiation (10 pairs for 5 experts) cross_cos = [] for i in range(self.n_experts): for j in range(i + 1, self.n_experts): cc = F.cosine_similarity( expert_projected[i], expert_projected[j], dim=-1) cross_cos.append(cc) cross_features = torch.stack(cross_cos, dim=-1) # Per-sample disagreement per_sample_agreement = expert_cos.mean(dim=-1) per_sample_disagreement = expert_cos.std(dim=-1) disagreement_ratio = per_sample_disagreement / (per_sample_agreement + 1e-8) # Expert norm ratios expert_norms = [] for i in range(self.n_experts): W = self.expert_whiteners[i]; mu = self.expert_means[i] whitened = (emb - mu) @ W expert_norms.append(whitened.norm(dim=-1)) norm_ratio = torch.stack(expert_norms, dim=-1) norm_ratio = norm_ratio / (norm_ratio.mean(dim=-1, keepdim=True) + 1e-8) # Anchor distances anchors_n = F.normalize(self.anchors, dim=-1) anchor_cos = emb @ anchors_n.T # Geometric context vector geo_input = torch.cat([ expert_cos, expert_mse, cross_features, disagreement_ratio.unsqueeze(-1), norm_ratio, anchor_cos ], dim=-1) geo_context = self.geo_proj(geo_input) enriched = torch.cat([embedding, geo_context], dim=-1) # Diagnostics diagnostics = { "expert_cos_mean": expert_cos.mean().item(), "expert_cos_std": expert_cos.std().item(), "cross_expert_cos": cross_features.mean().item(), "cross_expert_cos_std": cross_features.std().item(), "anchor_max_cos": anchor_cos.max(dim=-1).values.mean().item(), "anchor_mean_cos": anchor_cos.mean().item(), "disagreement_ratio": disagreement_ratio.mean().item(), "norm_ratio_spread": norm_ratio.std(dim=-1).mean().item(), } return enriched, geo_context, diagnostics class CaptionBertModel(PreTrainedModel): """ Consensus-distilled caption encoder with geometric alignment bank. The encoder produces L2-normalized 768-dim embeddings in the geometric consensus space of 5 BERT-family models (BERT, ModernBERT, RoBERTa, ALBERT, DistilBERT), aligned via Generalized Procrustes Analysis. The alignment bank annotates each embedding with 128-dim geometric context from the 5-expert differentiation structure — per-expert consistency, cross-expert disagreement, and anchor distances. Output fields: last_hidden_state: (B, 768) L2-normalized consensus embedding pooler_output: (B, 768) same (HF compatibility) token_embeddings: (B, L, 384) pre-pooling token representations enriched: (B, 896) embedding + bank geometric context geometric_context: dict expert cos, cross-expert, anchors, etc. hidden_states: tuple per-layer outputs (if requested) """ config_class = CaptionBertConfig def __init__(self, config): super().__init__(config) self.config = config # ── Encoder ── self.token_emb = nn.Embedding( config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) self.pos_emb = nn.Embedding( config.max_position_embeddings, config.hidden_size) self.emb_norm = nn.LayerNorm(config.hidden_size) self.emb_drop = nn.Dropout(config.hidden_dropout_prob) encoder_layer = nn.TransformerEncoderLayer( d_model=config.hidden_size, nhead=config.num_attention_heads, dim_feedforward=config.intermediate_size, dropout=config.hidden_dropout_prob, activation="gelu", batch_first=True, norm_first=True, ) self.encoder = nn.TransformerEncoder( encoder_layer, num_layers=config.num_hidden_layers, enable_nested_tensor=False) self.output_proj = nn.Sequential( nn.Linear(config.hidden_size, config.hidden_size), nn.GELU(), nn.LayerNorm(config.hidden_size), nn.Linear(config.hidden_size, config.output_dim), ) # ── Alignment Bank ── if getattr(config, 'bank_enabled', False): self.bank = AlignmentBank( d_embed=config.output_dim, n_experts=config.bank_n_experts, n_anchors=config.bank_n_anchors, d_bank=config.bank_dim, ) else: self.bank = None self.post_init() def forward(self, input_ids=None, attention_mask=None, output_hidden_states=False, **kwargs): B, L = input_ids.shape device = input_ids.device # ── Encode ── positions = torch.arange(L, device=device).unsqueeze(0) x = self.token_emb(input_ids) + self.pos_emb(positions) x = self.emb_drop(self.emb_norm(x)) if attention_mask is not None: key_padding_mask = ~attention_mask.bool() else: key_padding_mask = (input_ids == self.config.pad_token_id) hidden_states = [x] if output_hidden_states else None for layer in self.encoder.layers: x = layer(x, src_key_padding_mask=key_padding_mask) if output_hidden_states: hidden_states.append(x) # ── Pool + Project ── if attention_mask is not None: mask = attention_mask.unsqueeze(-1).float() else: mask = (~key_padding_mask).unsqueeze(-1).float() pooled = (x * mask).sum(1) / mask.sum(1).clamp(min=1) embedding = F.normalize(self.output_proj(pooled), dim=-1) # ── Alignment Bank ── enriched = None geo_diagnostics = None if self.bank is not None: enriched, _, geo_diagnostics = self.bank(embedding) # ── Output ── result = { 'last_hidden_state': embedding, # (B, 768) 'pooler_output': embedding, # (B, 768) compat 'token_embeddings': x, # (B, L, 384) 'enriched': enriched, # (B, 896) or None 'geometric_context': geo_diagnostics, # dict or None } if output_hidden_states: result['hidden_states'] = tuple(hidden_states) return type('Output', (), result)() def encode(self, texts, tokenizer=None, max_length=512, batch_size=128, device=None): """Convenience: raw text → L2-normalized (N, 768) embeddings.""" if isinstance(texts, str): texts = [texts] if tokenizer is None: from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") if device is None: device = next(self.parameters()).device self.eval() all_emb = [] with torch.no_grad(): for i in range(0, len(texts), batch_size): batch = texts[i:i+batch_size] inputs = tokenizer( batch, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" ).to(device) out = self(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]) all_emb.append(out.last_hidden_state.cpu()) return torch.cat(all_emb)