File size: 13,478 Bytes
1e5d50f
a1933ae
1e5d50f
 
8b6e6b3
 
 
a1933ae
 
8b6e6b3
 
 
a1933ae
 
 
 
 
 
 
 
 
 
 
 
1e5d50f
 
a1933ae
1e5d50f
 
 
8b6e6b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1933ae
8b6e6b3
a1933ae
 
 
 
 
 
8b6e6b3
 
 
 
 
 
 
 
 
 
 
a1933ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34d35c5
 
a1933ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e5d50f
 
8b6e6b3
1e5d50f
a1933ae
1e5d50f
a1933ae
 
 
1e5d50f
a1933ae
 
 
 
 
 
 
 
 
 
 
1e5d50f
8b6e6b3
1e5d50f
8b6e6b3
 
 
 
a1933ae
8b6e6b3
 
 
 
 
 
 
 
1e5d50f
8b6e6b3
 
 
 
 
 
 
 
 
91ea190
 
1e5d50f
 
8b6e6b3
1e5d50f
8b6e6b3
 
1e5d50f
 
a1933ae
 
 
 
 
 
 
 
 
 
 
8b6e6b3
 
7f7daa9
 
1e5d50f
8b6e6b3
1e5d50f
a1933ae
8b6e6b3
1e5d50f
 
 
 
8b6e6b3
1e5d50f
8b6e6b3
1e5d50f
7f7daa9
 
 
 
 
1e5d50f
a1933ae
1e5d50f
 
 
8b6e6b3
1e5d50f
8b6e6b3
 
a1933ae
 
 
 
 
 
 
7f7daa9
a1933ae
 
 
 
 
7f7daa9
 
 
 
 
8b6e6b3
 
 
a1933ae
8b6e6b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
# ============================================================================
# CaptionBERT-8192: HuggingFace AutoModel with Alignment Bank
#
# Usage:
#   from transformers import AutoModel, AutoTokenizer
#   model = AutoModel.from_pretrained("AbstractPhil/geolip-captionbert-8192",
#                                      trust_remote_code=True)
#   tokenizer = AutoTokenizer.from_pretrained("AbstractPhil/geolip-captionbert-8192",
#                                              trust_remote_code=True)
#   inputs = tokenizer("A cat on a windowsill", return_tensors="pt",
#                       padding=True, truncation=True, max_length=512)
#   outputs = model(**inputs)
#
#   # Core embedding (consensus-distilled, L2-normalized)
#   embedding = outputs.last_hidden_state      # (B, 768)
#
#   # Enriched embedding (with geometric context from 5-expert bank)
#   enriched = outputs.enriched                # (B, 768 + bank_dim)
#
#   # Token-level representations (pre-pooling, for sequence tasks)
#   tokens = outputs.token_embeddings          # (B, L, 384)
#
#   # Geometric diagnostics
#   geo = outputs.geometric_context            # dict with expert cos, anchors, etc.
# ============================================================================

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig, PreTrainedModel


class CaptionBertConfig(PretrainedConfig):
    model_type = "caption_bert"

    def __init__(
        self,
        vocab_size=30522,
        max_position_embeddings=8192,
        hidden_size=384,
        num_attention_heads=6,
        num_hidden_layers=6,
        intermediate_size=1536,
        output_dim=768,
        hidden_dropout_prob=0.0,
        pad_token_id=0,
        # Alignment bank
        bank_enabled=True,
        bank_n_experts=5,
        bank_n_anchors=512,
        bank_dim=128,
        bank_cv_target=0.082,
        **kwargs,
    ):
        super().__init__(pad_token_id=pad_token_id, **kwargs)
        self.vocab_size = vocab_size
        self.max_position_embeddings = max_position_embeddings
        self.hidden_size = hidden_size
        self.num_attention_heads = num_attention_heads
        self.num_hidden_layers = num_hidden_layers
        self.intermediate_size = intermediate_size
        self.output_dim = output_dim
        self.hidden_dropout_prob = hidden_dropout_prob
        self.bank_enabled = bank_enabled
        self.bank_n_experts = bank_n_experts
        self.bank_n_anchors = bank_n_anchors
        self.bank_dim = bank_dim
        self.bank_cv_target = bank_cv_target


class AlignmentBank(nn.Module):
    """
    Geometric interface layer preserving 5-expert differentiation structure.

    Trained post-hoc on frozen encoder via GPA + whitened Procrustes.
    Stores per-expert rotation matrices, whiteners, and means that encode
    how each expert's geometric perspective differs from the consensus center.

    Provides geometric context annotations (128-dim) alongside the core
    768-dim consensus embedding for downstream heads.
    """
    def __init__(self, d_embed=768, n_experts=5, n_anchors=512, d_bank=128):
        super().__init__()
        self.d_embed = d_embed
        self.n_experts = n_experts
        self.n_anchors = n_anchors
        self.d_bank = d_bank

        # Per-expert Procrustes components (the differentiation structure)
        self.expert_rotations = nn.ParameterList([
            nn.Parameter(torch.eye(d_embed)) for _ in range(n_experts)])
        self.expert_whiteners = nn.ParameterList([
            nn.Parameter(torch.eye(d_embed)) for _ in range(n_experts)])
        self.expert_means = nn.ParameterList([
            nn.Parameter(torch.zeros(d_embed)) for _ in range(n_experts)])

        # Consensus landmarks on the hypersphere
        self.anchors = nn.Parameter(
            F.normalize(torch.randn(n_anchors, d_embed), dim=-1))

        # Geometric context projection
        n_cross = n_experts * (n_experts - 1) // 2
        geo_dim = n_experts + n_experts + n_cross + 1 + n_experts + n_anchors
        self.geo_proj = nn.Sequential(
            nn.Linear(geo_dim, d_bank * 2), nn.GELU(), nn.LayerNorm(d_bank * 2),
            nn.Linear(d_bank * 2, d_bank), nn.LayerNorm(d_bank))

        # Calibrated consensus targets (preserved from training)
        self.register_buffer("target_cv", torch.tensor(0.082))
        self.register_buffer("target_mean_cos", torch.tensor(0.0))
        self.register_buffer("target_spectral", torch.zeros(50))
        self.register_buffer("target_cross_cos_mean", torch.tensor(0.0))
        self.register_buffer("target_cross_cos_std", torch.tensor(0.0))
        self.register_buffer("target_disagreement_ratio", torch.tensor(0.0))

    def forward(self, embedding):
        B = embedding.shape[0]
        emb = embedding.float()

        # Full whitened Procrustes per expert: center β†’ whiten β†’ normalize β†’ rotate
        expert_consistency = []
        expert_recon = []
        expert_projected = []
        for i in range(self.n_experts):
            R = self.expert_rotations[i]
            W = self.expert_whiteners[i]
            mu = self.expert_means[i]
            centered = emb - mu
            whitened = centered @ W
            whitened_n = F.normalize(whitened, dim=-1)
            in_expert = whitened_n @ R.T
            back = in_expert @ R
            cos = F.cosine_similarity(whitened_n, back, dim=-1)
            recon = (whitened_n - back).pow(2).mean(dim=-1)
            expert_consistency.append(cos)
            expert_recon.append(recon)
            expert_projected.append(in_expert)

        expert_cos = torch.stack(expert_consistency, dim=-1)
        expert_mse = torch.stack(expert_recon, dim=-1)

        # Cross-expert differentiation (10 pairs for 5 experts)
        cross_cos = []
        for i in range(self.n_experts):
            for j in range(i + 1, self.n_experts):
                cc = F.cosine_similarity(
                    expert_projected[i], expert_projected[j], dim=-1)
                cross_cos.append(cc)
        cross_features = torch.stack(cross_cos, dim=-1)

        # Per-sample disagreement
        per_sample_agreement = expert_cos.mean(dim=-1)
        per_sample_disagreement = expert_cos.std(dim=-1)
        disagreement_ratio = per_sample_disagreement / (per_sample_agreement + 1e-8)

        # Expert norm ratios
        expert_norms = []
        for i in range(self.n_experts):
            W = self.expert_whiteners[i]; mu = self.expert_means[i]
            whitened = (emb - mu) @ W
            expert_norms.append(whitened.norm(dim=-1))
        norm_ratio = torch.stack(expert_norms, dim=-1)
        norm_ratio = norm_ratio / (norm_ratio.mean(dim=-1, keepdim=True) + 1e-8)

        # Anchor distances
        anchors_n = F.normalize(self.anchors, dim=-1)
        anchor_cos = emb @ anchors_n.T

        # Geometric context vector
        geo_input = torch.cat([
            expert_cos, expert_mse, cross_features,
            disagreement_ratio.unsqueeze(-1), norm_ratio, anchor_cos
        ], dim=-1)
        geo_context = self.geo_proj(geo_input)
        enriched = torch.cat([embedding, geo_context], dim=-1)

        # Diagnostics
        diagnostics = {
            "expert_cos_mean": expert_cos.mean().item(),
            "expert_cos_std": expert_cos.std().item(),
            "cross_expert_cos": cross_features.mean().item(),
            "cross_expert_cos_std": cross_features.std().item(),
            "anchor_max_cos": anchor_cos.max(dim=-1).values.mean().item(),
            "anchor_mean_cos": anchor_cos.mean().item(),
            "disagreement_ratio": disagreement_ratio.mean().item(),
            "norm_ratio_spread": norm_ratio.std(dim=-1).mean().item(),
        }

        return enriched, geo_context, diagnostics


class CaptionBertModel(PreTrainedModel):
    """
    Consensus-distilled caption encoder with geometric alignment bank.

    The encoder produces L2-normalized 768-dim embeddings in the geometric
    consensus space of 5 BERT-family models (BERT, ModernBERT, RoBERTa,
    ALBERT, DistilBERT), aligned via Generalized Procrustes Analysis.

    The alignment bank annotates each embedding with 128-dim geometric
    context from the 5-expert differentiation structure β€” per-expert
    consistency, cross-expert disagreement, and anchor distances.

    Output fields:
        last_hidden_state:   (B, 768)         L2-normalized consensus embedding
        pooler_output:       (B, 768)         same (HF compatibility)
        token_embeddings:    (B, L, 384)      pre-pooling token representations
        enriched:            (B, 896)         embedding + bank geometric context
        geometric_context:   dict             expert cos, cross-expert, anchors, etc.
        hidden_states:       tuple            per-layer outputs (if requested)
    """
    config_class = CaptionBertConfig

    def __init__(self, config):
        super().__init__(config)
        self.config = config

        # ── Encoder ──
        self.token_emb = nn.Embedding(
            config.vocab_size, config.hidden_size,
            padding_idx=config.pad_token_id)
        self.pos_emb = nn.Embedding(
            config.max_position_embeddings, config.hidden_size)
        self.emb_norm = nn.LayerNorm(config.hidden_size)
        self.emb_drop = nn.Dropout(config.hidden_dropout_prob)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=config.hidden_size,
            nhead=config.num_attention_heads,
            dim_feedforward=config.intermediate_size,
            dropout=config.hidden_dropout_prob,
            activation="gelu",
            batch_first=True,
            norm_first=True,
        )
        self.encoder = nn.TransformerEncoder(
            encoder_layer, num_layers=config.num_hidden_layers,
            enable_nested_tensor=False)

        self.output_proj = nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size),
            nn.GELU(),
            nn.LayerNorm(config.hidden_size),
            nn.Linear(config.hidden_size, config.output_dim),
        )

        # ── Alignment Bank ──
        if getattr(config, 'bank_enabled', False):
            self.bank = AlignmentBank(
                d_embed=config.output_dim,
                n_experts=config.bank_n_experts,
                n_anchors=config.bank_n_anchors,
                d_bank=config.bank_dim,
            )
        else:
            self.bank = None

        self.post_init()

    def forward(self, input_ids=None, attention_mask=None,
                output_hidden_states=False, **kwargs):
        B, L = input_ids.shape
        device = input_ids.device

        # ── Encode ──
        positions = torch.arange(L, device=device).unsqueeze(0)
        x = self.token_emb(input_ids) + self.pos_emb(positions)
        x = self.emb_drop(self.emb_norm(x))

        if attention_mask is not None:
            key_padding_mask = ~attention_mask.bool()
        else:
            key_padding_mask = (input_ids == self.config.pad_token_id)

        hidden_states = [x] if output_hidden_states else None
        for layer in self.encoder.layers:
            x = layer(x, src_key_padding_mask=key_padding_mask)
            if output_hidden_states:
                hidden_states.append(x)

        # ── Pool + Project ──
        if attention_mask is not None:
            mask = attention_mask.unsqueeze(-1).float()
        else:
            mask = (~key_padding_mask).unsqueeze(-1).float()
        pooled = (x * mask).sum(1) / mask.sum(1).clamp(min=1)
        embedding = F.normalize(self.output_proj(pooled), dim=-1)

        # ── Alignment Bank ──
        enriched = None
        geo_diagnostics = None
        if self.bank is not None:
            enriched, _, geo_diagnostics = self.bank(embedding)

        # ── Output ──
        result = {
            'last_hidden_state': embedding,       # (B, 768)
            'pooler_output': embedding,            # (B, 768) compat
            'token_embeddings': x,                 # (B, L, 384)
            'enriched': enriched,                  # (B, 896) or None
            'geometric_context': geo_diagnostics,  # dict or None
        }
        if output_hidden_states:
            result['hidden_states'] = tuple(hidden_states)

        return type('Output', (), result)()

    def encode(self, texts, tokenizer=None, max_length=512, batch_size=128,
               device=None):
        """Convenience: raw text β†’ L2-normalized (N, 768) embeddings."""
        if isinstance(texts, str):
            texts = [texts]
        if tokenizer is None:
            from transformers import AutoTokenizer
            tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
        if device is None:
            device = next(self.parameters()).device
        self.eval()
        all_emb = []
        with torch.no_grad():
            for i in range(0, len(texts), batch_size):
                batch = texts[i:i+batch_size]
                inputs = tokenizer(
                    batch, max_length=max_length, padding="max_length",
                    truncation=True, return_tensors="pt"
                ).to(device)
                out = self(input_ids=inputs["input_ids"],
                          attention_mask=inputs["attention_mask"])
                all_emb.append(out.last_hidden_state.cpu())
        return torch.cat(all_emb)