geolip-clip-vit-large-patch14-ctx576 / cell1_trainer_model.py
AbstractPhil's picture
Rename model.py to cell1_trainer_model.py
6b607e6 verified
# ============================================================================
# MEMORY-EXTENDED CLIP-L TEXT ENCODER
#
# Extends CLIP-ViT-L/14 text encoder from 77 tokens to ~2048 effective
# context via geometric memory bank, taught by ModernBERT-large.
#
# Architecture:
# CLIP text encoder (frozen, 768-dim, 77 ctx, causal attn, 12 layers)
# + Geometric memory system (trainable, ~25M)
# Taught by ModernBERT-large (frozen, 1024-dim, 8192 ctx)
#
# Key constraints:
# - Output must remain CLIP-compatible (768-dim, vision-aligned)
# - CLIP uses CAUSAL attention (not bidirectional like BERT)
# - 77-token window is much tighter than BERT's 512
# - Usable tokens per segment: ~60 (77 - 8 memory - SOS - EOS)
#
# Losses:
# L_modern: InfoNCE(proj_modern(student), teacher_cls) Γ— 1.0
# L_clip: InfoNCE(student, clip_full_text) Γ— 0.5
# L_procrustes: Procrustes alignment regularizer Γ— 0.3
# L_cv: |pentachoron_cv(anchors) - 0.20| Γ— 0.05
#
# At inference: CLIP + memory only. No ModernBERT needed.
# ============================================================================
import math
import os
import json
import time
from dataclasses import dataclass, field, asdict
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
# ══════════════════════════════════════════════════════════════════
# CONFIG
# ══════════════════════════════════════════════════════════════════
@dataclass
class MemoryCLIPConfig:
# CLIP text encoder
clip_model: str = "openai/clip-vit-large-patch14"
clip_hidden: int = 768
clip_layers: int = 12
clip_max_tokens: int = 77
freeze_clip: bool = True
# Memory system
n_memory_tokens: int = 8 # prepended to each segment (tighter window than BERT)
bank_size: int = 64 # max anchors (captions shorter than documents)
anchor_dim: int = 768
n_bank_heads: int = 8
bank_cross_layers: int = 2
# Gate
gate_type: str = "gru"
# Multi-layer extraction
# CLIP-L has 12 layers. Extract from 6 spread across depth.
extract_layers: Tuple[int, ...] = (1, 3, 5, 7, 9, 11)
layer_fusion: str = "learned"
# Segment processing
# Smaller segments = more anchors per caption = CV can activate
# 96-token avg caption / stride 14 = ~7 segments (need β‰₯5 for pentachoron)
max_content_tokens: int = 18
segment_overlap: int = 4
max_segments: int = 32 # up to 32 Γ— 18 = 576 tokens effective
# Teacher
teacher_model: str = "answerdotai/ModernBERT-large"
teacher_hidden: int = 1024
teacher_max_len: int = 4096 # conservative, fits most captions
# Geometric
cv_target: float = 0.20
@property
def n_extract_layers(self):
return len(self.extract_layers)
@property
def depth_profile_dim(self):
return self.n_extract_layers * self.clip_hidden
# ══════════════════════════════════════════════════════════════════
# GEOMETRIC UTILITIES
# ══════════════════════════════════════════════════════════════════
def cayley_menger_vol2(pts):
with torch.amp.autocast("cuda", enabled=False):
pts = pts.float()
diff = pts.unsqueeze(-2) - pts.unsqueeze(-3)
d2 = (diff * diff).sum(-1)
B, V, _ = d2.shape
cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32)
cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
s = (-1.0)**V; f = math.factorial(V-1)
return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm)
def pentachoron_cv(embeddings, n_samples=16):
B = embeddings.shape[0]
if B < 5:
return torch.tensor(0.0, device=embeddings.device)
vols = []
for _ in range(n_samples):
idx = torch.randperm(B, device=embeddings.device)[:5]
v2 = cayley_menger_vol2(embeddings[idx].unsqueeze(0))
vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12))
stacked = torch.stack(vols)
return stacked.std() / (stacked.mean() + 1e-8)
def procrustes_alignment_loss(emb_a, emb_b):
"""
Differentiable Procrustes regularizer.
Measures rotational alignability β€” NOT the alignment force (that's InfoNCE).
Kept as regularizer per Phil's correction: it shapes geometry even if
it doesn't create alignment by itself.
"""
with torch.amp.autocast("cuda", enabled=False):
A = F.normalize(emb_a.float(), dim=-1)
B_emb = F.normalize(emb_b.float(), dim=-1)
A = A - A.mean(0, keepdim=True)
B_emb = B_emb - B_emb.mean(0, keepdim=True)
S = torch.linalg.svdvals(A.T @ B_emb)
N, D = A.shape
return 1.0 - S.sum() / (math.sqrt(N) * D)
# ══════════════════════════════════════════════════════════════════
# GEOMETRIC MEMORY BANK
# ══════════════════════════════════════════════════════════════════
class GeometricMemoryBank(nn.Module):
"""
Memory bank for CLIP text encoder.
Stores depth-profile anchors from each text segment.
Memory tokens query the bank via cross-attention.
"""
def __init__(self, config: MemoryCLIPConfig):
super().__init__()
self.config = config
self.max_size = config.bank_size
self.dim = config.anchor_dim
# Depth-profile compressor: 6Γ—768=4608 β†’ 768
depth_dim = config.depth_profile_dim
self.depth_compressor = nn.Sequential(
nn.Linear(depth_dim, config.clip_hidden * 2),
nn.GELU(),
nn.LayerNorm(config.clip_hidden * 2),
nn.Linear(config.clip_hidden * 2, config.anchor_dim),
)
# Temporal encoding
self.temporal_proj = nn.Linear(1, config.anchor_dim, bias=False)
# Cross-attention: memory tokens (Q) attend to bank anchors (K, V)
self.cross_attn = nn.ModuleList([
nn.MultiheadAttention(config.clip_hidden, config.n_bank_heads,
batch_first=True, dropout=0.1)
for _ in range(config.bank_cross_layers)
])
self.cross_norms = nn.ModuleList([
nn.LayerNorm(config.clip_hidden)
for _ in range(config.bank_cross_layers)
])
self.cross_ffns = nn.ModuleList([
nn.Sequential(
nn.Linear(config.clip_hidden, config.clip_hidden * 2),
nn.GELU(),
nn.Linear(config.clip_hidden * 2, config.clip_hidden))
for _ in range(config.bank_cross_layers)
])
self.ffn_norms = nn.ModuleList([
nn.LayerNorm(config.clip_hidden)
for _ in range(config.bank_cross_layers)
])
def init_bank(self, batch_size, device):
return {"anchors": torch.zeros(batch_size, 0, self.dim, device=device),
"n_written": 0}
def write(self, bank, depth_cls, segment_idx=0):
B = depth_cls.shape[0]
anchor = self.depth_compressor(depth_cls.reshape(B, -1))
anchor = F.normalize(anchor, dim=-1)
t = torch.tensor([[segment_idx]], dtype=anchor.dtype, device=anchor.device)
anchor = anchor + 0.1 * self.temporal_proj(t / max(self.max_size, 1))
anchor = F.normalize(anchor, dim=-1)
anchors = torch.cat([bank["anchors"], anchor.unsqueeze(1)], dim=1)
if anchors.shape[1] > self.max_size:
anchors = anchors[:, -self.max_size:]
return {"anchors": anchors,
"n_written": bank["n_written"] + 1,
"live_anchor": anchor}
def read(self, memory_tokens, bank):
anchors = bank["anchors"]
if anchors.shape[1] == 0:
return memory_tokens
x = memory_tokens
for attn, norm, ffn, ffn_norm in zip(
self.cross_attn, self.cross_norms,
self.cross_ffns, self.ffn_norms):
residual = x
x, _ = attn(norm(x), anchors, anchors)
x = residual + x
residual = x
x = residual + ffn(ffn_norm(x))
return x
# ══════════════════════════════════════════════════════════════════
# DELTA MEMORY GATE
# ══════════════════════════════════════════════════════════════════
class DeltaMemoryGate(nn.Module):
def __init__(self, config: MemoryCLIPConfig):
super().__init__()
H = config.clip_hidden
self.reset_proj = nn.Linear(H * 2, H)
self.update_proj = nn.Linear(H * 2, H)
self.candidate_proj = nn.Linear(H * 2, H)
self.norm = nn.LayerNorm(H)
def forward(self, old, new):
cat = torch.cat([old, new], dim=-1)
r = torch.sigmoid(self.reset_proj(cat))
z = torch.sigmoid(self.update_proj(cat))
h = torch.tanh(self.candidate_proj(torch.cat([r * old, new], dim=-1)))
return self.norm(z * old + (1 - z) * h)
# ══════════════════════════════════════════════════════════════════
# MULTI-LAYER FUSION
# ══════════════════════════════════════════════════════════════════
class LayerFusion(nn.Module):
def __init__(self, config: MemoryCLIPConfig):
super().__init__()
n = config.n_extract_layers
self.weights = nn.Parameter(torch.ones(n) / n)
self.proj = nn.Linear(config.clip_hidden, config.clip_hidden)
self.norm = nn.LayerNorm(config.clip_hidden)
def forward(self, layer_outputs):
w = F.softmax(self.weights, dim=0)
stacked = torch.stack(layer_outputs)
fused = (stacked * w.view(-1, 1, 1, 1)).sum(0)
return self.norm(self.proj(fused))
# ══════════════════════════════════════════════════════════════════
# TEACHER PROJECTOR
# ══════════════════════════════════════════════════════════════════
class TeacherProjector(nn.Module):
"""Projects student (768) β†’ teacher (1024). Initialized from Procrustes."""
def __init__(self, student_dim, teacher_dim, name=""):
super().__init__()
self.name = name
self.proj = nn.Linear(student_dim, teacher_dim, bias=True)
nn.init.eye_(self.proj.weight[:student_dim, :student_dim])
nn.init.zeros_(self.proj.bias)
def forward(self, x):
return self.proj(x)
@torch.no_grad()
def init_from_procrustes(self, rotation, student_mean, teacher_mean):
"""
rotation: (teacher_dim, padded_student_dim) β€” from padded SVD
proj.weight: (teacher_dim, student_dim) β€” the actual linear layer
Slice rotation to match: R[:, :student_dim]
"""
with torch.no_grad():
student_dim = self.proj.in_features
teacher_dim = self.proj.out_features
# R is (teacher_dim, padded_dim) where padded_dim >= student_dim
# Take only the student_dim columns (rest were zero-padded)
R_sliced = rotation[:teacher_dim, :student_dim]
self.proj.weight.data.copy_(R_sliced)
# Bias: teacher_mean - R @ student_mean (both are teacher_dim)
self.proj.bias.data.copy_(teacher_mean[:teacher_dim] - rotation[:teacher_dim] @ student_mean)
print(f" [{self.name}] Procrustes init: |R|={R_sliced.norm():.3f}")
# ══════════════════════════════════════════════════════════════════
# MEMORY-EXTENDED CLIP TEXT ENCODER
# ══════════════════════════════════════════════════════════════════
class MemoryExtendedCLIP(nn.Module):
"""
CLIP-ViT-L/14 text encoder + geometric memory system.
Forward processes one segment at a time:
1. Bank READ: memory tokens query past anchors
2. Prepend memory tokens to segment tokens
3. CLIP text encoder forward (frozen, causal attention)
4. Multi-layer extraction + fusion
5. GRU gate updates memory state
6. Depth profile β†’ anchor β†’ bank WRITE
7. Output: fused CLS/EOS embedding
The causal attention in CLIP means memory tokens at positions 0-7
can attend to each other but text tokens can only attend to
memory + preceding text. This naturally lets text "read" from memory.
"""
def __init__(self, config: MemoryCLIPConfig):
super().__init__()
self.config = config
# Lazy-loaded in setup() to avoid import issues in syntax check
self.clip_text = None
self.clip_tokenizer = None
# Memory system
self.memory_embeddings = nn.Parameter(
torch.randn(1, config.n_memory_tokens, config.clip_hidden) * 0.02)
self.layer_fusion = LayerFusion(config)
self.bank = GeometricMemoryBank(config)
self.gate = DeltaMemoryGate(config)
# Output
self.output_proj = nn.Sequential(
nn.Linear(config.clip_hidden, config.clip_hidden),
nn.GELU(), nn.LayerNorm(config.clip_hidden))
self.memory_output_fusion = nn.Sequential(
nn.Linear(config.clip_hidden * 2, config.clip_hidden),
nn.GELU(),
nn.Linear(config.clip_hidden, config.clip_hidden))
# Cross-attention: memory tokens attend to CLIP's hidden states
# (separate from bank's cross-attn which attends to past anchors)
self.clip_cross_attn = nn.ModuleList([
nn.MultiheadAttention(config.clip_hidden, config.n_bank_heads,
batch_first=True, dropout=0.1)
for _ in range(config.bank_cross_layers)
])
self.clip_cross_norms = nn.ModuleList([
nn.LayerNorm(config.clip_hidden)
for _ in range(config.bank_cross_layers)
])
self.clip_cross_ffns = nn.ModuleList([
nn.Sequential(
nn.Linear(config.clip_hidden, config.clip_hidden * 2),
nn.GELU(),
nn.Linear(config.clip_hidden * 2, config.clip_hidden))
for _ in range(config.bank_cross_layers)
])
self.clip_cross_ffn_norms = nn.ModuleList([
nn.LayerNorm(config.clip_hidden)
for _ in range(config.bank_cross_layers)
])
# Teacher projector: 768 β†’ 1024
self.proj_modern = TeacherProjector(
config.clip_hidden, config.teacher_hidden, "ModernBERT")
def setup(self, device):
"""Load CLIP text encoder. Call once before training."""
from transformers import CLIPTextModel, CLIPTokenizer
print(f" Loading CLIP text encoder: {self.config.clip_model}")
self.clip_text = CLIPTextModel.from_pretrained(self.config.clip_model).to(device)
self.clip_text.config.output_hidden_states = True
if self.config.freeze_clip:
for p in self.clip_text.parameters():
p.requires_grad = False
self.clip_tokenizer = CLIPTokenizer.from_pretrained(self.config.clip_model)
n_train = sum(p.numel() for p in self.parameters() if p.requires_grad)
n_frozen = sum(p.numel() for p in self.parameters() if not p.requires_grad)
print(f" CLIP: {n_frozen:,} frozen")
print(f" Memory + projector: {n_train:,} trainable")
print(f" Extract: {self.config.extract_layers}")
print(f" Bank: {self.config.bank_size} anchors, "
f"{self.config.bank_cross_layers} cross-attn")
print(f" Memory: {self.config.n_memory_tokens} tokens")
print(f" Effective context: {self.config.max_segments} Γ— "
f"{self.config.max_content_tokens} = "
f"{self.config.max_segments * self.config.max_content_tokens} tokens")
def init_state(self, batch_size, device=None):
if device is None:
device = self.memory_embeddings.device
return {
"memory": self.memory_embeddings.expand(batch_size, -1, -1).clone(),
"bank": self.bank.init_bank(batch_size, device),
"segment_idx": 0,
}
def forward(self, input_ids, attention_mask, state):
"""
Process one text segment through CLIP + memory.
Architecture:
1. CLIP text encoder runs NORMALLY on truncated 77-token input
2. Memory tokens cross-attend to CLIP's hidden states (decoder-style)
3. Bank stores depth-profile anchors
4. GRU gate updates memory state
This avoids injecting tokens into CLIP's causal attention,
which breaks its internal mask/position handling.
"""
B = input_ids.shape[0]
device = input_ids.device
n_mem = self.config.n_memory_tokens
memory_state = state["memory"]
bank = state["bank"]
seg_idx = state["segment_idx"]
# ── Bank read: enrich memory tokens from past anchors ──
memory_tokens = self.bank.read(memory_state, bank)
# ── CLIP forward (standard, frozen, 77-token max) ──
# Truncate to CLIP's max length
max_len = self.config.clip_max_tokens
clip_ids = input_ids[:, :max_len]
clip_mask = attention_mask[:, :max_len]
with torch.no_grad():
clip_out = self.clip_text(
input_ids=clip_ids,
attention_mask=clip_mask,
output_hidden_states=True,
return_dict=True)
# all_hiddens: tuple of (B, seq, 768), length = n_layers + 1
all_hiddens = clip_out.hidden_states
clip_seq_len = all_hiddens[0].shape[1]
# ── Multi-layer extraction from CLIP ──
selected = [all_hiddens[i + 1] for i in self.config.extract_layers]
fused = self.layer_fusion(selected) # (B, clip_seq_len, 768)
# ── Memory tokens cross-attend to fused CLIP output ──
# Decoder-style: memory tokens (Q) attend to CLIP output (K, V)
mem_enriched = memory_tokens
for attn, norm, ffn, ffn_norm in zip(
self.clip_cross_attn, self.clip_cross_norms,
self.clip_cross_ffns, self.clip_cross_ffn_norms):
residual = mem_enriched
mem_enriched_normed = norm(mem_enriched)
mem_enriched, _ = attn(mem_enriched_normed, fused, fused)
mem_enriched = residual + mem_enriched
residual = mem_enriched
mem_enriched = residual + ffn(ffn_norm(mem_enriched))
# ── Depth profile for bank anchor ──
# First real token (after SOS) from each extracted layer
depth_cls = torch.stack([h[:, 1, :] for h in selected], dim=1)
# ── GRU gate: update memory state ──
new_memory = self.gate(memory_state, mem_enriched)
# ── Bank write: store anchor ──
new_bank = self.bank.write(bank, depth_cls, seg_idx)
# ── Output: combine CLIP's EOS with memory delta ──
# CLIP's pooler_output is the EOS token after final projection
clip_pooled = clip_out.pooler_output # (B, 768) β€” EOS token
if clip_pooled is None:
# Fallback: use last_hidden_state at EOS position
clip_pooled = clip_out.last_hidden_state[:, -1, :]
cls_output = self.output_proj(clip_pooled)
memory_delta = self.memory_output_fusion(
torch.cat([cls_output, new_memory.mean(dim=1)], dim=-1))
fused_output = cls_output + memory_delta
outputs = {
"memory_output": fused_output, # (B, 768) β€” CLIP-compatible
"cls_output": cls_output,
"live_anchor": new_bank["live_anchor"],
"depth_cls": depth_cls,
"content_output": fused, # (B, clip_seq, 768)
"clip_pooled": clip_pooled,
}
new_state = {
"memory": new_memory,
"bank": {"anchors": new_bank["anchors"],
"n_written": new_bank["n_written"],
"live_anchor": new_bank["live_anchor"]},
"segment_idx": seg_idx + 1,
}
return outputs, new_state
@staticmethod
def detach_state(state):
return {
"memory": state["memory"].detach(),
"bank": {"anchors": state["bank"]["anchors"].detach(),
"n_written": state["bank"]["n_written"]},
"segment_idx": state["segment_idx"],
}
def get_sequence_output(self, input_ids, attention_mask, state):
"""
Get full sequence hidden states (for SD cross-attention compatibility).
Returns (B, total_seq_len, 768) across all segments processed.
"""
outputs, new_state = self.forward(input_ids, attention_mask, state)
# The content_output from fused layers is the per-token representation
return outputs["content_output"], new_state
def num_trainable_params(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)
# ══════════════════════════════════════════════════════════════════
# STATIC PROCRUSTES
# ══════════════════════════════════════════════════════════════════
@torch.no_grad()
def compute_static_procrustes(student_embs, teacher_embs):
X = student_embs.float()
Y = teacher_embs.float()
mu_x, mu_y = X.mean(0), Y.mean(0)
Xc, Yc = X - mu_x, Y - mu_y
# Handle dimension mismatch: student (768) β†’ teacher (1024)
# Pad student to teacher dim for SVD
if Xc.shape[1] < Yc.shape[1]:
pad = torch.zeros(Xc.shape[0], Yc.shape[1] - Xc.shape[1],
device=Xc.device)
Xc_padded = torch.cat([Xc, pad], dim=1)
mu_x_padded = torch.cat([mu_x, torch.zeros(Yc.shape[1] - Xc.shape[1],
device=mu_x.device)])
else:
Xc_padded = Xc
mu_x_padded = mu_x
U, S, Vt = torch.linalg.svd(Xc_padded.T @ Yc)
R = (U @ Vt).T # (teacher_dim, student_padded_dim)
cos_before = F.cosine_similarity(Xc_padded, Yc, dim=-1).mean()
cos_after = F.cosine_similarity((Xc_padded @ R.T), Yc, dim=-1).mean()
print(f" Procrustes: cos {cos_before:.4f} β†’ {cos_after:.4f}")
return R, mu_x_padded, mu_y
# ══════════════════════════════════════════════════════════════════
# TEXT SEGMENTATION
# ══════════════════════════════════════════════════════════════════
def segment_text(text, clip_tokenizer, max_content=18, overlap=4, max_segments=32):
"""
Segment long text into CLIP-compatible chunks.
Each chunk: [SOS] + content_tokens + [EOS] + [PAD...]
Returns list of (input_ids, attention_mask) tensors.
"""
# Tokenize full text (no truncation)
full_tokens = clip_tokenizer.encode(text, add_special_tokens=False)
segments = []
stride = max_content - overlap
pos = 0
while pos < len(full_tokens) and len(segments) < max_segments:
end = min(pos + max_content, len(full_tokens))
chunk = full_tokens[pos:end]
# Build CLIP-format input: [SOS] + chunk + [EOS] + padding
sos = clip_tokenizer.bos_token_id or 49406
eos = clip_tokenizer.eos_token_id or 49407
input_ids = [sos] + chunk + [eos]
n_pad = 77 - len(input_ids)
if n_pad > 0:
input_ids = input_ids + [0] * n_pad
else:
input_ids = input_ids[:77]
mask = [1] * min(len(chunk) + 2, 77) + [0] * max(n_pad, 0)
mask = mask[:77]
segments.append({
"input_ids": torch.tensor(input_ids, dtype=torch.long),
"attention_mask": torch.tensor(mask, dtype=torch.long),
})
if end >= len(full_tokens):
break
pos += stride
return segments
# ══════════════════════════════════════════════════════════════════
# SANITY CHECK
# ══════════════════════════════════════════════════════════════════
if __name__ == "__main__":
print("=" * 70)
print("MEMORY-EXTENDED CLIP-L TEXT ENCODER")
print("=" * 70)
config = MemoryCLIPConfig()
model = MemoryExtendedCLIP(config)
# Count params before loading CLIP
comps = {
"memory_embeddings": model.memory_embeddings.numel(),
"layer_fusion": sum(p.numel() for p in model.layer_fusion.parameters()),
"bank.depth_compressor": sum(p.numel() for p in model.bank.depth_compressor.parameters()),
"bank.temporal_proj": sum(p.numel() for p in model.bank.temporal_proj.parameters()),
"bank.cross_attn": sum(p.numel() for p in model.bank.cross_attn.parameters()),
"bank.cross_ffns": sum(p.numel() for p in model.bank.cross_ffns.parameters()),
"gate": sum(p.numel() for p in model.gate.parameters()),
"output_proj": sum(p.numel() for p in model.output_proj.parameters()),
"memory_output_fusion": sum(p.numel() for p in model.memory_output_fusion.parameters()),
"proj_modern": sum(p.numel() for p in model.proj_modern.parameters()),
}
print(f"\n Memory system components:")
for k, v in comps.items():
print(f" {k:30s}: {v:,}")
total = sum(comps.values())
print(f" {'TOTAL':30s}: {total:,}")
print(f"\n Config:")
print(f" CLIP: {config.clip_model}")
print(f" Hidden: {config.clip_hidden}")
print(f" Window: {config.clip_max_tokens} tokens")
print(f" Memory tokens: {config.n_memory_tokens}")
print(f" Segments: max {config.max_segments} Γ— {config.max_content_tokens} "
f"= {config.max_segments * config.max_content_tokens} tokens")
print(f" Bank: {config.bank_size} anchors")
print(f" Teacher: {config.teacher_model} ({config.teacher_hidden}-dim)")
# Test segmentation (doesn't need model loaded)
print(f"\n Testing segmentation...")
from transformers import CLIPTokenizer
tok = CLIPTokenizer.from_pretrained(config.clip_model)
long_text = ("A vast sweeping landscape of rolling green hills under dramatic "
"storm clouds with a lone oak tree in the foreground its branches "
"bent by wind casting long shadows across a field of wildflowers "
"in purple yellow and white while in the distance a medieval stone "
"castle sits atop a cliff overlooking a turbulent sea with waves "
"crashing against ancient rocks and seabirds wheeling overhead " * 3)
segments = segment_text(long_text, tok, config.max_content_tokens,
config.segment_overlap, config.max_segments)
print(f" Text length: {len(long_text)} chars")
print(f" Full tokens: {len(tok.encode(long_text, add_special_tokens=False))}")
print(f" Segments: {len(segments)}")
for i, seg in enumerate(segments[:3]):
n_real = seg["attention_mask"].sum().item()
print(f" Seg {i}: {n_real} tokens (of {seg['input_ids'].shape[0]})")
print(f"\nReady for training. Run setup(device) to load CLIP.")