global_fractal_router / 5clip_imagenet.py
AbstractPhil's picture
Create 5clip_imagenet.py
09b6e4d verified
"""
ImageNet Multi-CLIP Collective Experiment
==========================================
Uses pre-extracted CLIP features from multiple model variants.
No image processing - pure feature routing at A100 speeds.
Dataset: AbstractPhil/clip-imagenet-features
Streams: b32, b16, l14, laion_b32, laion_bigg14, laion_h14
Each CLIP variant becomes an expert stream with:
- Learnable translation head
- Own router with unique fingerprint
- Hierarchical coordination via mailbox
Training:
- AMP mixed precision
- 8 workers total, pinned, persistent
- Hierarchical chain topology
Author: AbstractPhil
Date: December 2025
License: Apache 2.0
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.cuda.amp import autocast, GradScaler
from datasets import load_dataset
from dataclasses import dataclass, field
from typing import Dict, Tuple, List, Optional
from collections import defaultdict
import numpy as np
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
# =============================================================================
# IMPORTS FROM GEOFRACTAL
# =============================================================================
from geofractal.model.blocks.router.global_fractal_router import (
GlobalFractalRouter,
GlobalFractalRouterConfig,
get_registry,
RouterMailbox,
)
# =============================================================================
# CONFIG
# =============================================================================
@dataclass
class ImageNetCollectiveConfig:
"""Configuration for ImageNet multi-CLIP collective."""
# Dataset
dataset_name: str = "AbstractPhil/imagenet-clip-features"
num_classes: int = 1000
# CLIP variants and their dimensions
clip_variants: Dict[str, int] = field(default_factory=lambda: {
'clip_vit_b32': 512,
'clip_vit_b16': 512,
'clip_vit_l14': 768,
'clip_vit_laion_b32': 512,
'clip_vit_laion_bigg14': 1280,
# 'clip_vit_laion_h14': 1024, # Can add if memory permits
})
# Feature dimensions
feature_dim: int = 512 # Internal routing dimension
fingerprint_dim: int = 64
# Router
num_anchors: int = 16
num_routes: int = 8
num_slots: int = 16 # Sequence length for routing
# Training
batch_size: int = 256
epochs: int = 20
lr: float = 3e-4
weight_decay: float = 0.01
warmup_epochs: int = 2
# DataLoader - A100 optimized
num_workers: int = 8 # Total across all loaders
pin_memory: bool = True
persistent_workers: bool = True
prefetch_factor: int = 4
# AMP
use_amp: bool = True
device: str = "cuda" if torch.cuda.is_available() else "cpu"
def workers_per_loader(self) -> int:
"""Distribute workers across loaders."""
n_loaders = len(self.clip_variants)
return max(1, self.num_workers // n_loaders)
# =============================================================================
# DATASET
# =============================================================================
class CLIPFeatureDataset(Dataset):
"""
Wraps HuggingFace dataset for a single CLIP variant.
Returns pre-extracted features and labels.
"""
def __init__(
self,
hf_dataset,
feature_column: str = 'clip_features',
label_column: str = 'label',
):
self.dataset = hf_dataset
self.feature_column = feature_column
self.label_column = label_column
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
item = self.dataset[idx]
features = torch.tensor(item[self.feature_column], dtype=torch.float32)
label = item[self.label_column]
return features, label
class MultiCLIPDataset(Dataset):
"""
Loads features from multiple CLIP variants simultaneously.
Returns dict of features + label.
"""
def __init__(
self,
dataset_name: str,
split_prefix: str, # e.g., 'train' or 'validation'
clip_variants: Dict[str, int],
):
self.variants = list(clip_variants.keys())
self.datasets = {}
print(f"Loading {split_prefix} splits...")
for variant in tqdm(self.variants, desc="Loading variants"):
split_name = f"{variant}_{split_prefix}"
try:
ds = load_dataset(dataset_name, split=split_name)
self.datasets[variant] = ds
print(f" {variant}: {len(ds):,} samples")
except Exception as e:
print(f" WARNING: Could not load {split_name}: {e}")
# Use first dataset for length (all should be same)
self.length = len(next(iter(self.datasets.values())))
# Verify all same length
for name, ds in self.datasets.items():
assert len(ds) == self.length, f"{name} has {len(ds)} != {self.length}"
def __len__(self):
return self.length
def __getitem__(self, idx):
features = {}
label = None
for variant, ds in self.datasets.items():
item = ds[idx]
features[variant] = torch.tensor(item['clip_features'], dtype=torch.float32)
if label is None:
label = item['label']
return features, label
def get_dataloaders(config: ImageNetCollectiveConfig):
"""Create train and validation dataloaders."""
train_dataset = MultiCLIPDataset(
config.dataset_name,
'train',
config.clip_variants,
)
val_dataset = MultiCLIPDataset(
config.dataset_name,
'validation',
config.clip_variants,
)
# Collate function for dict of features
def collate_fn(batch):
features = {k: [] for k in config.clip_variants.keys()}
labels = []
for feat_dict, label in batch:
for k, v in feat_dict.items():
features[k].append(v)
labels.append(label)
features = {k: torch.stack(v) for k, v in features.items()}
labels = torch.tensor(labels, dtype=torch.long)
return features, labels
workers_per = config.workers_per_loader()
train_loader = DataLoader(
train_dataset,
batch_size=config.batch_size,
shuffle=True,
num_workers=config.num_workers,
pin_memory=config.pin_memory,
persistent_workers=config.persistent_workers if config.num_workers > 0 else False,
prefetch_factor=config.prefetch_factor if config.num_workers > 0 else None,
collate_fn=collate_fn,
drop_last=True,
)
val_loader = DataLoader(
val_dataset,
batch_size=config.batch_size,
shuffle=False,
num_workers=config.num_workers,
pin_memory=config.pin_memory,
persistent_workers=config.persistent_workers if config.num_workers > 0 else False,
prefetch_factor=config.prefetch_factor if config.num_workers > 0 else None,
collate_fn=collate_fn,
)
return train_loader, val_loader
# =============================================================================
# FEATURE STREAM (No CLIP model - just translation + routing)
# =============================================================================
class FeatureStream(nn.Module):
"""
Stream for pre-extracted CLIP features.
No CLIP model - just translation head + router.
"""
def __init__(
self,
config: ImageNetCollectiveConfig,
variant_name: str,
input_dim: int,
parent_id: Optional[str] = None,
):
super().__init__()
self.config = config
self.variant_name = variant_name
self.input_dim = input_dim
# Translation head: CLIP dim β†’ routing space
self.translation = nn.Sequential(
nn.Linear(input_dim, config.feature_dim * 2),
nn.LayerNorm(config.feature_dim * 2),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(config.feature_dim * 2, config.feature_dim * config.num_slots),
)
# Learnable slot embeddings (unique per stream)
self.slot_embed = nn.Parameter(
torch.randn(1, config.num_slots, config.feature_dim) * 0.02
)
# Router with unique fingerprint
router_config = GlobalFractalRouterConfig(
feature_dim=config.feature_dim,
fingerprint_dim=config.fingerprint_dim,
num_anchors=config.num_anchors,
num_routes=config.num_routes,
use_adjacent_gating=True,
use_cantor_prior=True,
grid_size=(config.num_slots, 1),
)
self.router = GlobalFractalRouter(
config=router_config,
parent_id=parent_id,
cooperation_group="imagenet_collective",
name=variant_name,
)
@property
def fingerprint(self) -> torch.Tensor:
return self.router.fingerprint
@property
def module_id(self) -> str:
return self.router.module_id
def forward(
self,
features: torch.Tensor,
mailbox: RouterMailbox,
target_fingerprint: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Dict]:
"""
Args:
features: [B, input_dim] pre-extracted CLIP features
mailbox: Shared mailbox
target_fingerprint: Next stream's fingerprint
Returns:
routed: [B, num_slots, feature_dim]
info: Dict with metrics
"""
B = features.shape[0]
# Translate to routing space
translated = self.translation(features) # [B, feature_dim * num_slots]
slots = translated.view(B, self.config.num_slots, self.config.feature_dim)
# Add slot embeddings
slots = slots + self.slot_embed
# Route
routes, weights, routed = self.router(
slots,
mailbox=mailbox,
target_fingerprint=target_fingerprint,
skip_first=False,
)
info = {
'route_entropy': -(weights * (weights + 1e-8).log()).sum(dim=-1).mean().item(),
}
return routed, info
# =============================================================================
# MULTI-CLIP COLLECTIVE
# =============================================================================
class ImageNetCollective(nn.Module):
"""
Collective of pre-extracted CLIP features from multiple variants.
Hierarchical chain topology with shared mailbox coordination.
"""
def __init__(self, config: ImageNetCollectiveConfig):
super().__init__()
self.config = config
# Reset registry for fresh start
get_registry().reset()
# Build streams in hierarchical chain
self.streams = nn.ModuleDict()
self.stream_order = list(config.clip_variants.keys())
parent_id = None
for variant_name, input_dim in config.clip_variants.items():
stream = FeatureStream(
config=config,
variant_name=variant_name,
input_dim=input_dim,
parent_id=parent_id,
)
self.streams[variant_name] = stream
parent_id = stream.module_id
print(f" Stream: {variant_name} ({input_dim}D) -> parent: {parent_id[:8] if parent_id else 'root'}...")
# Shared mailbox
router_config = GlobalFractalRouterConfig(
feature_dim=config.feature_dim,
fingerprint_dim=config.fingerprint_dim,
)
self.mailbox = RouterMailbox(router_config)
# Fusion layer
num_streams = len(config.clip_variants)
self.fusion = nn.Sequential(
nn.Linear(config.feature_dim * num_streams, config.feature_dim * 2),
nn.LayerNorm(config.feature_dim * 2),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(config.feature_dim * 2, config.feature_dim),
nn.LayerNorm(config.feature_dim),
)
# Classification head
self.classifier = nn.Linear(config.feature_dim, config.num_classes)
# Per-stream classifiers (for measuring individual contribution)
self.stream_classifiers = nn.ModuleDict({
name: nn.Linear(config.feature_dim, config.num_classes)
for name in config.clip_variants.keys()
})
def forward(
self,
features: Dict[str, torch.Tensor],
return_individual: bool = False,
) -> Tuple[torch.Tensor, Dict]:
"""
Args:
features: Dict mapping variant name to [B, clip_dim] features
return_individual: Also return per-stream predictions
Returns:
logits: [B, num_classes]
info: Dict with metrics
"""
# Clear mailbox
self.mailbox.clear()
# Process streams in order
stream_features = {}
stream_infos = {}
for i, name in enumerate(self.stream_order):
stream = self.streams[name]
# Get target fingerprint (next stream or None)
if i < len(self.stream_order) - 1:
next_name = self.stream_order[i + 1]
target_fp = self.streams[next_name].fingerprint
else:
target_fp = None
# Forward
routed, info = stream(features[name], self.mailbox, target_fp)
# Pool across slots
pooled = routed.mean(dim=1) # [B, feature_dim]
stream_features[name] = pooled
stream_infos[name] = info
# Fuse all streams
fused = torch.cat([stream_features[n] for n in self.stream_order], dim=-1)
fused = self.fusion(fused)
# Classify
logits = self.classifier(fused)
info = {
'stream_infos': stream_infos,
'mailbox_messages': len(self.mailbox.messages),
'mean_route_entropy': np.mean([i['route_entropy'] for i in stream_infos.values()]),
}
if return_individual:
individual_logits = {
name: self.stream_classifiers[name](stream_features[name])
for name in self.stream_order
}
info['individual_logits'] = individual_logits
return logits, info
# =============================================================================
# SINGLE STREAM BASELINE
# =============================================================================
class SingleStreamBaseline(nn.Module):
"""Single CLIP variant with linear probe (no routing)."""
def __init__(self, config: ImageNetCollectiveConfig, variant_name: str, input_dim: int):
super().__init__()
self.variant_name = variant_name
self.classifier = nn.Sequential(
nn.Linear(input_dim, config.feature_dim),
nn.LayerNorm(config.feature_dim),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(config.feature_dim, config.num_classes),
)
def forward(self, features: torch.Tensor) -> torch.Tensor:
return self.classifier(features)
# =============================================================================
# TRAINING
# =============================================================================
def train_collective(
model: ImageNetCollective,
train_loader: DataLoader,
val_loader: DataLoader,
config: ImageNetCollectiveConfig,
):
"""Train collective with AMP."""
optimizer = torch.optim.AdamW(
model.parameters(),
lr=config.lr,
weight_decay=config.weight_decay,
)
# Warmup + cosine schedule
total_steps = len(train_loader) * config.epochs
warmup_steps = len(train_loader) * config.warmup_epochs
def lr_lambda(step):
if step < warmup_steps:
return step / warmup_steps
progress = (step - warmup_steps) / (total_steps - warmup_steps)
return 0.5 * (1 + np.cos(np.pi * progress))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
scaler = GradScaler() if config.use_amp else None
history = defaultdict(list)
best_acc = 0
for epoch in range(config.epochs):
model.train()
epoch_loss = 0
correct = 0
total = 0
pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.epochs}")
for features, labels in pbar:
# Move to device
features = {k: v.to(config.device, non_blocking=True) for k, v in features.items()}
labels = labels.to(config.device, non_blocking=True)
optimizer.zero_grad()
if config.use_amp:
with autocast():
logits, info = model(features)
loss = F.cross_entropy(logits, labels)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
else:
logits, info = model(features)
loss = F.cross_entropy(logits, labels)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
epoch_loss += loss.item() * labels.size(0)
correct += (logits.argmax(dim=1) == labels).sum().item()
total += labels.size(0)
pbar.set_postfix({
'loss': f"{loss.item():.4f}",
'acc': f"{correct/total*100:.1f}%",
'lr': f"{scheduler.get_last_lr()[0]:.2e}",
})
# Validate
val_acc, val_stream_accs = evaluate_collective(model, val_loader, config)
history['train_loss'].append(epoch_loss / total)
history['train_acc'].append(correct / total)
history['val_acc'].append(val_acc)
history['stream_accs'].append(val_stream_accs)
# Log
stream_str = ' | '.join([f"{k[:4]}: {v*100:.1f}%" for k, v in val_stream_accs.items()])
tqdm.write(f"Epoch {epoch+1:3d} | Loss: {epoch_loss/total:.4f} | "
f"Val: {val_acc*100:.2f}% | {stream_str}")
if val_acc > best_acc:
best_acc = val_acc
tqdm.write(f" β˜… New best: {best_acc*100:.2f}%")
return dict(history), best_acc
def evaluate_collective(
model: ImageNetCollective,
loader: DataLoader,
config: ImageNetCollectiveConfig,
) -> Tuple[float, Dict[str, float]]:
"""Evaluate collective and per-stream accuracy."""
model.eval()
correct = 0
total = 0
stream_correct = defaultdict(int)
with torch.no_grad():
for features, labels in tqdm(loader, desc="Eval", leave=False):
features = {k: v.to(config.device, non_blocking=True) for k, v in features.items()}
labels = labels.to(config.device, non_blocking=True)
if config.use_amp:
with autocast():
logits, info = model(features, return_individual=True)
else:
logits, info = model(features, return_individual=True)
correct += (logits.argmax(dim=1) == labels).sum().item()
total += labels.size(0)
for name, ind_logits in info['individual_logits'].items():
stream_correct[name] += (ind_logits.argmax(dim=1) == labels).sum().item()
acc = correct / total
stream_accs = {k: v / total for k, v in stream_correct.items()}
return acc, stream_accs
def train_baseline(
variant_name: str,
input_dim: int,
train_loader: DataLoader,
val_loader: DataLoader,
config: ImageNetCollectiveConfig,
):
"""Train single stream baseline."""
model = SingleStreamBaseline(config, variant_name, input_dim).to(config.device)
optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.epochs)
scaler = GradScaler() if config.use_amp else None
history = defaultdict(list)
best_acc = 0
for epoch in range(config.epochs):
model.train()
epoch_loss = 0
correct = 0
total = 0
for features, labels in tqdm(train_loader, desc=f"{variant_name} E{epoch+1}", leave=False):
feat = features[variant_name].to(config.device, non_blocking=True)
labels = labels.to(config.device, non_blocking=True)
optimizer.zero_grad()
if config.use_amp:
with autocast():
logits = model(feat)
loss = F.cross_entropy(logits, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
logits = model(feat)
loss = F.cross_entropy(logits, labels)
loss.backward()
optimizer.step()
epoch_loss += loss.item() * labels.size(0)
correct += (logits.argmax(dim=1) == labels).sum().item()
total += labels.size(0)
scheduler.step()
# Validate
model.eval()
val_correct = 0
val_total = 0
with torch.no_grad():
for features, labels in val_loader:
feat = features[variant_name].to(config.device, non_blocking=True)
labels = labels.to(config.device, non_blocking=True)
if config.use_amp:
with autocast():
logits = model(feat)
else:
logits = model(feat)
val_correct += (logits.argmax(dim=1) == labels).sum().item()
val_total += labels.size(0)
val_acc = val_correct / val_total
history['val_acc'].append(val_acc)
if val_acc > best_acc:
best_acc = val_acc
if (epoch + 1) % 5 == 0 or epoch == 0:
tqdm.write(f"{variant_name} Epoch {epoch+1:3d} | Val: {val_acc*100:.2f}%")
return dict(history), best_acc
# =============================================================================
# VISUALIZATION
# =============================================================================
def plot_results(
collective_history: Dict,
baseline_results: Dict[str, float],
config: ImageNetCollectiveConfig,
save_path: str = "imagenet_collective_results.png",
):
"""Plot training results."""
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
epochs = range(1, len(collective_history['val_acc']) + 1)
# Validation accuracy over time
ax = axes[0, 0]
ax.plot(epochs, [a*100 for a in collective_history['val_acc']], 'b-',
label='Collective', linewidth=2)
for name in config.clip_variants.keys():
accs = [sa[name]*100 for sa in collective_history['stream_accs']]
ax.plot(epochs, accs, '--', label=f'{name} (in coll.)', alpha=0.7)
ax.set_xlabel('Epoch')
ax.set_ylabel('Validation Accuracy (%)')
ax.set_title('Training Progress')
ax.legend(fontsize=8)
ax.grid(True, alpha=0.3)
# Final comparison bar
ax = axes[0, 1]
final_collective = collective_history['val_acc'][-1] * 100
final_streams = {k: v*100 for k, v in collective_history['stream_accs'][-1].items()}
names = ['Collective'] + list(baseline_results.keys())
values = [final_collective] + [v*100 for v in baseline_results.values()]
colors = ['steelblue'] + ['coral'] * len(baseline_results)
bars = ax.bar(range(len(names)), values, color=colors)
ax.set_xticks(range(len(names)))
ax.set_xticklabels([n.replace('clip_vit_', '').replace('_', '\n') for n in names], fontsize=8)
ax.set_ylabel('Validation Accuracy (%)')
ax.set_title('Final Accuracy: Collective vs Individual Baselines')
for bar, val in zip(bars, values):
ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.3,
f'{val:.1f}%', ha='center', va='bottom', fontsize=8)
# Per-stream accuracy in collective vs baseline
ax = axes[1, 0]
stream_names = list(config.clip_variants.keys())
x = np.arange(len(stream_names))
width = 0.35
in_collective = [final_streams[n] for n in stream_names]
standalone = [baseline_results[n]*100 for n in stream_names]
bars1 = ax.bar(x - width/2, in_collective, width, label='In Collective', color='steelblue')
bars2 = ax.bar(x + width/2, standalone, width, label='Standalone', color='coral')
ax.set_ylabel('Accuracy (%)')
ax.set_title('Per-Stream: Collective vs Standalone')
ax.set_xticks(x)
ax.set_xticklabels([n.replace('clip_vit_', '') for n in stream_names], fontsize=8, rotation=45)
ax.legend()
ax.grid(True, alpha=0.3, axis='y')
# Summary
ax = axes[1, 1]
ax.axis('off')
best_baseline = max(baseline_results.values()) * 100
improvement = final_collective - best_baseline
summary = f"""
IMAGENET COLLECTIVE RESULTS
════════════════════════════════════════════════════════
Collective: {final_collective:.2f}%
Best Individual: {best_baseline:.2f}%
Improvement: {improvement:+.2f}%
════════════════════════════════════════════════════════
Per-stream in collective:
"""
for name, acc in final_streams.items():
short_name = name.replace('clip_vit_', '')
summary += f"\n {short_name:<15}: {acc:.2f}%"
summary += """
════════════════════════════════════════════════════════
Individual baselines:
"""
for name, acc in baseline_results.items():
short_name = name.replace('clip_vit_', '')
summary += f"\n {short_name:<15}: {acc*100:.2f}%"
ax.text(0.05, 0.95, summary, fontsize=10, family='monospace',
verticalalignment='top', transform=ax.transAxes)
plt.tight_layout()
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.show()
print(f"\nSaved: {save_path}")
# =============================================================================
# MAIN
# =============================================================================
def main():
print("="*70)
print(" ImageNet Multi-CLIP Collective Experiment")
print(" Pre-extracted Features via GlobalFractalRouter")
print("="*70)
config = ImageNetCollectiveConfig()
print(f"\nConfig:")
print(f" Dataset: {config.dataset_name}")
print(f" Variants: {len(config.clip_variants)}")
for name, dim in config.clip_variants.items():
print(f" - {name}: {dim}D")
print(f" Feature dim: {config.feature_dim}")
print(f" Epochs: {config.epochs}")
print(f" Batch size: {config.batch_size}")
print(f" AMP: {config.use_amp}")
print(f" Device: {config.device}")
# Data
print("\n" + "="*70)
print(" Loading Data")
print("="*70)
train_loader, val_loader = get_dataloaders(config)
print(f"\n Train batches: {len(train_loader)}")
print(f" Val batches: {len(val_loader)}")
# =================================================================
# COLLECTIVE
# =================================================================
print("\n" + "="*70)
print(" Training COLLECTIVE")
print("="*70)
collective = ImageNetCollective(config).to(config.device)
params = sum(p.numel() for p in collective.parameters())
print(f"\n Parameters: {params:,}")
collective_history, collective_best = train_collective(
collective, train_loader, val_loader, config
)
# =================================================================
# BASELINES
# =================================================================
print("\n" + "="*70)
print(" Training BASELINES (Individual Streams)")
print("="*70)
baseline_results = {}
for variant_name, input_dim in config.clip_variants.items():
print(f"\n Training: {variant_name}")
_, best_acc = train_baseline(
variant_name, input_dim, train_loader, val_loader, config
)
baseline_results[variant_name] = best_acc
print(f" {variant_name} best: {best_acc*100:.2f}%")
# =================================================================
# RESULTS
# =================================================================
print("\n" + "="*70)
print(" FINAL RESULTS")
print("="*70)
print(f"\n Collective: {collective_best*100:.2f}%")
print(f" Best individual: {max(baseline_results.values())*100:.2f}%")
print(f" Improvement: {(collective_best - max(baseline_results.values()))*100:+.2f}%")
print("\n Per-stream final (in collective):")
for name, acc in collective_history['stream_accs'][-1].items():
print(f" {name}: {acc*100:.2f}%")
print("\n Individual baselines:")
for name, acc in baseline_results.items():
print(f" {name}: {acc*100:.2f}%")
plot_results(collective_history, baseline_results, config)
return collective, collective_history, baseline_results
if __name__ == "__main__":
results = main()