|
|
""" |
|
|
ImageNet Multi-CLIP Collective Experiment |
|
|
========================================== |
|
|
Uses pre-extracted CLIP features from multiple model variants. |
|
|
No image processing - pure feature routing at A100 speeds. |
|
|
|
|
|
Dataset: AbstractPhil/clip-imagenet-features |
|
|
Streams: b32, b16, l14, laion_b32, laion_bigg14, laion_h14 |
|
|
|
|
|
Each CLIP variant becomes an expert stream with: |
|
|
- Learnable translation head |
|
|
- Own router with unique fingerprint |
|
|
- Hierarchical coordination via mailbox |
|
|
|
|
|
Training: |
|
|
- AMP mixed precision |
|
|
- 8 workers total, pinned, persistent |
|
|
- Hierarchical chain topology |
|
|
|
|
|
Author: AbstractPhil |
|
|
Date: December 2025 |
|
|
License: Apache 2.0 |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import DataLoader, Dataset |
|
|
from torch.cuda.amp import autocast, GradScaler |
|
|
from datasets import load_dataset |
|
|
from dataclasses import dataclass, field |
|
|
from typing import Dict, Tuple, List, Optional |
|
|
from collections import defaultdict |
|
|
import numpy as np |
|
|
from tqdm.auto import tqdm |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from geofractal.model.blocks.router.global_fractal_router import ( |
|
|
GlobalFractalRouter, |
|
|
GlobalFractalRouterConfig, |
|
|
get_registry, |
|
|
RouterMailbox, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ImageNetCollectiveConfig: |
|
|
"""Configuration for ImageNet multi-CLIP collective.""" |
|
|
|
|
|
|
|
|
dataset_name: str = "AbstractPhil/imagenet-clip-features" |
|
|
num_classes: int = 1000 |
|
|
|
|
|
|
|
|
clip_variants: Dict[str, int] = field(default_factory=lambda: { |
|
|
'clip_vit_b32': 512, |
|
|
'clip_vit_b16': 512, |
|
|
'clip_vit_l14': 768, |
|
|
'clip_vit_laion_b32': 512, |
|
|
'clip_vit_laion_bigg14': 1280, |
|
|
|
|
|
}) |
|
|
|
|
|
|
|
|
feature_dim: int = 512 |
|
|
fingerprint_dim: int = 64 |
|
|
|
|
|
|
|
|
num_anchors: int = 16 |
|
|
num_routes: int = 8 |
|
|
num_slots: int = 16 |
|
|
|
|
|
|
|
|
batch_size: int = 256 |
|
|
epochs: int = 20 |
|
|
lr: float = 3e-4 |
|
|
weight_decay: float = 0.01 |
|
|
warmup_epochs: int = 2 |
|
|
|
|
|
|
|
|
num_workers: int = 8 |
|
|
pin_memory: bool = True |
|
|
persistent_workers: bool = True |
|
|
prefetch_factor: int = 4 |
|
|
|
|
|
|
|
|
use_amp: bool = True |
|
|
|
|
|
device: str = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
def workers_per_loader(self) -> int: |
|
|
"""Distribute workers across loaders.""" |
|
|
n_loaders = len(self.clip_variants) |
|
|
return max(1, self.num_workers // n_loaders) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CLIPFeatureDataset(Dataset): |
|
|
""" |
|
|
Wraps HuggingFace dataset for a single CLIP variant. |
|
|
Returns pre-extracted features and labels. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
hf_dataset, |
|
|
feature_column: str = 'clip_features', |
|
|
label_column: str = 'label', |
|
|
): |
|
|
self.dataset = hf_dataset |
|
|
self.feature_column = feature_column |
|
|
self.label_column = label_column |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.dataset) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
item = self.dataset[idx] |
|
|
features = torch.tensor(item[self.feature_column], dtype=torch.float32) |
|
|
label = item[self.label_column] |
|
|
return features, label |
|
|
|
|
|
|
|
|
class MultiCLIPDataset(Dataset): |
|
|
""" |
|
|
Loads features from multiple CLIP variants simultaneously. |
|
|
Returns dict of features + label. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dataset_name: str, |
|
|
split_prefix: str, |
|
|
clip_variants: Dict[str, int], |
|
|
): |
|
|
self.variants = list(clip_variants.keys()) |
|
|
self.datasets = {} |
|
|
|
|
|
print(f"Loading {split_prefix} splits...") |
|
|
for variant in tqdm(self.variants, desc="Loading variants"): |
|
|
split_name = f"{variant}_{split_prefix}" |
|
|
try: |
|
|
ds = load_dataset(dataset_name, split=split_name) |
|
|
self.datasets[variant] = ds |
|
|
print(f" {variant}: {len(ds):,} samples") |
|
|
except Exception as e: |
|
|
print(f" WARNING: Could not load {split_name}: {e}") |
|
|
|
|
|
|
|
|
self.length = len(next(iter(self.datasets.values()))) |
|
|
|
|
|
|
|
|
for name, ds in self.datasets.items(): |
|
|
assert len(ds) == self.length, f"{name} has {len(ds)} != {self.length}" |
|
|
|
|
|
def __len__(self): |
|
|
return self.length |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
features = {} |
|
|
label = None |
|
|
|
|
|
for variant, ds in self.datasets.items(): |
|
|
item = ds[idx] |
|
|
features[variant] = torch.tensor(item['clip_features'], dtype=torch.float32) |
|
|
if label is None: |
|
|
label = item['label'] |
|
|
|
|
|
return features, label |
|
|
|
|
|
|
|
|
def get_dataloaders(config: ImageNetCollectiveConfig): |
|
|
"""Create train and validation dataloaders.""" |
|
|
|
|
|
train_dataset = MultiCLIPDataset( |
|
|
config.dataset_name, |
|
|
'train', |
|
|
config.clip_variants, |
|
|
) |
|
|
|
|
|
val_dataset = MultiCLIPDataset( |
|
|
config.dataset_name, |
|
|
'validation', |
|
|
config.clip_variants, |
|
|
) |
|
|
|
|
|
|
|
|
def collate_fn(batch): |
|
|
features = {k: [] for k in config.clip_variants.keys()} |
|
|
labels = [] |
|
|
|
|
|
for feat_dict, label in batch: |
|
|
for k, v in feat_dict.items(): |
|
|
features[k].append(v) |
|
|
labels.append(label) |
|
|
|
|
|
features = {k: torch.stack(v) for k, v in features.items()} |
|
|
labels = torch.tensor(labels, dtype=torch.long) |
|
|
|
|
|
return features, labels |
|
|
|
|
|
workers_per = config.workers_per_loader() |
|
|
|
|
|
train_loader = DataLoader( |
|
|
train_dataset, |
|
|
batch_size=config.batch_size, |
|
|
shuffle=True, |
|
|
num_workers=config.num_workers, |
|
|
pin_memory=config.pin_memory, |
|
|
persistent_workers=config.persistent_workers if config.num_workers > 0 else False, |
|
|
prefetch_factor=config.prefetch_factor if config.num_workers > 0 else None, |
|
|
collate_fn=collate_fn, |
|
|
drop_last=True, |
|
|
) |
|
|
|
|
|
val_loader = DataLoader( |
|
|
val_dataset, |
|
|
batch_size=config.batch_size, |
|
|
shuffle=False, |
|
|
num_workers=config.num_workers, |
|
|
pin_memory=config.pin_memory, |
|
|
persistent_workers=config.persistent_workers if config.num_workers > 0 else False, |
|
|
prefetch_factor=config.prefetch_factor if config.num_workers > 0 else None, |
|
|
collate_fn=collate_fn, |
|
|
) |
|
|
|
|
|
return train_loader, val_loader |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FeatureStream(nn.Module): |
|
|
""" |
|
|
Stream for pre-extracted CLIP features. |
|
|
No CLIP model - just translation head + router. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
config: ImageNetCollectiveConfig, |
|
|
variant_name: str, |
|
|
input_dim: int, |
|
|
parent_id: Optional[str] = None, |
|
|
): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.variant_name = variant_name |
|
|
self.input_dim = input_dim |
|
|
|
|
|
|
|
|
self.translation = nn.Sequential( |
|
|
nn.Linear(input_dim, config.feature_dim * 2), |
|
|
nn.LayerNorm(config.feature_dim * 2), |
|
|
nn.GELU(), |
|
|
nn.Dropout(0.1), |
|
|
nn.Linear(config.feature_dim * 2, config.feature_dim * config.num_slots), |
|
|
) |
|
|
|
|
|
|
|
|
self.slot_embed = nn.Parameter( |
|
|
torch.randn(1, config.num_slots, config.feature_dim) * 0.02 |
|
|
) |
|
|
|
|
|
|
|
|
router_config = GlobalFractalRouterConfig( |
|
|
feature_dim=config.feature_dim, |
|
|
fingerprint_dim=config.fingerprint_dim, |
|
|
num_anchors=config.num_anchors, |
|
|
num_routes=config.num_routes, |
|
|
use_adjacent_gating=True, |
|
|
use_cantor_prior=True, |
|
|
grid_size=(config.num_slots, 1), |
|
|
) |
|
|
|
|
|
self.router = GlobalFractalRouter( |
|
|
config=router_config, |
|
|
parent_id=parent_id, |
|
|
cooperation_group="imagenet_collective", |
|
|
name=variant_name, |
|
|
) |
|
|
|
|
|
@property |
|
|
def fingerprint(self) -> torch.Tensor: |
|
|
return self.router.fingerprint |
|
|
|
|
|
@property |
|
|
def module_id(self) -> str: |
|
|
return self.router.module_id |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
features: torch.Tensor, |
|
|
mailbox: RouterMailbox, |
|
|
target_fingerprint: Optional[torch.Tensor] = None, |
|
|
) -> Tuple[torch.Tensor, Dict]: |
|
|
""" |
|
|
Args: |
|
|
features: [B, input_dim] pre-extracted CLIP features |
|
|
mailbox: Shared mailbox |
|
|
target_fingerprint: Next stream's fingerprint |
|
|
|
|
|
Returns: |
|
|
routed: [B, num_slots, feature_dim] |
|
|
info: Dict with metrics |
|
|
""" |
|
|
B = features.shape[0] |
|
|
|
|
|
|
|
|
translated = self.translation(features) |
|
|
slots = translated.view(B, self.config.num_slots, self.config.feature_dim) |
|
|
|
|
|
|
|
|
slots = slots + self.slot_embed |
|
|
|
|
|
|
|
|
routes, weights, routed = self.router( |
|
|
slots, |
|
|
mailbox=mailbox, |
|
|
target_fingerprint=target_fingerprint, |
|
|
skip_first=False, |
|
|
) |
|
|
|
|
|
info = { |
|
|
'route_entropy': -(weights * (weights + 1e-8).log()).sum(dim=-1).mean().item(), |
|
|
} |
|
|
|
|
|
return routed, info |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ImageNetCollective(nn.Module): |
|
|
""" |
|
|
Collective of pre-extracted CLIP features from multiple variants. |
|
|
Hierarchical chain topology with shared mailbox coordination. |
|
|
""" |
|
|
|
|
|
def __init__(self, config: ImageNetCollectiveConfig): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
|
|
|
|
|
|
get_registry().reset() |
|
|
|
|
|
|
|
|
self.streams = nn.ModuleDict() |
|
|
self.stream_order = list(config.clip_variants.keys()) |
|
|
|
|
|
parent_id = None |
|
|
for variant_name, input_dim in config.clip_variants.items(): |
|
|
stream = FeatureStream( |
|
|
config=config, |
|
|
variant_name=variant_name, |
|
|
input_dim=input_dim, |
|
|
parent_id=parent_id, |
|
|
) |
|
|
self.streams[variant_name] = stream |
|
|
parent_id = stream.module_id |
|
|
print(f" Stream: {variant_name} ({input_dim}D) -> parent: {parent_id[:8] if parent_id else 'root'}...") |
|
|
|
|
|
|
|
|
router_config = GlobalFractalRouterConfig( |
|
|
feature_dim=config.feature_dim, |
|
|
fingerprint_dim=config.fingerprint_dim, |
|
|
) |
|
|
self.mailbox = RouterMailbox(router_config) |
|
|
|
|
|
|
|
|
num_streams = len(config.clip_variants) |
|
|
self.fusion = nn.Sequential( |
|
|
nn.Linear(config.feature_dim * num_streams, config.feature_dim * 2), |
|
|
nn.LayerNorm(config.feature_dim * 2), |
|
|
nn.GELU(), |
|
|
nn.Dropout(0.1), |
|
|
nn.Linear(config.feature_dim * 2, config.feature_dim), |
|
|
nn.LayerNorm(config.feature_dim), |
|
|
) |
|
|
|
|
|
|
|
|
self.classifier = nn.Linear(config.feature_dim, config.num_classes) |
|
|
|
|
|
|
|
|
self.stream_classifiers = nn.ModuleDict({ |
|
|
name: nn.Linear(config.feature_dim, config.num_classes) |
|
|
for name in config.clip_variants.keys() |
|
|
}) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
features: Dict[str, torch.Tensor], |
|
|
return_individual: bool = False, |
|
|
) -> Tuple[torch.Tensor, Dict]: |
|
|
""" |
|
|
Args: |
|
|
features: Dict mapping variant name to [B, clip_dim] features |
|
|
return_individual: Also return per-stream predictions |
|
|
|
|
|
Returns: |
|
|
logits: [B, num_classes] |
|
|
info: Dict with metrics |
|
|
""" |
|
|
|
|
|
self.mailbox.clear() |
|
|
|
|
|
|
|
|
stream_features = {} |
|
|
stream_infos = {} |
|
|
|
|
|
for i, name in enumerate(self.stream_order): |
|
|
stream = self.streams[name] |
|
|
|
|
|
|
|
|
if i < len(self.stream_order) - 1: |
|
|
next_name = self.stream_order[i + 1] |
|
|
target_fp = self.streams[next_name].fingerprint |
|
|
else: |
|
|
target_fp = None |
|
|
|
|
|
|
|
|
routed, info = stream(features[name], self.mailbox, target_fp) |
|
|
|
|
|
|
|
|
pooled = routed.mean(dim=1) |
|
|
stream_features[name] = pooled |
|
|
stream_infos[name] = info |
|
|
|
|
|
|
|
|
fused = torch.cat([stream_features[n] for n in self.stream_order], dim=-1) |
|
|
fused = self.fusion(fused) |
|
|
|
|
|
|
|
|
logits = self.classifier(fused) |
|
|
|
|
|
info = { |
|
|
'stream_infos': stream_infos, |
|
|
'mailbox_messages': len(self.mailbox.messages), |
|
|
'mean_route_entropy': np.mean([i['route_entropy'] for i in stream_infos.values()]), |
|
|
} |
|
|
|
|
|
if return_individual: |
|
|
individual_logits = { |
|
|
name: self.stream_classifiers[name](stream_features[name]) |
|
|
for name in self.stream_order |
|
|
} |
|
|
info['individual_logits'] = individual_logits |
|
|
|
|
|
return logits, info |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SingleStreamBaseline(nn.Module): |
|
|
"""Single CLIP variant with linear probe (no routing).""" |
|
|
|
|
|
def __init__(self, config: ImageNetCollectiveConfig, variant_name: str, input_dim: int): |
|
|
super().__init__() |
|
|
self.variant_name = variant_name |
|
|
|
|
|
self.classifier = nn.Sequential( |
|
|
nn.Linear(input_dim, config.feature_dim), |
|
|
nn.LayerNorm(config.feature_dim), |
|
|
nn.GELU(), |
|
|
nn.Dropout(0.1), |
|
|
nn.Linear(config.feature_dim, config.num_classes), |
|
|
) |
|
|
|
|
|
def forward(self, features: torch.Tensor) -> torch.Tensor: |
|
|
return self.classifier(features) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_collective( |
|
|
model: ImageNetCollective, |
|
|
train_loader: DataLoader, |
|
|
val_loader: DataLoader, |
|
|
config: ImageNetCollectiveConfig, |
|
|
): |
|
|
"""Train collective with AMP.""" |
|
|
|
|
|
optimizer = torch.optim.AdamW( |
|
|
model.parameters(), |
|
|
lr=config.lr, |
|
|
weight_decay=config.weight_decay, |
|
|
) |
|
|
|
|
|
|
|
|
total_steps = len(train_loader) * config.epochs |
|
|
warmup_steps = len(train_loader) * config.warmup_epochs |
|
|
|
|
|
def lr_lambda(step): |
|
|
if step < warmup_steps: |
|
|
return step / warmup_steps |
|
|
progress = (step - warmup_steps) / (total_steps - warmup_steps) |
|
|
return 0.5 * (1 + np.cos(np.pi * progress)) |
|
|
|
|
|
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) |
|
|
scaler = GradScaler() if config.use_amp else None |
|
|
|
|
|
history = defaultdict(list) |
|
|
best_acc = 0 |
|
|
|
|
|
for epoch in range(config.epochs): |
|
|
model.train() |
|
|
epoch_loss = 0 |
|
|
correct = 0 |
|
|
total = 0 |
|
|
|
|
|
pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.epochs}") |
|
|
|
|
|
for features, labels in pbar: |
|
|
|
|
|
features = {k: v.to(config.device, non_blocking=True) for k, v in features.items()} |
|
|
labels = labels.to(config.device, non_blocking=True) |
|
|
|
|
|
optimizer.zero_grad() |
|
|
|
|
|
if config.use_amp: |
|
|
with autocast(): |
|
|
logits, info = model(features) |
|
|
loss = F.cross_entropy(logits, labels) |
|
|
|
|
|
scaler.scale(loss).backward() |
|
|
scaler.unscale_(optimizer) |
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
|
|
scaler.step(optimizer) |
|
|
scaler.update() |
|
|
else: |
|
|
logits, info = model(features) |
|
|
loss = F.cross_entropy(logits, labels) |
|
|
loss.backward() |
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
|
|
optimizer.step() |
|
|
|
|
|
scheduler.step() |
|
|
|
|
|
epoch_loss += loss.item() * labels.size(0) |
|
|
correct += (logits.argmax(dim=1) == labels).sum().item() |
|
|
total += labels.size(0) |
|
|
|
|
|
pbar.set_postfix({ |
|
|
'loss': f"{loss.item():.4f}", |
|
|
'acc': f"{correct/total*100:.1f}%", |
|
|
'lr': f"{scheduler.get_last_lr()[0]:.2e}", |
|
|
}) |
|
|
|
|
|
|
|
|
val_acc, val_stream_accs = evaluate_collective(model, val_loader, config) |
|
|
|
|
|
history['train_loss'].append(epoch_loss / total) |
|
|
history['train_acc'].append(correct / total) |
|
|
history['val_acc'].append(val_acc) |
|
|
history['stream_accs'].append(val_stream_accs) |
|
|
|
|
|
|
|
|
stream_str = ' | '.join([f"{k[:4]}: {v*100:.1f}%" for k, v in val_stream_accs.items()]) |
|
|
tqdm.write(f"Epoch {epoch+1:3d} | Loss: {epoch_loss/total:.4f} | " |
|
|
f"Val: {val_acc*100:.2f}% | {stream_str}") |
|
|
|
|
|
if val_acc > best_acc: |
|
|
best_acc = val_acc |
|
|
tqdm.write(f" β
New best: {best_acc*100:.2f}%") |
|
|
|
|
|
return dict(history), best_acc |
|
|
|
|
|
|
|
|
def evaluate_collective( |
|
|
model: ImageNetCollective, |
|
|
loader: DataLoader, |
|
|
config: ImageNetCollectiveConfig, |
|
|
) -> Tuple[float, Dict[str, float]]: |
|
|
"""Evaluate collective and per-stream accuracy.""" |
|
|
|
|
|
model.eval() |
|
|
correct = 0 |
|
|
total = 0 |
|
|
stream_correct = defaultdict(int) |
|
|
|
|
|
with torch.no_grad(): |
|
|
for features, labels in tqdm(loader, desc="Eval", leave=False): |
|
|
features = {k: v.to(config.device, non_blocking=True) for k, v in features.items()} |
|
|
labels = labels.to(config.device, non_blocking=True) |
|
|
|
|
|
if config.use_amp: |
|
|
with autocast(): |
|
|
logits, info = model(features, return_individual=True) |
|
|
else: |
|
|
logits, info = model(features, return_individual=True) |
|
|
|
|
|
correct += (logits.argmax(dim=1) == labels).sum().item() |
|
|
total += labels.size(0) |
|
|
|
|
|
for name, ind_logits in info['individual_logits'].items(): |
|
|
stream_correct[name] += (ind_logits.argmax(dim=1) == labels).sum().item() |
|
|
|
|
|
acc = correct / total |
|
|
stream_accs = {k: v / total for k, v in stream_correct.items()} |
|
|
|
|
|
return acc, stream_accs |
|
|
|
|
|
|
|
|
def train_baseline( |
|
|
variant_name: str, |
|
|
input_dim: int, |
|
|
train_loader: DataLoader, |
|
|
val_loader: DataLoader, |
|
|
config: ImageNetCollectiveConfig, |
|
|
): |
|
|
"""Train single stream baseline.""" |
|
|
|
|
|
model = SingleStreamBaseline(config, variant_name, input_dim).to(config.device) |
|
|
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay) |
|
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.epochs) |
|
|
scaler = GradScaler() if config.use_amp else None |
|
|
|
|
|
history = defaultdict(list) |
|
|
best_acc = 0 |
|
|
|
|
|
for epoch in range(config.epochs): |
|
|
model.train() |
|
|
epoch_loss = 0 |
|
|
correct = 0 |
|
|
total = 0 |
|
|
|
|
|
for features, labels in tqdm(train_loader, desc=f"{variant_name} E{epoch+1}", leave=False): |
|
|
feat = features[variant_name].to(config.device, non_blocking=True) |
|
|
labels = labels.to(config.device, non_blocking=True) |
|
|
|
|
|
optimizer.zero_grad() |
|
|
|
|
|
if config.use_amp: |
|
|
with autocast(): |
|
|
logits = model(feat) |
|
|
loss = F.cross_entropy(logits, labels) |
|
|
scaler.scale(loss).backward() |
|
|
scaler.step(optimizer) |
|
|
scaler.update() |
|
|
else: |
|
|
logits = model(feat) |
|
|
loss = F.cross_entropy(logits, labels) |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
|
|
|
epoch_loss += loss.item() * labels.size(0) |
|
|
correct += (logits.argmax(dim=1) == labels).sum().item() |
|
|
total += labels.size(0) |
|
|
|
|
|
scheduler.step() |
|
|
|
|
|
|
|
|
model.eval() |
|
|
val_correct = 0 |
|
|
val_total = 0 |
|
|
|
|
|
with torch.no_grad(): |
|
|
for features, labels in val_loader: |
|
|
feat = features[variant_name].to(config.device, non_blocking=True) |
|
|
labels = labels.to(config.device, non_blocking=True) |
|
|
|
|
|
if config.use_amp: |
|
|
with autocast(): |
|
|
logits = model(feat) |
|
|
else: |
|
|
logits = model(feat) |
|
|
|
|
|
val_correct += (logits.argmax(dim=1) == labels).sum().item() |
|
|
val_total += labels.size(0) |
|
|
|
|
|
val_acc = val_correct / val_total |
|
|
history['val_acc'].append(val_acc) |
|
|
|
|
|
if val_acc > best_acc: |
|
|
best_acc = val_acc |
|
|
|
|
|
if (epoch + 1) % 5 == 0 or epoch == 0: |
|
|
tqdm.write(f"{variant_name} Epoch {epoch+1:3d} | Val: {val_acc*100:.2f}%") |
|
|
|
|
|
return dict(history), best_acc |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_results( |
|
|
collective_history: Dict, |
|
|
baseline_results: Dict[str, float], |
|
|
config: ImageNetCollectiveConfig, |
|
|
save_path: str = "imagenet_collective_results.png", |
|
|
): |
|
|
"""Plot training results.""" |
|
|
|
|
|
fig, axes = plt.subplots(2, 2, figsize=(14, 10)) |
|
|
|
|
|
epochs = range(1, len(collective_history['val_acc']) + 1) |
|
|
|
|
|
|
|
|
ax = axes[0, 0] |
|
|
ax.plot(epochs, [a*100 for a in collective_history['val_acc']], 'b-', |
|
|
label='Collective', linewidth=2) |
|
|
for name in config.clip_variants.keys(): |
|
|
accs = [sa[name]*100 for sa in collective_history['stream_accs']] |
|
|
ax.plot(epochs, accs, '--', label=f'{name} (in coll.)', alpha=0.7) |
|
|
ax.set_xlabel('Epoch') |
|
|
ax.set_ylabel('Validation Accuracy (%)') |
|
|
ax.set_title('Training Progress') |
|
|
ax.legend(fontsize=8) |
|
|
ax.grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
ax = axes[0, 1] |
|
|
|
|
|
final_collective = collective_history['val_acc'][-1] * 100 |
|
|
final_streams = {k: v*100 for k, v in collective_history['stream_accs'][-1].items()} |
|
|
|
|
|
names = ['Collective'] + list(baseline_results.keys()) |
|
|
values = [final_collective] + [v*100 for v in baseline_results.values()] |
|
|
colors = ['steelblue'] + ['coral'] * len(baseline_results) |
|
|
|
|
|
bars = ax.bar(range(len(names)), values, color=colors) |
|
|
ax.set_xticks(range(len(names))) |
|
|
ax.set_xticklabels([n.replace('clip_vit_', '').replace('_', '\n') for n in names], fontsize=8) |
|
|
ax.set_ylabel('Validation Accuracy (%)') |
|
|
ax.set_title('Final Accuracy: Collective vs Individual Baselines') |
|
|
|
|
|
for bar, val in zip(bars, values): |
|
|
ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.3, |
|
|
f'{val:.1f}%', ha='center', va='bottom', fontsize=8) |
|
|
|
|
|
|
|
|
ax = axes[1, 0] |
|
|
|
|
|
stream_names = list(config.clip_variants.keys()) |
|
|
x = np.arange(len(stream_names)) |
|
|
width = 0.35 |
|
|
|
|
|
in_collective = [final_streams[n] for n in stream_names] |
|
|
standalone = [baseline_results[n]*100 for n in stream_names] |
|
|
|
|
|
bars1 = ax.bar(x - width/2, in_collective, width, label='In Collective', color='steelblue') |
|
|
bars2 = ax.bar(x + width/2, standalone, width, label='Standalone', color='coral') |
|
|
|
|
|
ax.set_ylabel('Accuracy (%)') |
|
|
ax.set_title('Per-Stream: Collective vs Standalone') |
|
|
ax.set_xticks(x) |
|
|
ax.set_xticklabels([n.replace('clip_vit_', '') for n in stream_names], fontsize=8, rotation=45) |
|
|
ax.legend() |
|
|
ax.grid(True, alpha=0.3, axis='y') |
|
|
|
|
|
|
|
|
ax = axes[1, 1] |
|
|
ax.axis('off') |
|
|
|
|
|
best_baseline = max(baseline_results.values()) * 100 |
|
|
improvement = final_collective - best_baseline |
|
|
|
|
|
summary = f""" |
|
|
IMAGENET COLLECTIVE RESULTS |
|
|
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
|
|
|
|
|
Collective: {final_collective:.2f}% |
|
|
Best Individual: {best_baseline:.2f}% |
|
|
|
|
|
Improvement: {improvement:+.2f}% |
|
|
|
|
|
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
|
|
|
|
|
Per-stream in collective: |
|
|
""" |
|
|
|
|
|
for name, acc in final_streams.items(): |
|
|
short_name = name.replace('clip_vit_', '') |
|
|
summary += f"\n {short_name:<15}: {acc:.2f}%" |
|
|
|
|
|
summary += """ |
|
|
|
|
|
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
|
|
|
|
|
Individual baselines: |
|
|
""" |
|
|
|
|
|
for name, acc in baseline_results.items(): |
|
|
short_name = name.replace('clip_vit_', '') |
|
|
summary += f"\n {short_name:<15}: {acc*100:.2f}%" |
|
|
|
|
|
ax.text(0.05, 0.95, summary, fontsize=10, family='monospace', |
|
|
verticalalignment='top', transform=ax.transAxes) |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig(save_path, dpi=150, bbox_inches='tight') |
|
|
plt.show() |
|
|
print(f"\nSaved: {save_path}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
print("="*70) |
|
|
print(" ImageNet Multi-CLIP Collective Experiment") |
|
|
print(" Pre-extracted Features via GlobalFractalRouter") |
|
|
print("="*70) |
|
|
|
|
|
config = ImageNetCollectiveConfig() |
|
|
|
|
|
print(f"\nConfig:") |
|
|
print(f" Dataset: {config.dataset_name}") |
|
|
print(f" Variants: {len(config.clip_variants)}") |
|
|
for name, dim in config.clip_variants.items(): |
|
|
print(f" - {name}: {dim}D") |
|
|
print(f" Feature dim: {config.feature_dim}") |
|
|
print(f" Epochs: {config.epochs}") |
|
|
print(f" Batch size: {config.batch_size}") |
|
|
print(f" AMP: {config.use_amp}") |
|
|
print(f" Device: {config.device}") |
|
|
|
|
|
|
|
|
print("\n" + "="*70) |
|
|
print(" Loading Data") |
|
|
print("="*70) |
|
|
|
|
|
train_loader, val_loader = get_dataloaders(config) |
|
|
print(f"\n Train batches: {len(train_loader)}") |
|
|
print(f" Val batches: {len(val_loader)}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n" + "="*70) |
|
|
print(" Training COLLECTIVE") |
|
|
print("="*70) |
|
|
|
|
|
collective = ImageNetCollective(config).to(config.device) |
|
|
|
|
|
params = sum(p.numel() for p in collective.parameters()) |
|
|
print(f"\n Parameters: {params:,}") |
|
|
|
|
|
collective_history, collective_best = train_collective( |
|
|
collective, train_loader, val_loader, config |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n" + "="*70) |
|
|
print(" Training BASELINES (Individual Streams)") |
|
|
print("="*70) |
|
|
|
|
|
baseline_results = {} |
|
|
|
|
|
for variant_name, input_dim in config.clip_variants.items(): |
|
|
print(f"\n Training: {variant_name}") |
|
|
_, best_acc = train_baseline( |
|
|
variant_name, input_dim, train_loader, val_loader, config |
|
|
) |
|
|
baseline_results[variant_name] = best_acc |
|
|
print(f" {variant_name} best: {best_acc*100:.2f}%") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n" + "="*70) |
|
|
print(" FINAL RESULTS") |
|
|
print("="*70) |
|
|
|
|
|
print(f"\n Collective: {collective_best*100:.2f}%") |
|
|
print(f" Best individual: {max(baseline_results.values())*100:.2f}%") |
|
|
print(f" Improvement: {(collective_best - max(baseline_results.values()))*100:+.2f}%") |
|
|
|
|
|
print("\n Per-stream final (in collective):") |
|
|
for name, acc in collective_history['stream_accs'][-1].items(): |
|
|
print(f" {name}: {acc*100:.2f}%") |
|
|
|
|
|
print("\n Individual baselines:") |
|
|
for name, acc in baseline_results.items(): |
|
|
print(f" {name}: {acc*100:.2f}%") |
|
|
|
|
|
plot_results(collective_history, baseline_results, config) |
|
|
|
|
|
return collective, collective_history, baseline_results |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
results = main() |