File size: 13,478 Bytes
1e5d50f a1933ae 1e5d50f 8b6e6b3 a1933ae 8b6e6b3 a1933ae 1e5d50f a1933ae 1e5d50f 8b6e6b3 a1933ae 8b6e6b3 a1933ae 8b6e6b3 a1933ae 34d35c5 a1933ae 1e5d50f 8b6e6b3 1e5d50f a1933ae 1e5d50f a1933ae 1e5d50f a1933ae 1e5d50f 8b6e6b3 1e5d50f 8b6e6b3 a1933ae 8b6e6b3 1e5d50f 8b6e6b3 91ea190 1e5d50f 8b6e6b3 1e5d50f 8b6e6b3 1e5d50f a1933ae 8b6e6b3 7f7daa9 1e5d50f 8b6e6b3 1e5d50f a1933ae 8b6e6b3 1e5d50f 8b6e6b3 1e5d50f 8b6e6b3 1e5d50f 7f7daa9 1e5d50f a1933ae 1e5d50f 8b6e6b3 1e5d50f 8b6e6b3 a1933ae 7f7daa9 a1933ae 7f7daa9 8b6e6b3 a1933ae 8b6e6b3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 | # ============================================================================
# CaptionBERT-8192: HuggingFace AutoModel with Alignment Bank
#
# Usage:
# from transformers import AutoModel, AutoTokenizer
# model = AutoModel.from_pretrained("AbstractPhil/geolip-captionbert-8192",
# trust_remote_code=True)
# tokenizer = AutoTokenizer.from_pretrained("AbstractPhil/geolip-captionbert-8192",
# trust_remote_code=True)
# inputs = tokenizer("A cat on a windowsill", return_tensors="pt",
# padding=True, truncation=True, max_length=512)
# outputs = model(**inputs)
#
# # Core embedding (consensus-distilled, L2-normalized)
# embedding = outputs.last_hidden_state # (B, 768)
#
# # Enriched embedding (with geometric context from 5-expert bank)
# enriched = outputs.enriched # (B, 768 + bank_dim)
#
# # Token-level representations (pre-pooling, for sequence tasks)
# tokens = outputs.token_embeddings # (B, L, 384)
#
# # Geometric diagnostics
# geo = outputs.geometric_context # dict with expert cos, anchors, etc.
# ============================================================================
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig, PreTrainedModel
class CaptionBertConfig(PretrainedConfig):
model_type = "caption_bert"
def __init__(
self,
vocab_size=30522,
max_position_embeddings=8192,
hidden_size=384,
num_attention_heads=6,
num_hidden_layers=6,
intermediate_size=1536,
output_dim=768,
hidden_dropout_prob=0.0,
pad_token_id=0,
# Alignment bank
bank_enabled=True,
bank_n_experts=5,
bank_n_anchors=512,
bank_dim=128,
bank_cv_target=0.082,
**kwargs,
):
super().__init__(pad_token_id=pad_token_id, **kwargs)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.num_hidden_layers = num_hidden_layers
self.intermediate_size = intermediate_size
self.output_dim = output_dim
self.hidden_dropout_prob = hidden_dropout_prob
self.bank_enabled = bank_enabled
self.bank_n_experts = bank_n_experts
self.bank_n_anchors = bank_n_anchors
self.bank_dim = bank_dim
self.bank_cv_target = bank_cv_target
class AlignmentBank(nn.Module):
"""
Geometric interface layer preserving 5-expert differentiation structure.
Trained post-hoc on frozen encoder via GPA + whitened Procrustes.
Stores per-expert rotation matrices, whiteners, and means that encode
how each expert's geometric perspective differs from the consensus center.
Provides geometric context annotations (128-dim) alongside the core
768-dim consensus embedding for downstream heads.
"""
def __init__(self, d_embed=768, n_experts=5, n_anchors=512, d_bank=128):
super().__init__()
self.d_embed = d_embed
self.n_experts = n_experts
self.n_anchors = n_anchors
self.d_bank = d_bank
# Per-expert Procrustes components (the differentiation structure)
self.expert_rotations = nn.ParameterList([
nn.Parameter(torch.eye(d_embed)) for _ in range(n_experts)])
self.expert_whiteners = nn.ParameterList([
nn.Parameter(torch.eye(d_embed)) for _ in range(n_experts)])
self.expert_means = nn.ParameterList([
nn.Parameter(torch.zeros(d_embed)) for _ in range(n_experts)])
# Consensus landmarks on the hypersphere
self.anchors = nn.Parameter(
F.normalize(torch.randn(n_anchors, d_embed), dim=-1))
# Geometric context projection
n_cross = n_experts * (n_experts - 1) // 2
geo_dim = n_experts + n_experts + n_cross + 1 + n_experts + n_anchors
self.geo_proj = nn.Sequential(
nn.Linear(geo_dim, d_bank * 2), nn.GELU(), nn.LayerNorm(d_bank * 2),
nn.Linear(d_bank * 2, d_bank), nn.LayerNorm(d_bank))
# Calibrated consensus targets (preserved from training)
self.register_buffer("target_cv", torch.tensor(0.082))
self.register_buffer("target_mean_cos", torch.tensor(0.0))
self.register_buffer("target_spectral", torch.zeros(50))
self.register_buffer("target_cross_cos_mean", torch.tensor(0.0))
self.register_buffer("target_cross_cos_std", torch.tensor(0.0))
self.register_buffer("target_disagreement_ratio", torch.tensor(0.0))
def forward(self, embedding):
B = embedding.shape[0]
emb = embedding.float()
# Full whitened Procrustes per expert: center β whiten β normalize β rotate
expert_consistency = []
expert_recon = []
expert_projected = []
for i in range(self.n_experts):
R = self.expert_rotations[i]
W = self.expert_whiteners[i]
mu = self.expert_means[i]
centered = emb - mu
whitened = centered @ W
whitened_n = F.normalize(whitened, dim=-1)
in_expert = whitened_n @ R.T
back = in_expert @ R
cos = F.cosine_similarity(whitened_n, back, dim=-1)
recon = (whitened_n - back).pow(2).mean(dim=-1)
expert_consistency.append(cos)
expert_recon.append(recon)
expert_projected.append(in_expert)
expert_cos = torch.stack(expert_consistency, dim=-1)
expert_mse = torch.stack(expert_recon, dim=-1)
# Cross-expert differentiation (10 pairs for 5 experts)
cross_cos = []
for i in range(self.n_experts):
for j in range(i + 1, self.n_experts):
cc = F.cosine_similarity(
expert_projected[i], expert_projected[j], dim=-1)
cross_cos.append(cc)
cross_features = torch.stack(cross_cos, dim=-1)
# Per-sample disagreement
per_sample_agreement = expert_cos.mean(dim=-1)
per_sample_disagreement = expert_cos.std(dim=-1)
disagreement_ratio = per_sample_disagreement / (per_sample_agreement + 1e-8)
# Expert norm ratios
expert_norms = []
for i in range(self.n_experts):
W = self.expert_whiteners[i]; mu = self.expert_means[i]
whitened = (emb - mu) @ W
expert_norms.append(whitened.norm(dim=-1))
norm_ratio = torch.stack(expert_norms, dim=-1)
norm_ratio = norm_ratio / (norm_ratio.mean(dim=-1, keepdim=True) + 1e-8)
# Anchor distances
anchors_n = F.normalize(self.anchors, dim=-1)
anchor_cos = emb @ anchors_n.T
# Geometric context vector
geo_input = torch.cat([
expert_cos, expert_mse, cross_features,
disagreement_ratio.unsqueeze(-1), norm_ratio, anchor_cos
], dim=-1)
geo_context = self.geo_proj(geo_input)
enriched = torch.cat([embedding, geo_context], dim=-1)
# Diagnostics
diagnostics = {
"expert_cos_mean": expert_cos.mean().item(),
"expert_cos_std": expert_cos.std().item(),
"cross_expert_cos": cross_features.mean().item(),
"cross_expert_cos_std": cross_features.std().item(),
"anchor_max_cos": anchor_cos.max(dim=-1).values.mean().item(),
"anchor_mean_cos": anchor_cos.mean().item(),
"disagreement_ratio": disagreement_ratio.mean().item(),
"norm_ratio_spread": norm_ratio.std(dim=-1).mean().item(),
}
return enriched, geo_context, diagnostics
class CaptionBertModel(PreTrainedModel):
"""
Consensus-distilled caption encoder with geometric alignment bank.
The encoder produces L2-normalized 768-dim embeddings in the geometric
consensus space of 5 BERT-family models (BERT, ModernBERT, RoBERTa,
ALBERT, DistilBERT), aligned via Generalized Procrustes Analysis.
The alignment bank annotates each embedding with 128-dim geometric
context from the 5-expert differentiation structure β per-expert
consistency, cross-expert disagreement, and anchor distances.
Output fields:
last_hidden_state: (B, 768) L2-normalized consensus embedding
pooler_output: (B, 768) same (HF compatibility)
token_embeddings: (B, L, 384) pre-pooling token representations
enriched: (B, 896) embedding + bank geometric context
geometric_context: dict expert cos, cross-expert, anchors, etc.
hidden_states: tuple per-layer outputs (if requested)
"""
config_class = CaptionBertConfig
def __init__(self, config):
super().__init__(config)
self.config = config
# ββ Encoder ββ
self.token_emb = nn.Embedding(
config.vocab_size, config.hidden_size,
padding_idx=config.pad_token_id)
self.pos_emb = nn.Embedding(
config.max_position_embeddings, config.hidden_size)
self.emb_norm = nn.LayerNorm(config.hidden_size)
self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
encoder_layer = nn.TransformerEncoderLayer(
d_model=config.hidden_size,
nhead=config.num_attention_heads,
dim_feedforward=config.intermediate_size,
dropout=config.hidden_dropout_prob,
activation="gelu",
batch_first=True,
norm_first=True,
)
self.encoder = nn.TransformerEncoder(
encoder_layer, num_layers=config.num_hidden_layers,
enable_nested_tensor=False)
self.output_proj = nn.Sequential(
nn.Linear(config.hidden_size, config.hidden_size),
nn.GELU(),
nn.LayerNorm(config.hidden_size),
nn.Linear(config.hidden_size, config.output_dim),
)
# ββ Alignment Bank ββ
if getattr(config, 'bank_enabled', False):
self.bank = AlignmentBank(
d_embed=config.output_dim,
n_experts=config.bank_n_experts,
n_anchors=config.bank_n_anchors,
d_bank=config.bank_dim,
)
else:
self.bank = None
self.post_init()
def forward(self, input_ids=None, attention_mask=None,
output_hidden_states=False, **kwargs):
B, L = input_ids.shape
device = input_ids.device
# ββ Encode ββ
positions = torch.arange(L, device=device).unsqueeze(0)
x = self.token_emb(input_ids) + self.pos_emb(positions)
x = self.emb_drop(self.emb_norm(x))
if attention_mask is not None:
key_padding_mask = ~attention_mask.bool()
else:
key_padding_mask = (input_ids == self.config.pad_token_id)
hidden_states = [x] if output_hidden_states else None
for layer in self.encoder.layers:
x = layer(x, src_key_padding_mask=key_padding_mask)
if output_hidden_states:
hidden_states.append(x)
# ββ Pool + Project ββ
if attention_mask is not None:
mask = attention_mask.unsqueeze(-1).float()
else:
mask = (~key_padding_mask).unsqueeze(-1).float()
pooled = (x * mask).sum(1) / mask.sum(1).clamp(min=1)
embedding = F.normalize(self.output_proj(pooled), dim=-1)
# ββ Alignment Bank ββ
enriched = None
geo_diagnostics = None
if self.bank is not None:
enriched, _, geo_diagnostics = self.bank(embedding)
# ββ Output ββ
result = {
'last_hidden_state': embedding, # (B, 768)
'pooler_output': embedding, # (B, 768) compat
'token_embeddings': x, # (B, L, 384)
'enriched': enriched, # (B, 896) or None
'geometric_context': geo_diagnostics, # dict or None
}
if output_hidden_states:
result['hidden_states'] = tuple(hidden_states)
return type('Output', (), result)()
def encode(self, texts, tokenizer=None, max_length=512, batch_size=128,
device=None):
"""Convenience: raw text β L2-normalized (N, 768) embeddings."""
if isinstance(texts, str):
texts = [texts]
if tokenizer is None:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
if device is None:
device = next(self.parameters()).device
self.eval()
all_emb = []
with torch.no_grad():
for i in range(0, len(texts), batch_size):
batch = texts[i:i+batch_size]
inputs = tokenizer(
batch, max_length=max_length, padding="max_length",
truncation=True, return_tensors="pt"
).to(device)
out = self(input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"])
all_emb.append(out.last_hidden_state.cpu())
return torch.cat(all_emb) |