AbstractPhil commited on
Commit
a1933ae
Β·
verified Β·
1 Parent(s): 85bb85c

added bank support

Browse files
Files changed (1) hide show
  1. modeling_caption_bert.py +189 -39
modeling_caption_bert.py CHANGED
@@ -1,17 +1,30 @@
1
  # ============================================================================
2
- # CaptionBERT-8192: HuggingFace AutoModel-Compatible Implementation
3
  #
4
  # Usage:
5
  # from transformers import AutoModel, AutoTokenizer
6
  # model = AutoModel.from_pretrained("AbstractPhil/geolip-captionbert-8192",
7
  # trust_remote_code=True)
8
- # tokenizer = AutoTokenizer.from_pretrained("AbstractPhil/geolip-captionbert-8192")
 
9
  # inputs = tokenizer("A cat on a windowsill", return_tensors="pt",
10
  # padding=True, truncation=True, max_length=512)
11
  # outputs = model(**inputs)
12
- # embedding = outputs.last_hidden_state # (B, 768) L2-normalized
 
 
 
 
 
 
 
 
 
 
 
13
  # ============================================================================
14
 
 
15
  import torch
16
  import torch.nn as nn
17
  import torch.nn.functional as F
@@ -30,8 +43,14 @@ class CaptionBertConfig(PretrainedConfig):
30
  num_hidden_layers=6,
31
  intermediate_size=1536,
32
  output_dim=768,
33
- hidden_dropout_prob=0.1,
34
  pad_token_id=0,
 
 
 
 
 
 
35
  **kwargs,
36
  ):
37
  super().__init__(pad_token_id=pad_token_id, **kwargs)
@@ -43,18 +62,151 @@ class CaptionBertConfig(PretrainedConfig):
43
  self.intermediate_size = intermediate_size
44
  self.output_dim = output_dim
45
  self.hidden_dropout_prob = hidden_dropout_prob
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
 
48
  class CaptionBertModel(PreTrainedModel):
49
  """
50
- Consensus-distilled caption encoder.
51
 
52
- Produces L2-normalized 768-dim embeddings in the geometric consensus
53
- space of 5 BERT-family models (BERT, ModernBERT, RoBERTa, ALBERT, DistilBERT).
 
54
 
55
- Output:
56
- last_hidden_state: (B, output_dim) L2-normalized embedding
57
- pooler_output: (B, output_dim) same as last_hidden_state (for compatibility)
 
 
 
 
 
 
 
 
58
  """
59
  config_class = CaptionBertConfig
60
 
@@ -62,7 +214,7 @@ class CaptionBertModel(PreTrainedModel):
62
  super().__init__(config)
63
  self.config = config
64
 
65
- # Embeddings
66
  self.token_emb = nn.Embedding(
67
  config.vocab_size, config.hidden_size,
68
  padding_idx=config.pad_token_id)
@@ -71,7 +223,6 @@ class CaptionBertModel(PreTrainedModel):
71
  self.emb_norm = nn.LayerNorm(config.hidden_size)
72
  self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
73
 
74
- # Transformer encoder
75
  encoder_layer = nn.TransformerEncoderLayer(
76
  d_model=config.hidden_size,
77
  nhead=config.num_attention_heads,
@@ -85,7 +236,6 @@ class CaptionBertModel(PreTrainedModel):
85
  encoder_layer, num_layers=config.num_hidden_layers,
86
  enable_nested_tensor=False)
87
 
88
- # Output projection to consensus space
89
  self.output_proj = nn.Sequential(
90
  nn.Linear(config.hidden_size, config.hidden_size),
91
  nn.GELU(),
@@ -93,6 +243,17 @@ class CaptionBertModel(PreTrainedModel):
93
  nn.Linear(config.hidden_size, config.output_dim),
94
  )
95
 
 
 
 
 
 
 
 
 
 
 
 
96
  self.post_init()
97
 
98
  def forward(self, input_ids=None, attention_mask=None,
@@ -100,39 +261,43 @@ class CaptionBertModel(PreTrainedModel):
100
  B, L = input_ids.shape
101
  device = input_ids.device
102
 
103
- # Embed
104
  positions = torch.arange(L, device=device).unsqueeze(0)
105
  x = self.token_emb(input_ids) + self.pos_emb(positions)
106
  x = self.emb_drop(self.emb_norm(x))
107
 
108
- # Transformer with padding mask
109
  if attention_mask is not None:
110
  key_padding_mask = ~attention_mask.bool()
111
  else:
112
  key_padding_mask = (input_ids == self.config.pad_token_id)
113
 
114
- # Layer-by-layer for hidden state capture
115
  hidden_states = [x] if output_hidden_states else None
116
  for layer in self.encoder.layers:
117
  x = layer(x, src_key_padding_mask=key_padding_mask)
118
  if output_hidden_states:
119
  hidden_states.append(x)
120
 
121
- # Mean pool over non-padding tokens
122
  if attention_mask is not None:
123
  mask = attention_mask.unsqueeze(-1).float()
124
  else:
125
  mask = (~key_padding_mask).unsqueeze(-1).float()
126
  pooled = (x * mask).sum(1) / mask.sum(1).clamp(min=1)
127
-
128
- # Project and normalize
129
  embedding = F.normalize(self.output_proj(pooled), dim=-1)
130
 
131
- # Return in HuggingFace-compatible format
 
 
 
 
 
 
132
  result = {
133
- 'last_hidden_state': embedding, # (B, 768) pooled, normalized
134
- 'pooler_output': embedding, # same, for compatibility
135
- 'token_embeddings': x, # (B, L, 384) pre-pooling sequence
 
 
136
  }
137
  if output_hidden_states:
138
  result['hidden_states'] = tuple(hidden_states)
@@ -141,29 +306,14 @@ class CaptionBertModel(PreTrainedModel):
141
 
142
  def encode(self, texts, tokenizer=None, max_length=512, batch_size=128,
143
  device=None):
144
- """
145
- Convenience method: raw text β†’ L2-normalized embeddings.
146
-
147
- Args:
148
- texts: str or list of str
149
- tokenizer: AutoTokenizer instance (loads default if None)
150
- max_length: max token length
151
- batch_size: encoding batch size
152
- device: torch device
153
-
154
- Returns:
155
- (N, 768) L2-normalized tensor
156
- """
157
  if isinstance(texts, str):
158
  texts = [texts]
159
-
160
  if tokenizer is None:
161
  from transformers import AutoTokenizer
162
  tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
163
-
164
  if device is None:
165
  device = next(self.parameters()).device
166
-
167
  self.eval()
168
  all_emb = []
169
  with torch.no_grad():
 
1
  # ============================================================================
2
+ # CaptionBERT-8192: HuggingFace AutoModel with Alignment Bank
3
  #
4
  # Usage:
5
  # from transformers import AutoModel, AutoTokenizer
6
  # model = AutoModel.from_pretrained("AbstractPhil/geolip-captionbert-8192",
7
  # trust_remote_code=True)
8
+ # tokenizer = AutoTokenizer.from_pretrained("AbstractPhil/geolip-captionbert-8192",
9
+ # trust_remote_code=True)
10
  # inputs = tokenizer("A cat on a windowsill", return_tensors="pt",
11
  # padding=True, truncation=True, max_length=512)
12
  # outputs = model(**inputs)
13
+ #
14
+ # # Core embedding (consensus-distilled, L2-normalized)
15
+ # embedding = outputs.last_hidden_state # (B, 768)
16
+ #
17
+ # # Enriched embedding (with geometric context from 5-expert bank)
18
+ # enriched = outputs.enriched # (B, 768 + bank_dim)
19
+ #
20
+ # # Token-level representations (pre-pooling, for sequence tasks)
21
+ # tokens = outputs.token_embeddings # (B, L, 384)
22
+ #
23
+ # # Geometric diagnostics
24
+ # geo = outputs.geometric_context # dict with expert cos, anchors, etc.
25
  # ============================================================================
26
 
27
+ import math
28
  import torch
29
  import torch.nn as nn
30
  import torch.nn.functional as F
 
43
  num_hidden_layers=6,
44
  intermediate_size=1536,
45
  output_dim=768,
46
+ hidden_dropout_prob=0.0,
47
  pad_token_id=0,
48
+ # Alignment bank
49
+ bank_enabled=True,
50
+ bank_n_experts=5,
51
+ bank_n_anchors=512,
52
+ bank_dim=128,
53
+ bank_cv_target=0.082,
54
  **kwargs,
55
  ):
56
  super().__init__(pad_token_id=pad_token_id, **kwargs)
 
62
  self.intermediate_size = intermediate_size
63
  self.output_dim = output_dim
64
  self.hidden_dropout_prob = hidden_dropout_prob
65
+ self.bank_enabled = bank_enabled
66
+ self.bank_n_experts = bank_n_experts
67
+ self.bank_n_anchors = bank_n_anchors
68
+ self.bank_dim = bank_dim
69
+ self.bank_cv_target = bank_cv_target
70
+
71
+
72
+ class AlignmentBank(nn.Module):
73
+ """
74
+ Geometric interface layer preserving 5-expert differentiation structure.
75
+
76
+ Trained post-hoc on frozen encoder via GPA + whitened Procrustes.
77
+ Stores per-expert rotation matrices, whiteners, and means that encode
78
+ how each expert's geometric perspective differs from the consensus center.
79
+
80
+ Provides geometric context annotations (128-dim) alongside the core
81
+ 768-dim consensus embedding for downstream heads.
82
+ """
83
+ def __init__(self, d_embed=768, n_experts=5, n_anchors=512, d_bank=128):
84
+ super().__init__()
85
+ self.d_embed = d_embed
86
+ self.n_experts = n_experts
87
+ self.n_anchors = n_anchors
88
+ self.d_bank = d_bank
89
+
90
+ # Per-expert Procrustes components (the differentiation structure)
91
+ self.expert_rotations = nn.ParameterList([
92
+ nn.Parameter(torch.eye(d_embed)) for _ in range(n_experts)])
93
+ self.expert_whiteners = nn.ParameterList([
94
+ nn.Parameter(torch.eye(d_embed)) for _ in range(n_experts)])
95
+ self.expert_means = nn.ParameterList([
96
+ nn.Parameter(torch.zeros(d_embed)) for _ in range(n_experts)])
97
+
98
+ # Consensus landmarks on the hypersphere
99
+ self.anchors = nn.Parameter(
100
+ F.normalize(torch.randn(n_anchors, d_embed), dim=-1))
101
+
102
+ # Geometric context projection
103
+ n_cross = n_experts * (n_experts - 1) // 2
104
+ geo_dim = n_experts + n_experts + n_cross + 1 + n_experts + n_anchors
105
+ self.geo_proj = nn.Sequential(
106
+ nn.Linear(geo_dim, d_bank * 2), nn.GELU(), nn.LayerNorm(d_bank * 2),
107
+ nn.Linear(d_bank * 2, d_bank), nn.LayerNorm(d_bank))
108
+
109
+ # Calibrated consensus targets (preserved from training)
110
+ self.register_buffer("target_cv", torch.tensor(0.082))
111
+ self.register_buffer("target_cross_cos_mean", torch.tensor(0.0))
112
+ self.register_buffer("target_cross_cos_std", torch.tensor(0.0))
113
+ self.register_buffer("target_disagreement_ratio", torch.tensor(0.0))
114
+
115
+ def forward(self, embedding):
116
+ B = embedding.shape[0]
117
+ emb = embedding.float()
118
+
119
+ # Full whitened Procrustes per expert: center β†’ whiten β†’ normalize β†’ rotate
120
+ expert_consistency = []
121
+ expert_recon = []
122
+ expert_projected = []
123
+ for i in range(self.n_experts):
124
+ R = self.expert_rotations[i]
125
+ W = self.expert_whiteners[i]
126
+ mu = self.expert_means[i]
127
+ centered = emb - mu
128
+ whitened = centered @ W
129
+ whitened_n = F.normalize(whitened, dim=-1)
130
+ in_expert = whitened_n @ R.T
131
+ back = in_expert @ R
132
+ cos = F.cosine_similarity(whitened_n, back, dim=-1)
133
+ recon = (whitened_n - back).pow(2).mean(dim=-1)
134
+ expert_consistency.append(cos)
135
+ expert_recon.append(recon)
136
+ expert_projected.append(in_expert)
137
+
138
+ expert_cos = torch.stack(expert_consistency, dim=-1)
139
+ expert_mse = torch.stack(expert_recon, dim=-1)
140
+
141
+ # Cross-expert differentiation (10 pairs for 5 experts)
142
+ cross_cos = []
143
+ for i in range(self.n_experts):
144
+ for j in range(i + 1, self.n_experts):
145
+ cc = F.cosine_similarity(
146
+ expert_projected[i], expert_projected[j], dim=-1)
147
+ cross_cos.append(cc)
148
+ cross_features = torch.stack(cross_cos, dim=-1)
149
+
150
+ # Per-sample disagreement
151
+ per_sample_agreement = expert_cos.mean(dim=-1)
152
+ per_sample_disagreement = expert_cos.std(dim=-1)
153
+ disagreement_ratio = per_sample_disagreement / (per_sample_agreement + 1e-8)
154
+
155
+ # Expert norm ratios
156
+ expert_norms = []
157
+ for i in range(self.n_experts):
158
+ W = self.expert_whiteners[i]; mu = self.expert_means[i]
159
+ whitened = (emb - mu) @ W
160
+ expert_norms.append(whitened.norm(dim=-1))
161
+ norm_ratio = torch.stack(expert_norms, dim=-1)
162
+ norm_ratio = norm_ratio / (norm_ratio.mean(dim=-1, keepdim=True) + 1e-8)
163
+
164
+ # Anchor distances
165
+ anchors_n = F.normalize(self.anchors, dim=-1)
166
+ anchor_cos = emb @ anchors_n.T
167
+
168
+ # Geometric context vector
169
+ geo_input = torch.cat([
170
+ expert_cos, expert_mse, cross_features,
171
+ disagreement_ratio.unsqueeze(-1), norm_ratio, anchor_cos
172
+ ], dim=-1)
173
+ geo_context = self.geo_proj(geo_input)
174
+ enriched = torch.cat([embedding, geo_context], dim=-1)
175
+
176
+ # Diagnostics
177
+ diagnostics = {
178
+ "expert_cos_mean": expert_cos.mean().item(),
179
+ "expert_cos_std": expert_cos.std().item(),
180
+ "cross_expert_cos": cross_features.mean().item(),
181
+ "cross_expert_cos_std": cross_features.std().item(),
182
+ "anchor_max_cos": anchor_cos.max(dim=-1).values.mean().item(),
183
+ "anchor_mean_cos": anchor_cos.mean().item(),
184
+ "disagreement_ratio": disagreement_ratio.mean().item(),
185
+ "norm_ratio_spread": norm_ratio.std(dim=-1).mean().item(),
186
+ }
187
+
188
+ return enriched, geo_context, diagnostics
189
 
190
 
191
  class CaptionBertModel(PreTrainedModel):
192
  """
193
+ Consensus-distilled caption encoder with geometric alignment bank.
194
 
195
+ The encoder produces L2-normalized 768-dim embeddings in the geometric
196
+ consensus space of 5 BERT-family models (BERT, ModernBERT, RoBERTa,
197
+ ALBERT, DistilBERT), aligned via Generalized Procrustes Analysis.
198
 
199
+ The alignment bank annotates each embedding with 128-dim geometric
200
+ context from the 5-expert differentiation structure β€” per-expert
201
+ consistency, cross-expert disagreement, and anchor distances.
202
+
203
+ Output fields:
204
+ last_hidden_state: (B, 768) L2-normalized consensus embedding
205
+ pooler_output: (B, 768) same (HF compatibility)
206
+ token_embeddings: (B, L, 384) pre-pooling token representations
207
+ enriched: (B, 896) embedding + bank geometric context
208
+ geometric_context: dict expert cos, cross-expert, anchors, etc.
209
+ hidden_states: tuple per-layer outputs (if requested)
210
  """
211
  config_class = CaptionBertConfig
212
 
 
214
  super().__init__(config)
215
  self.config = config
216
 
217
+ # ── Encoder ──
218
  self.token_emb = nn.Embedding(
219
  config.vocab_size, config.hidden_size,
220
  padding_idx=config.pad_token_id)
 
223
  self.emb_norm = nn.LayerNorm(config.hidden_size)
224
  self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
225
 
 
226
  encoder_layer = nn.TransformerEncoderLayer(
227
  d_model=config.hidden_size,
228
  nhead=config.num_attention_heads,
 
236
  encoder_layer, num_layers=config.num_hidden_layers,
237
  enable_nested_tensor=False)
238
 
 
239
  self.output_proj = nn.Sequential(
240
  nn.Linear(config.hidden_size, config.hidden_size),
241
  nn.GELU(),
 
243
  nn.Linear(config.hidden_size, config.output_dim),
244
  )
245
 
246
+ # ── Alignment Bank ──
247
+ if getattr(config, 'bank_enabled', False):
248
+ self.bank = AlignmentBank(
249
+ d_embed=config.output_dim,
250
+ n_experts=config.bank_n_experts,
251
+ n_anchors=config.bank_n_anchors,
252
+ d_bank=config.bank_dim,
253
+ )
254
+ else:
255
+ self.bank = None
256
+
257
  self.post_init()
258
 
259
  def forward(self, input_ids=None, attention_mask=None,
 
261
  B, L = input_ids.shape
262
  device = input_ids.device
263
 
264
+ # ── Encode ──
265
  positions = torch.arange(L, device=device).unsqueeze(0)
266
  x = self.token_emb(input_ids) + self.pos_emb(positions)
267
  x = self.emb_drop(self.emb_norm(x))
268
 
 
269
  if attention_mask is not None:
270
  key_padding_mask = ~attention_mask.bool()
271
  else:
272
  key_padding_mask = (input_ids == self.config.pad_token_id)
273
 
 
274
  hidden_states = [x] if output_hidden_states else None
275
  for layer in self.encoder.layers:
276
  x = layer(x, src_key_padding_mask=key_padding_mask)
277
  if output_hidden_states:
278
  hidden_states.append(x)
279
 
280
+ # ── Pool + Project ──
281
  if attention_mask is not None:
282
  mask = attention_mask.unsqueeze(-1).float()
283
  else:
284
  mask = (~key_padding_mask).unsqueeze(-1).float()
285
  pooled = (x * mask).sum(1) / mask.sum(1).clamp(min=1)
 
 
286
  embedding = F.normalize(self.output_proj(pooled), dim=-1)
287
 
288
+ # ── Alignment Bank ──
289
+ enriched = None
290
+ geo_diagnostics = None
291
+ if self.bank is not None:
292
+ enriched, _, geo_diagnostics = self.bank(embedding)
293
+
294
+ # ── Output ──
295
  result = {
296
+ 'last_hidden_state': embedding, # (B, 768)
297
+ 'pooler_output': embedding, # (B, 768) compat
298
+ 'token_embeddings': x, # (B, L, 384)
299
+ 'enriched': enriched, # (B, 896) or None
300
+ 'geometric_context': geo_diagnostics, # dict or None
301
  }
302
  if output_hidden_states:
303
  result['hidden_states'] = tuple(hidden_states)
 
306
 
307
  def encode(self, texts, tokenizer=None, max_length=512, batch_size=128,
308
  device=None):
309
+ """Convenience: raw text β†’ L2-normalized (N, 768) embeddings."""
 
 
 
 
 
 
 
 
 
 
 
 
310
  if isinstance(texts, str):
311
  texts = [texts]
 
312
  if tokenizer is None:
313
  from transformers import AutoTokenizer
314
  tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
 
315
  if device is None:
316
  device = next(self.parameters()).device
 
317
  self.eval()
318
  all_emb = []
319
  with torch.no_grad():