| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import math |
| import os |
| import json |
| import time |
| from dataclasses import dataclass, field, asdict |
| from typing import Any, Dict, List, Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| from tqdm import tqdm |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class MemoryCLIPConfig: |
| |
| clip_model: str = "openai/clip-vit-large-patch14" |
| clip_hidden: int = 768 |
| clip_layers: int = 12 |
| clip_max_tokens: int = 77 |
| freeze_clip: bool = True |
|
|
| |
| n_memory_tokens: int = 8 |
| bank_size: int = 64 |
| anchor_dim: int = 768 |
| n_bank_heads: int = 8 |
| bank_cross_layers: int = 2 |
|
|
| |
| gate_type: str = "gru" |
|
|
| |
| |
| extract_layers: Tuple[int, ...] = (1, 3, 5, 7, 9, 11) |
| layer_fusion: str = "learned" |
|
|
| |
| |
| |
| max_content_tokens: int = 18 |
| segment_overlap: int = 4 |
| max_segments: int = 32 |
|
|
| |
| teacher_model: str = "answerdotai/ModernBERT-large" |
| teacher_hidden: int = 1024 |
| teacher_max_len: int = 4096 |
|
|
| |
| cv_target: float = 0.20 |
|
|
| @property |
| def n_extract_layers(self): |
| return len(self.extract_layers) |
|
|
| @property |
| def depth_profile_dim(self): |
| return self.n_extract_layers * self.clip_hidden |
|
|
|
|
| |
| |
| |
|
|
| def cayley_menger_vol2(pts): |
| with torch.amp.autocast("cuda", enabled=False): |
| pts = pts.float() |
| diff = pts.unsqueeze(-2) - pts.unsqueeze(-3) |
| d2 = (diff * diff).sum(-1) |
| B, V, _ = d2.shape |
| cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32) |
| cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 |
| s = (-1.0)**V; f = math.factorial(V-1) |
| return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm) |
|
|
|
|
| def pentachoron_cv(embeddings, n_samples=16): |
| B = embeddings.shape[0] |
| if B < 5: |
| return torch.tensor(0.0, device=embeddings.device) |
| vols = [] |
| for _ in range(n_samples): |
| idx = torch.randperm(B, device=embeddings.device)[:5] |
| v2 = cayley_menger_vol2(embeddings[idx].unsqueeze(0)) |
| vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12)) |
| stacked = torch.stack(vols) |
| return stacked.std() / (stacked.mean() + 1e-8) |
|
|
|
|
| def procrustes_alignment_loss(emb_a, emb_b): |
| """ |
| Differentiable Procrustes regularizer. |
| Measures rotational alignability β NOT the alignment force (that's InfoNCE). |
| Kept as regularizer per Phil's correction: it shapes geometry even if |
| it doesn't create alignment by itself. |
| """ |
| with torch.amp.autocast("cuda", enabled=False): |
| A = F.normalize(emb_a.float(), dim=-1) |
| B_emb = F.normalize(emb_b.float(), dim=-1) |
| A = A - A.mean(0, keepdim=True) |
| B_emb = B_emb - B_emb.mean(0, keepdim=True) |
| S = torch.linalg.svdvals(A.T @ B_emb) |
| N, D = A.shape |
| return 1.0 - S.sum() / (math.sqrt(N) * D) |
|
|
|
|
| |
| |
| |
|
|
| class GeometricMemoryBank(nn.Module): |
| """ |
| Memory bank for CLIP text encoder. |
| Stores depth-profile anchors from each text segment. |
| Memory tokens query the bank via cross-attention. |
| """ |
| def __init__(self, config: MemoryCLIPConfig): |
| super().__init__() |
| self.config = config |
| self.max_size = config.bank_size |
| self.dim = config.anchor_dim |
|
|
| |
| depth_dim = config.depth_profile_dim |
| self.depth_compressor = nn.Sequential( |
| nn.Linear(depth_dim, config.clip_hidden * 2), |
| nn.GELU(), |
| nn.LayerNorm(config.clip_hidden * 2), |
| nn.Linear(config.clip_hidden * 2, config.anchor_dim), |
| ) |
|
|
| |
| self.temporal_proj = nn.Linear(1, config.anchor_dim, bias=False) |
|
|
| |
| self.cross_attn = nn.ModuleList([ |
| nn.MultiheadAttention(config.clip_hidden, config.n_bank_heads, |
| batch_first=True, dropout=0.1) |
| for _ in range(config.bank_cross_layers) |
| ]) |
| self.cross_norms = nn.ModuleList([ |
| nn.LayerNorm(config.clip_hidden) |
| for _ in range(config.bank_cross_layers) |
| ]) |
| self.cross_ffns = nn.ModuleList([ |
| nn.Sequential( |
| nn.Linear(config.clip_hidden, config.clip_hidden * 2), |
| nn.GELU(), |
| nn.Linear(config.clip_hidden * 2, config.clip_hidden)) |
| for _ in range(config.bank_cross_layers) |
| ]) |
| self.ffn_norms = nn.ModuleList([ |
| nn.LayerNorm(config.clip_hidden) |
| for _ in range(config.bank_cross_layers) |
| ]) |
|
|
| def init_bank(self, batch_size, device): |
| return {"anchors": torch.zeros(batch_size, 0, self.dim, device=device), |
| "n_written": 0} |
|
|
| def write(self, bank, depth_cls, segment_idx=0): |
| B = depth_cls.shape[0] |
| anchor = self.depth_compressor(depth_cls.reshape(B, -1)) |
| anchor = F.normalize(anchor, dim=-1) |
|
|
| t = torch.tensor([[segment_idx]], dtype=anchor.dtype, device=anchor.device) |
| anchor = anchor + 0.1 * self.temporal_proj(t / max(self.max_size, 1)) |
| anchor = F.normalize(anchor, dim=-1) |
|
|
| anchors = torch.cat([bank["anchors"], anchor.unsqueeze(1)], dim=1) |
| if anchors.shape[1] > self.max_size: |
| anchors = anchors[:, -self.max_size:] |
|
|
| return {"anchors": anchors, |
| "n_written": bank["n_written"] + 1, |
| "live_anchor": anchor} |
|
|
| def read(self, memory_tokens, bank): |
| anchors = bank["anchors"] |
| if anchors.shape[1] == 0: |
| return memory_tokens |
| x = memory_tokens |
| for attn, norm, ffn, ffn_norm in zip( |
| self.cross_attn, self.cross_norms, |
| self.cross_ffns, self.ffn_norms): |
| residual = x |
| x, _ = attn(norm(x), anchors, anchors) |
| x = residual + x |
| residual = x |
| x = residual + ffn(ffn_norm(x)) |
| return x |
|
|
|
|
| |
| |
| |
|
|
| class DeltaMemoryGate(nn.Module): |
| def __init__(self, config: MemoryCLIPConfig): |
| super().__init__() |
| H = config.clip_hidden |
| self.reset_proj = nn.Linear(H * 2, H) |
| self.update_proj = nn.Linear(H * 2, H) |
| self.candidate_proj = nn.Linear(H * 2, H) |
| self.norm = nn.LayerNorm(H) |
|
|
| def forward(self, old, new): |
| cat = torch.cat([old, new], dim=-1) |
| r = torch.sigmoid(self.reset_proj(cat)) |
| z = torch.sigmoid(self.update_proj(cat)) |
| h = torch.tanh(self.candidate_proj(torch.cat([r * old, new], dim=-1))) |
| return self.norm(z * old + (1 - z) * h) |
|
|
|
|
| |
| |
| |
|
|
| class LayerFusion(nn.Module): |
| def __init__(self, config: MemoryCLIPConfig): |
| super().__init__() |
| n = config.n_extract_layers |
| self.weights = nn.Parameter(torch.ones(n) / n) |
| self.proj = nn.Linear(config.clip_hidden, config.clip_hidden) |
| self.norm = nn.LayerNorm(config.clip_hidden) |
|
|
| def forward(self, layer_outputs): |
| w = F.softmax(self.weights, dim=0) |
| stacked = torch.stack(layer_outputs) |
| fused = (stacked * w.view(-1, 1, 1, 1)).sum(0) |
| return self.norm(self.proj(fused)) |
|
|
|
|
| |
| |
| |
|
|
| class TeacherProjector(nn.Module): |
| """Projects student (768) β teacher (1024). Initialized from Procrustes.""" |
| def __init__(self, student_dim, teacher_dim, name=""): |
| super().__init__() |
| self.name = name |
| self.proj = nn.Linear(student_dim, teacher_dim, bias=True) |
| nn.init.eye_(self.proj.weight[:student_dim, :student_dim]) |
| nn.init.zeros_(self.proj.bias) |
|
|
| def forward(self, x): |
| return self.proj(x) |
|
|
| @torch.no_grad() |
| def init_from_procrustes(self, rotation, student_mean, teacher_mean): |
| """ |
| rotation: (teacher_dim, padded_student_dim) β from padded SVD |
| proj.weight: (teacher_dim, student_dim) β the actual linear layer |
| |
| Slice rotation to match: R[:, :student_dim] |
| """ |
| with torch.no_grad(): |
| student_dim = self.proj.in_features |
| teacher_dim = self.proj.out_features |
| |
| |
| R_sliced = rotation[:teacher_dim, :student_dim] |
| self.proj.weight.data.copy_(R_sliced) |
| |
| self.proj.bias.data.copy_(teacher_mean[:teacher_dim] - rotation[:teacher_dim] @ student_mean) |
| print(f" [{self.name}] Procrustes init: |R|={R_sliced.norm():.3f}") |
|
|
|
|
| |
| |
| |
|
|
| class MemoryExtendedCLIP(nn.Module): |
| """ |
| CLIP-ViT-L/14 text encoder + geometric memory system. |
| |
| Forward processes one segment at a time: |
| 1. Bank READ: memory tokens query past anchors |
| 2. Prepend memory tokens to segment tokens |
| 3. CLIP text encoder forward (frozen, causal attention) |
| 4. Multi-layer extraction + fusion |
| 5. GRU gate updates memory state |
| 6. Depth profile β anchor β bank WRITE |
| 7. Output: fused CLS/EOS embedding |
| |
| The causal attention in CLIP means memory tokens at positions 0-7 |
| can attend to each other but text tokens can only attend to |
| memory + preceding text. This naturally lets text "read" from memory. |
| """ |
| def __init__(self, config: MemoryCLIPConfig): |
| super().__init__() |
| self.config = config |
|
|
| |
| self.clip_text = None |
| self.clip_tokenizer = None |
|
|
| |
| self.memory_embeddings = nn.Parameter( |
| torch.randn(1, config.n_memory_tokens, config.clip_hidden) * 0.02) |
| self.layer_fusion = LayerFusion(config) |
| self.bank = GeometricMemoryBank(config) |
| self.gate = DeltaMemoryGate(config) |
|
|
| |
| self.output_proj = nn.Sequential( |
| nn.Linear(config.clip_hidden, config.clip_hidden), |
| nn.GELU(), nn.LayerNorm(config.clip_hidden)) |
| self.memory_output_fusion = nn.Sequential( |
| nn.Linear(config.clip_hidden * 2, config.clip_hidden), |
| nn.GELU(), |
| nn.Linear(config.clip_hidden, config.clip_hidden)) |
|
|
| |
| |
| self.clip_cross_attn = nn.ModuleList([ |
| nn.MultiheadAttention(config.clip_hidden, config.n_bank_heads, |
| batch_first=True, dropout=0.1) |
| for _ in range(config.bank_cross_layers) |
| ]) |
| self.clip_cross_norms = nn.ModuleList([ |
| nn.LayerNorm(config.clip_hidden) |
| for _ in range(config.bank_cross_layers) |
| ]) |
| self.clip_cross_ffns = nn.ModuleList([ |
| nn.Sequential( |
| nn.Linear(config.clip_hidden, config.clip_hidden * 2), |
| nn.GELU(), |
| nn.Linear(config.clip_hidden * 2, config.clip_hidden)) |
| for _ in range(config.bank_cross_layers) |
| ]) |
| self.clip_cross_ffn_norms = nn.ModuleList([ |
| nn.LayerNorm(config.clip_hidden) |
| for _ in range(config.bank_cross_layers) |
| ]) |
|
|
| |
| self.proj_modern = TeacherProjector( |
| config.clip_hidden, config.teacher_hidden, "ModernBERT") |
|
|
| def setup(self, device): |
| """Load CLIP text encoder. Call once before training.""" |
| from transformers import CLIPTextModel, CLIPTokenizer |
| print(f" Loading CLIP text encoder: {self.config.clip_model}") |
| self.clip_text = CLIPTextModel.from_pretrained(self.config.clip_model).to(device) |
| self.clip_text.config.output_hidden_states = True |
| if self.config.freeze_clip: |
| for p in self.clip_text.parameters(): |
| p.requires_grad = False |
| self.clip_tokenizer = CLIPTokenizer.from_pretrained(self.config.clip_model) |
|
|
| n_train = sum(p.numel() for p in self.parameters() if p.requires_grad) |
| n_frozen = sum(p.numel() for p in self.parameters() if not p.requires_grad) |
| print(f" CLIP: {n_frozen:,} frozen") |
| print(f" Memory + projector: {n_train:,} trainable") |
| print(f" Extract: {self.config.extract_layers}") |
| print(f" Bank: {self.config.bank_size} anchors, " |
| f"{self.config.bank_cross_layers} cross-attn") |
| print(f" Memory: {self.config.n_memory_tokens} tokens") |
| print(f" Effective context: {self.config.max_segments} Γ " |
| f"{self.config.max_content_tokens} = " |
| f"{self.config.max_segments * self.config.max_content_tokens} tokens") |
|
|
| def init_state(self, batch_size, device=None): |
| if device is None: |
| device = self.memory_embeddings.device |
| return { |
| "memory": self.memory_embeddings.expand(batch_size, -1, -1).clone(), |
| "bank": self.bank.init_bank(batch_size, device), |
| "segment_idx": 0, |
| } |
|
|
| def forward(self, input_ids, attention_mask, state): |
| """ |
| Process one text segment through CLIP + memory. |
| |
| Architecture: |
| 1. CLIP text encoder runs NORMALLY on truncated 77-token input |
| 2. Memory tokens cross-attend to CLIP's hidden states (decoder-style) |
| 3. Bank stores depth-profile anchors |
| 4. GRU gate updates memory state |
| |
| This avoids injecting tokens into CLIP's causal attention, |
| which breaks its internal mask/position handling. |
| """ |
| B = input_ids.shape[0] |
| device = input_ids.device |
| n_mem = self.config.n_memory_tokens |
|
|
| memory_state = state["memory"] |
| bank = state["bank"] |
| seg_idx = state["segment_idx"] |
|
|
| |
| memory_tokens = self.bank.read(memory_state, bank) |
|
|
| |
| |
| max_len = self.config.clip_max_tokens |
| clip_ids = input_ids[:, :max_len] |
| clip_mask = attention_mask[:, :max_len] |
|
|
| with torch.no_grad(): |
| clip_out = self.clip_text( |
| input_ids=clip_ids, |
| attention_mask=clip_mask, |
| output_hidden_states=True, |
| return_dict=True) |
|
|
| |
| all_hiddens = clip_out.hidden_states |
| clip_seq_len = all_hiddens[0].shape[1] |
|
|
| |
| selected = [all_hiddens[i + 1] for i in self.config.extract_layers] |
| fused = self.layer_fusion(selected) |
|
|
| |
| |
| mem_enriched = memory_tokens |
| for attn, norm, ffn, ffn_norm in zip( |
| self.clip_cross_attn, self.clip_cross_norms, |
| self.clip_cross_ffns, self.clip_cross_ffn_norms): |
| residual = mem_enriched |
| mem_enriched_normed = norm(mem_enriched) |
| mem_enriched, _ = attn(mem_enriched_normed, fused, fused) |
| mem_enriched = residual + mem_enriched |
| residual = mem_enriched |
| mem_enriched = residual + ffn(ffn_norm(mem_enriched)) |
|
|
| |
| |
| depth_cls = torch.stack([h[:, 1, :] for h in selected], dim=1) |
|
|
| |
| new_memory = self.gate(memory_state, mem_enriched) |
|
|
| |
| new_bank = self.bank.write(bank, depth_cls, seg_idx) |
|
|
| |
| |
| clip_pooled = clip_out.pooler_output |
| if clip_pooled is None: |
| |
| clip_pooled = clip_out.last_hidden_state[:, -1, :] |
|
|
| cls_output = self.output_proj(clip_pooled) |
| memory_delta = self.memory_output_fusion( |
| torch.cat([cls_output, new_memory.mean(dim=1)], dim=-1)) |
| fused_output = cls_output + memory_delta |
|
|
| outputs = { |
| "memory_output": fused_output, |
| "cls_output": cls_output, |
| "live_anchor": new_bank["live_anchor"], |
| "depth_cls": depth_cls, |
| "content_output": fused, |
| "clip_pooled": clip_pooled, |
| } |
|
|
| new_state = { |
| "memory": new_memory, |
| "bank": {"anchors": new_bank["anchors"], |
| "n_written": new_bank["n_written"], |
| "live_anchor": new_bank["live_anchor"]}, |
| "segment_idx": seg_idx + 1, |
| } |
| return outputs, new_state |
|
|
| @staticmethod |
| def detach_state(state): |
| return { |
| "memory": state["memory"].detach(), |
| "bank": {"anchors": state["bank"]["anchors"].detach(), |
| "n_written": state["bank"]["n_written"]}, |
| "segment_idx": state["segment_idx"], |
| } |
|
|
| def get_sequence_output(self, input_ids, attention_mask, state): |
| """ |
| Get full sequence hidden states (for SD cross-attention compatibility). |
| Returns (B, total_seq_len, 768) across all segments processed. |
| """ |
| outputs, new_state = self.forward(input_ids, attention_mask, state) |
| |
| return outputs["content_output"], new_state |
|
|
| def num_trainable_params(self): |
| return sum(p.numel() for p in self.parameters() if p.requires_grad) |
|
|
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def compute_static_procrustes(student_embs, teacher_embs): |
| X = student_embs.float() |
| Y = teacher_embs.float() |
| mu_x, mu_y = X.mean(0), Y.mean(0) |
| Xc, Yc = X - mu_x, Y - mu_y |
|
|
| |
| |
| if Xc.shape[1] < Yc.shape[1]: |
| pad = torch.zeros(Xc.shape[0], Yc.shape[1] - Xc.shape[1], |
| device=Xc.device) |
| Xc_padded = torch.cat([Xc, pad], dim=1) |
| mu_x_padded = torch.cat([mu_x, torch.zeros(Yc.shape[1] - Xc.shape[1], |
| device=mu_x.device)]) |
| else: |
| Xc_padded = Xc |
| mu_x_padded = mu_x |
|
|
| U, S, Vt = torch.linalg.svd(Xc_padded.T @ Yc) |
| R = (U @ Vt).T |
|
|
| cos_before = F.cosine_similarity(Xc_padded, Yc, dim=-1).mean() |
| cos_after = F.cosine_similarity((Xc_padded @ R.T), Yc, dim=-1).mean() |
| print(f" Procrustes: cos {cos_before:.4f} β {cos_after:.4f}") |
| return R, mu_x_padded, mu_y |
|
|
|
|
| |
| |
| |
|
|
| def segment_text(text, clip_tokenizer, max_content=18, overlap=4, max_segments=32): |
| """ |
| Segment long text into CLIP-compatible chunks. |
| Each chunk: [SOS] + content_tokens + [EOS] + [PAD...] |
| Returns list of (input_ids, attention_mask) tensors. |
| """ |
| |
| full_tokens = clip_tokenizer.encode(text, add_special_tokens=False) |
|
|
| segments = [] |
| stride = max_content - overlap |
| pos = 0 |
|
|
| while pos < len(full_tokens) and len(segments) < max_segments: |
| end = min(pos + max_content, len(full_tokens)) |
| chunk = full_tokens[pos:end] |
|
|
| |
| sos = clip_tokenizer.bos_token_id or 49406 |
| eos = clip_tokenizer.eos_token_id or 49407 |
|
|
| input_ids = [sos] + chunk + [eos] |
| n_pad = 77 - len(input_ids) |
| if n_pad > 0: |
| input_ids = input_ids + [0] * n_pad |
| else: |
| input_ids = input_ids[:77] |
|
|
| mask = [1] * min(len(chunk) + 2, 77) + [0] * max(n_pad, 0) |
| mask = mask[:77] |
|
|
| segments.append({ |
| "input_ids": torch.tensor(input_ids, dtype=torch.long), |
| "attention_mask": torch.tensor(mask, dtype=torch.long), |
| }) |
|
|
| if end >= len(full_tokens): |
| break |
| pos += stride |
|
|
| return segments |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| print("=" * 70) |
| print("MEMORY-EXTENDED CLIP-L TEXT ENCODER") |
| print("=" * 70) |
|
|
| config = MemoryCLIPConfig() |
| model = MemoryExtendedCLIP(config) |
|
|
| |
| comps = { |
| "memory_embeddings": model.memory_embeddings.numel(), |
| "layer_fusion": sum(p.numel() for p in model.layer_fusion.parameters()), |
| "bank.depth_compressor": sum(p.numel() for p in model.bank.depth_compressor.parameters()), |
| "bank.temporal_proj": sum(p.numel() for p in model.bank.temporal_proj.parameters()), |
| "bank.cross_attn": sum(p.numel() for p in model.bank.cross_attn.parameters()), |
| "bank.cross_ffns": sum(p.numel() for p in model.bank.cross_ffns.parameters()), |
| "gate": sum(p.numel() for p in model.gate.parameters()), |
| "output_proj": sum(p.numel() for p in model.output_proj.parameters()), |
| "memory_output_fusion": sum(p.numel() for p in model.memory_output_fusion.parameters()), |
| "proj_modern": sum(p.numel() for p in model.proj_modern.parameters()), |
| } |
| print(f"\n Memory system components:") |
| for k, v in comps.items(): |
| print(f" {k:30s}: {v:,}") |
| total = sum(comps.values()) |
| print(f" {'TOTAL':30s}: {total:,}") |
|
|
| print(f"\n Config:") |
| print(f" CLIP: {config.clip_model}") |
| print(f" Hidden: {config.clip_hidden}") |
| print(f" Window: {config.clip_max_tokens} tokens") |
| print(f" Memory tokens: {config.n_memory_tokens}") |
| print(f" Segments: max {config.max_segments} Γ {config.max_content_tokens} " |
| f"= {config.max_segments * config.max_content_tokens} tokens") |
| print(f" Bank: {config.bank_size} anchors") |
| print(f" Teacher: {config.teacher_model} ({config.teacher_hidden}-dim)") |
|
|
| |
| print(f"\n Testing segmentation...") |
| from transformers import CLIPTokenizer |
| tok = CLIPTokenizer.from_pretrained(config.clip_model) |
| long_text = ("A vast sweeping landscape of rolling green hills under dramatic " |
| "storm clouds with a lone oak tree in the foreground its branches " |
| "bent by wind casting long shadows across a field of wildflowers " |
| "in purple yellow and white while in the distance a medieval stone " |
| "castle sits atop a cliff overlooking a turbulent sea with waves " |
| "crashing against ancient rocks and seabirds wheeling overhead " * 3) |
| segments = segment_text(long_text, tok, config.max_content_tokens, |
| config.segment_overlap, config.max_segments) |
| print(f" Text length: {len(long_text)} chars") |
| print(f" Full tokens: {len(tok.encode(long_text, add_special_tokens=False))}") |
| print(f" Segments: {len(segments)}") |
| for i, seg in enumerate(segments[:3]): |
| n_real = seg["attention_mask"].sum().item() |
| print(f" Seg {i}: {n_real} tokens (of {seg['input_ids'].shape[0]})") |
|
|
| print(f"\nReady for training. Run setup(device) to load CLIP.") |