| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import math |
| | from dataclasses import dataclass |
| | from typing import Any, Dict, List, Optional, Tuple |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from transformers import BertModel |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | @dataclass |
| | class DeepBertV3Config: |
| | |
| | bert_model: str = "google-bert/bert-large-uncased" |
| | hidden_size: int = 1024 |
| | freeze_bert: bool = True |
| |
|
| | |
| | n_memory_tokens: int = 16 |
| |
|
| | |
| | bank_size: int = 128 |
| | anchor_dim: int = 1024 |
| | n_bank_heads: int = 8 |
| | bank_cross_layers: int = 2 |
| |
|
| | |
| | gate_type: str = "gru" |
| |
|
| | |
| | extract_layers: Tuple[int, ...] = (2, 5, 8, 11, 14, 17, 20, 23) |
| | layer_fusion: str = "learned" |
| |
|
| | |
| | max_content_tokens: int = 480 |
| | segment_overlap: int = 64 |
| | max_position: int = 512 |
| |
|
| | |
| | n_teachers: int = 2 |
| | teacher_hidden: int = 1024 |
| |
|
| | |
| | cv_target: float = 0.20 |
| |
|
| | @property |
| | def n_extract_layers(self): |
| | return len(self.extract_layers) |
| |
|
| | @property |
| | def depth_profile_dim(self): |
| | return self.n_extract_layers * self.hidden_size |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def cayley_menger_vol2(pts): |
| | with torch.amp.autocast("cuda", enabled=False): |
| | pts = pts.float() |
| | diff = pts.unsqueeze(-2) - pts.unsqueeze(-3) |
| | d2 = (diff * diff).sum(-1) |
| | B, V, _ = d2.shape |
| | cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32) |
| | cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 |
| | s = (-1.0)**V; f = math.factorial(V-1) |
| | return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm) |
| |
|
| |
|
| | def pentachoron_cv(embeddings, n_samples=16): |
| | """CV = std/mean of pentachoron volumes.""" |
| | B = embeddings.shape[0] |
| | if B < 5: |
| | return torch.tensor(0.0, device=embeddings.device) |
| | vols = [] |
| | for _ in range(n_samples): |
| | idx = torch.randperm(B, device=embeddings.device)[:5] |
| | v2 = cayley_menger_vol2(embeddings[idx].unsqueeze(0)) |
| | vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12)) |
| | stacked = torch.stack(vols) |
| | return stacked.std() / (stacked.mean() + 1e-8) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class GeometricMemoryBank(nn.Module): |
| | """ |
| | Bank stores compressed depth-profile anchors from each segment. |
| | Memory tokens query the bank via cross-attention. |
| | No alignment transform β both spaces learned end-to-end. |
| | """ |
| | def __init__(self, config: DeepBertV3Config): |
| | super().__init__() |
| | self.config = config |
| | self.max_size = config.bank_size |
| | self.dim = config.anchor_dim |
| |
|
| | |
| | depth_dim = config.depth_profile_dim |
| | self.depth_compressor = nn.Sequential( |
| | nn.Linear(depth_dim, config.hidden_size * 2), |
| | nn.GELU(), |
| | nn.LayerNorm(config.hidden_size * 2), |
| | nn.Linear(config.hidden_size * 2, config.anchor_dim), |
| | ) |
| |
|
| | |
| | self.temporal_proj = nn.Linear(1, config.anchor_dim, bias=False) |
| |
|
| | |
| | self.cross_attn = nn.ModuleList([ |
| | nn.MultiheadAttention(config.hidden_size, config.n_bank_heads, |
| | batch_first=True, dropout=0.1) |
| | for _ in range(config.bank_cross_layers) |
| | ]) |
| | self.cross_norms = nn.ModuleList([ |
| | nn.LayerNorm(config.hidden_size) |
| | for _ in range(config.bank_cross_layers) |
| | ]) |
| | self.cross_ffns = nn.ModuleList([ |
| | nn.Sequential( |
| | nn.Linear(config.hidden_size, config.hidden_size * 2), |
| | nn.GELU(), |
| | nn.Linear(config.hidden_size * 2, config.hidden_size), |
| | ) |
| | for _ in range(config.bank_cross_layers) |
| | ]) |
| | self.ffn_norms = nn.ModuleList([ |
| | nn.LayerNorm(config.hidden_size) |
| | for _ in range(config.bank_cross_layers) |
| | ]) |
| |
|
| | def init_bank(self, batch_size: int, device: torch.device) -> Dict[str, Any]: |
| | return {"anchors": torch.zeros(batch_size, 0, self.dim, device=device), |
| | "n_written": 0} |
| |
|
| | def write(self, bank, content_hidden, attention_mask=None, |
| | segment_idx=0, depth_cls=None): |
| | anchors = bank["anchors"] |
| |
|
| | if depth_cls is not None: |
| | B = depth_cls.shape[0] |
| | anchor = self.depth_compressor(depth_cls.reshape(B, -1)) |
| | else: |
| | if attention_mask is not None: |
| | m = attention_mask.float().unsqueeze(-1) |
| | pooled = (content_hidden * m).sum(1) / m.sum(1).clamp(min=1) |
| | else: |
| | pooled = content_hidden.mean(dim=1) |
| | anchor = self.depth_compressor( |
| | pooled.repeat(1, self.config.n_extract_layers)) |
| |
|
| | anchor = F.normalize(anchor, dim=-1) |
| |
|
| | |
| | t = torch.tensor([[segment_idx]], dtype=anchor.dtype, device=anchor.device) |
| | anchor = anchor + 0.1 * self.temporal_proj(t / max(self.max_size, 1)) |
| | anchor = F.normalize(anchor, dim=-1) |
| |
|
| | |
| | anchors = torch.cat([anchors, anchor.unsqueeze(1)], dim=1) |
| | if anchors.shape[1] > self.max_size: |
| | anchors = anchors[:, -self.max_size:] |
| |
|
| | return {"anchors": anchors, "n_written": bank["n_written"] + 1, |
| | "live_anchor": anchor} |
| |
|
| | def read(self, memory_tokens, bank): |
| | anchors = bank["anchors"] |
| | if anchors.shape[1] == 0: |
| | return memory_tokens |
| |
|
| | x = memory_tokens |
| | for attn, norm, ffn, ffn_norm in zip( |
| | self.cross_attn, self.cross_norms, |
| | self.cross_ffns, self.ffn_norms, |
| | ): |
| | residual = x |
| | x, _ = attn(norm(x), anchors, anchors) |
| | x = residual + x |
| | residual = x |
| | x = residual + ffn(ffn_norm(x)) |
| | return x |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class DeltaMemoryGate(nn.Module): |
| | def __init__(self, config: DeepBertV3Config): |
| | super().__init__() |
| | H = config.hidden_size |
| | self.gate_type = config.gate_type |
| | if config.gate_type == "gru": |
| | self.reset_proj = nn.Linear(H * 2, H) |
| | self.update_proj = nn.Linear(H * 2, H) |
| | self.candidate_proj = nn.Linear(H * 2, H) |
| | else: |
| | self.gate_proj = nn.Linear(H * 2, H) |
| | self.norm = nn.LayerNorm(H) |
| |
|
| | def forward(self, old, new): |
| | cat = torch.cat([old, new], dim=-1) |
| | if self.gate_type == "gru": |
| | r = torch.sigmoid(self.reset_proj(cat)) |
| | z = torch.sigmoid(self.update_proj(cat)) |
| | h = torch.tanh(self.candidate_proj(torch.cat([r * old, new], dim=-1))) |
| | out = z * old + (1 - z) * h |
| | else: |
| | g = torch.sigmoid(self.gate_proj(cat)) |
| | out = g * old + (1 - g) * new |
| | return self.norm(out) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class LayerFusion(nn.Module): |
| | def __init__(self, config: DeepBertV3Config): |
| | super().__init__() |
| | n = len(config.extract_layers) |
| | self.weights = nn.Parameter(torch.ones(n) / n) |
| | self.proj = nn.Linear(config.hidden_size, config.hidden_size) |
| | self.norm = nn.LayerNorm(config.hidden_size) |
| |
|
| | def forward(self, layer_outputs): |
| | w = F.softmax(self.weights, dim=0) |
| | stacked = torch.stack(layer_outputs) |
| | fused = (stacked * w.view(-1, 1, 1, 1)).sum(0) |
| | return self.norm(self.proj(fused)) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class TeacherProjector(nn.Module): |
| | """ |
| | Projects student output β teacher space. Linear(1024, 1024). |
| | Initialized from static Procrustes rotation in trainer. |
| | Fine-tunes during training to account for non-linear differences. |
| | """ |
| | def __init__(self, student_dim: int, teacher_dim: int, name: str = ""): |
| | super().__init__() |
| | self.name = name |
| | self.proj = nn.Linear(student_dim, teacher_dim, bias=True) |
| | |
| | nn.init.eye_(self.proj.weight) |
| | nn.init.zeros_(self.proj.bias) |
| |
|
| | def forward(self, x): |
| | return self.proj(x) |
| |
|
| | def init_from_procrustes(self, rotation, student_mean, teacher_mean): |
| | """ |
| | Initialize projector from pre-computed Procrustes alignment. |
| | rotation: (D, D) orthogonal matrix mapping student β teacher |
| | student_mean, teacher_mean: (D,) centering vectors |
| | Sets weight = rotation, bias = teacher_mean - rotation @ student_mean |
| | """ |
| | with torch.no_grad(): |
| | self.proj.weight.copy_(rotation) |
| | self.proj.bias.copy_(teacher_mean - rotation @ student_mean) |
| | print(f" [{self.name}] Procrustes init: |R|={rotation.norm():.3f}") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class DeepBertV3(nn.Module): |
| | def __init__(self, config: DeepBertV3Config): |
| | super().__init__() |
| | self.config = config |
| |
|
| | |
| | self.bert = BertModel.from_pretrained( |
| | config.bert_model, add_pooling_layer=False, |
| | attn_implementation="eager") |
| | self.bert.config.output_hidden_states = True |
| | if config.freeze_bert: |
| | for p in self.bert.parameters(): |
| | p.requires_grad = False |
| |
|
| | |
| | self.memory_embeddings = nn.Parameter( |
| | torch.randn(1, config.n_memory_tokens, config.hidden_size) * 0.02) |
| | self.layer_fusion = LayerFusion(config) |
| | self.bank = GeometricMemoryBank(config) |
| | self.gate = DeltaMemoryGate(config) |
| |
|
| | |
| | self.output_proj = nn.Sequential( |
| | nn.Linear(config.hidden_size, config.hidden_size), |
| | nn.GELU(), nn.LayerNorm(config.hidden_size)) |
| | self.memory_output_fusion = nn.Sequential( |
| | nn.Linear(config.hidden_size * 2, config.hidden_size), |
| | nn.GELU(), |
| | nn.Linear(config.hidden_size, config.hidden_size)) |
| |
|
| | |
| | self.proj_modern = TeacherProjector( |
| | config.hidden_size, config.teacher_hidden, "ModernBERT") |
| | self.proj_longformer = TeacherProjector( |
| | config.hidden_size, config.teacher_hidden, "Longformer") |
| |
|
| | @classmethod |
| | def from_pretrained(cls, config=None, **kwargs): |
| | if config is None: |
| | config = DeepBertV3Config(**kwargs) |
| | model = cls(config) |
| | n_train = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| | n_frozen = sum(p.numel() for p in model.parameters() if not p.requires_grad) |
| | print(f"DeepBert v3 initialized:") |
| | print(f" BERT: {n_frozen:,} frozen") |
| | print(f" Memory + projectors: {n_train:,} trainable") |
| | print(f" Extract: {config.extract_layers} β {config.depth_profile_dim}-dim anchor") |
| | print(f" Bank: {config.bank_size} anchors, {config.bank_cross_layers} cross-attn") |
| | print(f" Memory: {config.n_memory_tokens} tokens, {config.gate_type} gate") |
| | return model |
| |
|
| | def init_state(self, batch_size, device=None): |
| | if device is None: |
| | device = next(self.parameters()).device |
| | return { |
| | "memory": self.memory_embeddings.expand(batch_size, -1, -1).clone(), |
| | "bank": self.bank.init_bank(batch_size, device), |
| | "segment_idx": 0, |
| | } |
| |
|
| | def forward(self, input_ids, attention_mask, state): |
| | B = input_ids.shape[0] |
| | device = input_ids.device |
| | n_mem = self.config.n_memory_tokens |
| | seq_len = input_ids.shape[1] |
| |
|
| | memory_state = state["memory"] |
| | bank = state["bank"] |
| | seg_idx = state["segment_idx"] |
| |
|
| | |
| | memory_tokens = self.bank.read(memory_state, bank) |
| |
|
| | |
| | content_embeds = self.bert.embeddings.word_embeddings(input_ids) |
| | inputs_embeds = torch.cat([memory_tokens, content_embeds], dim=1) |
| |
|
| | position_ids = torch.cat([ |
| | torch.arange(n_mem, device=device).unsqueeze(0).expand(B, -1), |
| | torch.arange(n_mem, n_mem + seq_len, device=device).unsqueeze(0).expand(B, -1), |
| | ], dim=1).clamp(max=self.config.max_position - 1) |
| |
|
| | token_type_ids = torch.cat([ |
| | torch.ones(B, n_mem, dtype=torch.long, device=device), |
| | torch.zeros(B, seq_len, dtype=torch.long, device=device), |
| | ], dim=1) |
| |
|
| | full_mask = torch.cat([ |
| | torch.ones(B, n_mem, device=device, dtype=attention_mask.dtype), |
| | attention_mask, |
| | ], dim=1) |
| |
|
| | |
| | bert_out = self.bert( |
| | inputs_embeds=inputs_embeds, attention_mask=full_mask, |
| | position_ids=position_ids, token_type_ids=token_type_ids, |
| | output_hidden_states=True, return_dict=True) |
| |
|
| | |
| | selected = [bert_out.hidden_states[i + 1] for i in self.config.extract_layers] |
| | hidden = self.layer_fusion(selected) |
| | memory_output = hidden[:, :n_mem] |
| | content_output = hidden[:, n_mem:] |
| |
|
| | |
| | depth_cls = torch.stack([h[:, n_mem, :] for h in selected], dim=1) |
| |
|
| | |
| | new_memory = self.gate(memory_state, memory_output) |
| |
|
| | |
| | new_bank = self.bank.write(bank, content_output, attention_mask, |
| | seg_idx, depth_cls=depth_cls) |
| |
|
| | |
| | cls_output = self.output_proj(content_output[:, 0]) |
| | memory_delta = self.memory_output_fusion( |
| | torch.cat([cls_output, new_memory.mean(dim=1)], dim=-1)) |
| | fused = cls_output + memory_delta |
| |
|
| | outputs = { |
| | "memory_output": fused, |
| | "cls_output": cls_output, |
| | "live_anchor": new_bank["live_anchor"], |
| | "depth_cls": depth_cls, |
| | "content_output": content_output, |
| | "memory_tokens": new_memory, |
| | } |
| |
|
| | |
| | new_state = { |
| | "memory": new_memory, |
| | "bank": {"anchors": new_bank["anchors"], |
| | "n_written": new_bank["n_written"], |
| | "live_anchor": new_bank["live_anchor"]}, |
| | "segment_idx": seg_idx + 1, |
| | } |
| | return outputs, new_state |
| |
|
| | @staticmethod |
| | def detach_state(state): |
| | return { |
| | "memory": state["memory"].detach(), |
| | "bank": {"anchors": state["bank"]["anchors"].detach(), |
| | "n_written": state["bank"]["n_written"]}, |
| | "segment_idx": state["segment_idx"], |
| | } |
| |
|
| | def get_trainable_params(self): |
| | return [p for p in self.parameters() if p.requires_grad] |
| |
|
| | def num_trainable_params(self): |
| | return sum(p.numel() for p in self.parameters() if p.requires_grad) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | @torch.no_grad() |
| | def compute_static_procrustes(student_embs, teacher_embs): |
| | """ |
| | Orthogonal Procrustes: find R that minimizes ||student @ R - teacher||_F. |
| | Returns rotation R, student_mean, teacher_mean. |
| | """ |
| | X = student_embs.float() |
| | Y = teacher_embs.float() |
| | mu_x, mu_y = X.mean(0), Y.mean(0) |
| | Xc, Yc = X - mu_x, Y - mu_y |
| | U, S, Vt = torch.linalg.svd(Xc.T @ Yc) |
| | R = (U @ Vt).T |
| | cos_before = F.cosine_similarity(Xc, Yc, dim=-1).mean() |
| | cos_after = F.cosine_similarity((Xc @ R.T), Yc, dim=-1).mean() |
| | print(f" Procrustes: cos {cos_before:.4f} β {cos_after:.4f}") |
| | return R, mu_x, mu_y |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | if __name__ == "__main__": |
| | print("=" * 70) |
| | print("DEEP BERT v3 β Teacher-Distilled Geometric Memory") |
| | print("=" * 70) |
| |
|
| | config = DeepBertV3Config() |
| | model = DeepBertV3.from_pretrained(config) |
| |
|
| | comps = { |
| | "memory_embeddings": model.memory_embeddings.numel(), |
| | "layer_fusion": sum(p.numel() for p in model.layer_fusion.parameters()), |
| | "bank.depth_compressor": sum(p.numel() for p in model.bank.depth_compressor.parameters()), |
| | "bank.temporal_proj": sum(p.numel() for p in model.bank.temporal_proj.parameters()), |
| | "bank.cross_attn": sum(p.numel() for p in model.bank.cross_attn.parameters()), |
| | "bank.cross_ffns": sum(p.numel() for p in model.bank.cross_ffns.parameters()), |
| | "gate": sum(p.numel() for p in model.gate.parameters()), |
| | "output_proj": sum(p.numel() for p in model.output_proj.parameters()), |
| | "memory_output_fusion": sum(p.numel() for p in model.memory_output_fusion.parameters()), |
| | "proj_modern": sum(p.numel() for p in model.proj_modern.parameters()), |
| | "proj_longformer": sum(p.numel() for p in model.proj_longformer.parameters()), |
| | } |
| | print(f"\n Component breakdown:") |
| | for k, v in comps.items(): |
| | print(f" {k:30s}: {v:,}") |
| | print(f" {'TOTAL':30s}: {sum(comps.values()):,}") |
| |
|
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | model = model.to(device) |
| |
|
| | from transformers import BertTokenizer |
| | tok = BertTokenizer.from_pretrained(config.bert_model) |
| | state = model.init_state(1, device) |
| | texts = [ |
| | "The quick brown fox jumps over the lazy dog near the riverbank.", |
| | "Meanwhile the cat sat on the mat observing everything carefully.", |
| | "Both animals eventually fell asleep under the warm afternoon sun.", |
| | ] |
| | for i, text in enumerate(texts): |
| | tokens = tok(text, return_tensors="pt", padding="max_length", |
| | truncation=True, max_length=config.max_content_tokens) |
| | with torch.no_grad(): |
| | out, state = model(tokens["input_ids"].to(device), |
| | tokens["attention_mask"].to(device), state) |
| | print(f"\n Seg {i+1}: anchor={out['live_anchor'].shape}, " |
| | f"fused={out['memory_output'].shape}, " |
| | f"bank={state['bank']['anchors'].shape[1]}") |
| |
|
| | print(f"\nDone.") |