AbstractPhil commited on
Commit
09b6e4d
Β·
verified Β·
1 Parent(s): 45bc23d

Create 5clip_imagenet.py

Browse files
Files changed (1) hide show
  1. 5clip_imagenet.py +904 -0
5clip_imagenet.py ADDED
@@ -0,0 +1,904 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ImageNet Multi-CLIP Collective Experiment
3
+ ==========================================
4
+ Uses pre-extracted CLIP features from multiple model variants.
5
+ No image processing - pure feature routing at A100 speeds.
6
+
7
+ Dataset: AbstractPhil/clip-imagenet-features
8
+ Streams: b32, b16, l14, laion_b32, laion_bigg14, laion_h14
9
+
10
+ Each CLIP variant becomes an expert stream with:
11
+ - Learnable translation head
12
+ - Own router with unique fingerprint
13
+ - Hierarchical coordination via mailbox
14
+
15
+ Training:
16
+ - AMP mixed precision
17
+ - 8 workers total, pinned, persistent
18
+ - Hierarchical chain topology
19
+
20
+ Author: AbstractPhil
21
+ Date: December 2025
22
+ License: Apache 2.0
23
+ """
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+ import torch.nn.functional as F
28
+ from torch.utils.data import DataLoader, Dataset
29
+ from torch.cuda.amp import autocast, GradScaler
30
+ from datasets import load_dataset
31
+ from dataclasses import dataclass, field
32
+ from typing import Dict, Tuple, List, Optional
33
+ from collections import defaultdict
34
+ import numpy as np
35
+ from tqdm.auto import tqdm
36
+ import matplotlib.pyplot as plt
37
+
38
+ # =============================================================================
39
+ # IMPORTS FROM GEOFRACTAL
40
+ # =============================================================================
41
+
42
+ from geofractal.model.blocks.router.global_fractal_router import (
43
+ GlobalFractalRouter,
44
+ GlobalFractalRouterConfig,
45
+ get_registry,
46
+ RouterMailbox,
47
+ )
48
+
49
+ # =============================================================================
50
+ # CONFIG
51
+ # =============================================================================
52
+
53
+ @dataclass
54
+ class ImageNetCollectiveConfig:
55
+ """Configuration for ImageNet multi-CLIP collective."""
56
+
57
+ # Dataset
58
+ dataset_name: str = "AbstractPhil/imagenet-clip-features"
59
+ num_classes: int = 1000
60
+
61
+ # CLIP variants and their dimensions
62
+ clip_variants: Dict[str, int] = field(default_factory=lambda: {
63
+ 'clip_vit_b32': 512,
64
+ 'clip_vit_b16': 512,
65
+ 'clip_vit_l14': 768,
66
+ 'clip_vit_laion_b32': 512,
67
+ 'clip_vit_laion_bigg14': 1280,
68
+ # 'clip_vit_laion_h14': 1024, # Can add if memory permits
69
+ })
70
+
71
+ # Feature dimensions
72
+ feature_dim: int = 512 # Internal routing dimension
73
+ fingerprint_dim: int = 64
74
+
75
+ # Router
76
+ num_anchors: int = 16
77
+ num_routes: int = 8
78
+ num_slots: int = 16 # Sequence length for routing
79
+
80
+ # Training
81
+ batch_size: int = 256
82
+ epochs: int = 20
83
+ lr: float = 3e-4
84
+ weight_decay: float = 0.01
85
+ warmup_epochs: int = 2
86
+
87
+ # DataLoader - A100 optimized
88
+ num_workers: int = 8 # Total across all loaders
89
+ pin_memory: bool = True
90
+ persistent_workers: bool = True
91
+ prefetch_factor: int = 4
92
+
93
+ # AMP
94
+ use_amp: bool = True
95
+
96
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
97
+
98
+ def workers_per_loader(self) -> int:
99
+ """Distribute workers across loaders."""
100
+ n_loaders = len(self.clip_variants)
101
+ return max(1, self.num_workers // n_loaders)
102
+
103
+
104
+ # =============================================================================
105
+ # DATASET
106
+ # =============================================================================
107
+
108
+ class CLIPFeatureDataset(Dataset):
109
+ """
110
+ Wraps HuggingFace dataset for a single CLIP variant.
111
+ Returns pre-extracted features and labels.
112
+ """
113
+
114
+ def __init__(
115
+ self,
116
+ hf_dataset,
117
+ feature_column: str = 'clip_features',
118
+ label_column: str = 'label',
119
+ ):
120
+ self.dataset = hf_dataset
121
+ self.feature_column = feature_column
122
+ self.label_column = label_column
123
+
124
+ def __len__(self):
125
+ return len(self.dataset)
126
+
127
+ def __getitem__(self, idx):
128
+ item = self.dataset[idx]
129
+ features = torch.tensor(item[self.feature_column], dtype=torch.float32)
130
+ label = item[self.label_column]
131
+ return features, label
132
+
133
+
134
+ class MultiCLIPDataset(Dataset):
135
+ """
136
+ Loads features from multiple CLIP variants simultaneously.
137
+ Returns dict of features + label.
138
+ """
139
+
140
+ def __init__(
141
+ self,
142
+ dataset_name: str,
143
+ split_prefix: str, # e.g., 'train' or 'validation'
144
+ clip_variants: Dict[str, int],
145
+ ):
146
+ self.variants = list(clip_variants.keys())
147
+ self.datasets = {}
148
+
149
+ print(f"Loading {split_prefix} splits...")
150
+ for variant in tqdm(self.variants, desc="Loading variants"):
151
+ split_name = f"{variant}_{split_prefix}"
152
+ try:
153
+ ds = load_dataset(dataset_name, split=split_name)
154
+ self.datasets[variant] = ds
155
+ print(f" {variant}: {len(ds):,} samples")
156
+ except Exception as e:
157
+ print(f" WARNING: Could not load {split_name}: {e}")
158
+
159
+ # Use first dataset for length (all should be same)
160
+ self.length = len(next(iter(self.datasets.values())))
161
+
162
+ # Verify all same length
163
+ for name, ds in self.datasets.items():
164
+ assert len(ds) == self.length, f"{name} has {len(ds)} != {self.length}"
165
+
166
+ def __len__(self):
167
+ return self.length
168
+
169
+ def __getitem__(self, idx):
170
+ features = {}
171
+ label = None
172
+
173
+ for variant, ds in self.datasets.items():
174
+ item = ds[idx]
175
+ features[variant] = torch.tensor(item['clip_features'], dtype=torch.float32)
176
+ if label is None:
177
+ label = item['label']
178
+
179
+ return features, label
180
+
181
+
182
+ def get_dataloaders(config: ImageNetCollectiveConfig):
183
+ """Create train and validation dataloaders."""
184
+
185
+ train_dataset = MultiCLIPDataset(
186
+ config.dataset_name,
187
+ 'train',
188
+ config.clip_variants,
189
+ )
190
+
191
+ val_dataset = MultiCLIPDataset(
192
+ config.dataset_name,
193
+ 'validation',
194
+ config.clip_variants,
195
+ )
196
+
197
+ # Collate function for dict of features
198
+ def collate_fn(batch):
199
+ features = {k: [] for k in config.clip_variants.keys()}
200
+ labels = []
201
+
202
+ for feat_dict, label in batch:
203
+ for k, v in feat_dict.items():
204
+ features[k].append(v)
205
+ labels.append(label)
206
+
207
+ features = {k: torch.stack(v) for k, v in features.items()}
208
+ labels = torch.tensor(labels, dtype=torch.long)
209
+
210
+ return features, labels
211
+
212
+ workers_per = config.workers_per_loader()
213
+
214
+ train_loader = DataLoader(
215
+ train_dataset,
216
+ batch_size=config.batch_size,
217
+ shuffle=True,
218
+ num_workers=config.num_workers,
219
+ pin_memory=config.pin_memory,
220
+ persistent_workers=config.persistent_workers if config.num_workers > 0 else False,
221
+ prefetch_factor=config.prefetch_factor if config.num_workers > 0 else None,
222
+ collate_fn=collate_fn,
223
+ drop_last=True,
224
+ )
225
+
226
+ val_loader = DataLoader(
227
+ val_dataset,
228
+ batch_size=config.batch_size,
229
+ shuffle=False,
230
+ num_workers=config.num_workers,
231
+ pin_memory=config.pin_memory,
232
+ persistent_workers=config.persistent_workers if config.num_workers > 0 else False,
233
+ prefetch_factor=config.prefetch_factor if config.num_workers > 0 else None,
234
+ collate_fn=collate_fn,
235
+ )
236
+
237
+ return train_loader, val_loader
238
+
239
+
240
+ # =============================================================================
241
+ # FEATURE STREAM (No CLIP model - just translation + routing)
242
+ # =============================================================================
243
+
244
+ class FeatureStream(nn.Module):
245
+ """
246
+ Stream for pre-extracted CLIP features.
247
+ No CLIP model - just translation head + router.
248
+ """
249
+
250
+ def __init__(
251
+ self,
252
+ config: ImageNetCollectiveConfig,
253
+ variant_name: str,
254
+ input_dim: int,
255
+ parent_id: Optional[str] = None,
256
+ ):
257
+ super().__init__()
258
+ self.config = config
259
+ self.variant_name = variant_name
260
+ self.input_dim = input_dim
261
+
262
+ # Translation head: CLIP dim β†’ routing space
263
+ self.translation = nn.Sequential(
264
+ nn.Linear(input_dim, config.feature_dim * 2),
265
+ nn.LayerNorm(config.feature_dim * 2),
266
+ nn.GELU(),
267
+ nn.Dropout(0.1),
268
+ nn.Linear(config.feature_dim * 2, config.feature_dim * config.num_slots),
269
+ )
270
+
271
+ # Learnable slot embeddings (unique per stream)
272
+ self.slot_embed = nn.Parameter(
273
+ torch.randn(1, config.num_slots, config.feature_dim) * 0.02
274
+ )
275
+
276
+ # Router with unique fingerprint
277
+ router_config = GlobalFractalRouterConfig(
278
+ feature_dim=config.feature_dim,
279
+ fingerprint_dim=config.fingerprint_dim,
280
+ num_anchors=config.num_anchors,
281
+ num_routes=config.num_routes,
282
+ use_adjacent_gating=True,
283
+ use_cantor_prior=True,
284
+ grid_size=(config.num_slots, 1),
285
+ )
286
+
287
+ self.router = GlobalFractalRouter(
288
+ config=router_config,
289
+ parent_id=parent_id,
290
+ cooperation_group="imagenet_collective",
291
+ name=variant_name,
292
+ )
293
+
294
+ @property
295
+ def fingerprint(self) -> torch.Tensor:
296
+ return self.router.fingerprint
297
+
298
+ @property
299
+ def module_id(self) -> str:
300
+ return self.router.module_id
301
+
302
+ def forward(
303
+ self,
304
+ features: torch.Tensor,
305
+ mailbox: RouterMailbox,
306
+ target_fingerprint: Optional[torch.Tensor] = None,
307
+ ) -> Tuple[torch.Tensor, Dict]:
308
+ """
309
+ Args:
310
+ features: [B, input_dim] pre-extracted CLIP features
311
+ mailbox: Shared mailbox
312
+ target_fingerprint: Next stream's fingerprint
313
+
314
+ Returns:
315
+ routed: [B, num_slots, feature_dim]
316
+ info: Dict with metrics
317
+ """
318
+ B = features.shape[0]
319
+
320
+ # Translate to routing space
321
+ translated = self.translation(features) # [B, feature_dim * num_slots]
322
+ slots = translated.view(B, self.config.num_slots, self.config.feature_dim)
323
+
324
+ # Add slot embeddings
325
+ slots = slots + self.slot_embed
326
+
327
+ # Route
328
+ routes, weights, routed = self.router(
329
+ slots,
330
+ mailbox=mailbox,
331
+ target_fingerprint=target_fingerprint,
332
+ skip_first=False,
333
+ )
334
+
335
+ info = {
336
+ 'route_entropy': -(weights * (weights + 1e-8).log()).sum(dim=-1).mean().item(),
337
+ }
338
+
339
+ return routed, info
340
+
341
+
342
+ # =============================================================================
343
+ # MULTI-CLIP COLLECTIVE
344
+ # =============================================================================
345
+
346
+ class ImageNetCollective(nn.Module):
347
+ """
348
+ Collective of pre-extracted CLIP features from multiple variants.
349
+ Hierarchical chain topology with shared mailbox coordination.
350
+ """
351
+
352
+ def __init__(self, config: ImageNetCollectiveConfig):
353
+ super().__init__()
354
+ self.config = config
355
+
356
+ # Reset registry for fresh start
357
+ get_registry().reset()
358
+
359
+ # Build streams in hierarchical chain
360
+ self.streams = nn.ModuleDict()
361
+ self.stream_order = list(config.clip_variants.keys())
362
+
363
+ parent_id = None
364
+ for variant_name, input_dim in config.clip_variants.items():
365
+ stream = FeatureStream(
366
+ config=config,
367
+ variant_name=variant_name,
368
+ input_dim=input_dim,
369
+ parent_id=parent_id,
370
+ )
371
+ self.streams[variant_name] = stream
372
+ parent_id = stream.module_id
373
+ print(f" Stream: {variant_name} ({input_dim}D) -> parent: {parent_id[:8] if parent_id else 'root'}...")
374
+
375
+ # Shared mailbox
376
+ router_config = GlobalFractalRouterConfig(
377
+ feature_dim=config.feature_dim,
378
+ fingerprint_dim=config.fingerprint_dim,
379
+ )
380
+ self.mailbox = RouterMailbox(router_config)
381
+
382
+ # Fusion layer
383
+ num_streams = len(config.clip_variants)
384
+ self.fusion = nn.Sequential(
385
+ nn.Linear(config.feature_dim * num_streams, config.feature_dim * 2),
386
+ nn.LayerNorm(config.feature_dim * 2),
387
+ nn.GELU(),
388
+ nn.Dropout(0.1),
389
+ nn.Linear(config.feature_dim * 2, config.feature_dim),
390
+ nn.LayerNorm(config.feature_dim),
391
+ )
392
+
393
+ # Classification head
394
+ self.classifier = nn.Linear(config.feature_dim, config.num_classes)
395
+
396
+ # Per-stream classifiers (for measuring individual contribution)
397
+ self.stream_classifiers = nn.ModuleDict({
398
+ name: nn.Linear(config.feature_dim, config.num_classes)
399
+ for name in config.clip_variants.keys()
400
+ })
401
+
402
+ def forward(
403
+ self,
404
+ features: Dict[str, torch.Tensor],
405
+ return_individual: bool = False,
406
+ ) -> Tuple[torch.Tensor, Dict]:
407
+ """
408
+ Args:
409
+ features: Dict mapping variant name to [B, clip_dim] features
410
+ return_individual: Also return per-stream predictions
411
+
412
+ Returns:
413
+ logits: [B, num_classes]
414
+ info: Dict with metrics
415
+ """
416
+ # Clear mailbox
417
+ self.mailbox.clear()
418
+
419
+ # Process streams in order
420
+ stream_features = {}
421
+ stream_infos = {}
422
+
423
+ for i, name in enumerate(self.stream_order):
424
+ stream = self.streams[name]
425
+
426
+ # Get target fingerprint (next stream or None)
427
+ if i < len(self.stream_order) - 1:
428
+ next_name = self.stream_order[i + 1]
429
+ target_fp = self.streams[next_name].fingerprint
430
+ else:
431
+ target_fp = None
432
+
433
+ # Forward
434
+ routed, info = stream(features[name], self.mailbox, target_fp)
435
+
436
+ # Pool across slots
437
+ pooled = routed.mean(dim=1) # [B, feature_dim]
438
+ stream_features[name] = pooled
439
+ stream_infos[name] = info
440
+
441
+ # Fuse all streams
442
+ fused = torch.cat([stream_features[n] for n in self.stream_order], dim=-1)
443
+ fused = self.fusion(fused)
444
+
445
+ # Classify
446
+ logits = self.classifier(fused)
447
+
448
+ info = {
449
+ 'stream_infos': stream_infos,
450
+ 'mailbox_messages': len(self.mailbox.messages),
451
+ 'mean_route_entropy': np.mean([i['route_entropy'] for i in stream_infos.values()]),
452
+ }
453
+
454
+ if return_individual:
455
+ individual_logits = {
456
+ name: self.stream_classifiers[name](stream_features[name])
457
+ for name in self.stream_order
458
+ }
459
+ info['individual_logits'] = individual_logits
460
+
461
+ return logits, info
462
+
463
+
464
+ # =============================================================================
465
+ # SINGLE STREAM BASELINE
466
+ # =============================================================================
467
+
468
+ class SingleStreamBaseline(nn.Module):
469
+ """Single CLIP variant with linear probe (no routing)."""
470
+
471
+ def __init__(self, config: ImageNetCollectiveConfig, variant_name: str, input_dim: int):
472
+ super().__init__()
473
+ self.variant_name = variant_name
474
+
475
+ self.classifier = nn.Sequential(
476
+ nn.Linear(input_dim, config.feature_dim),
477
+ nn.LayerNorm(config.feature_dim),
478
+ nn.GELU(),
479
+ nn.Dropout(0.1),
480
+ nn.Linear(config.feature_dim, config.num_classes),
481
+ )
482
+
483
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
484
+ return self.classifier(features)
485
+
486
+
487
+ # =============================================================================
488
+ # TRAINING
489
+ # =============================================================================
490
+
491
+ def train_collective(
492
+ model: ImageNetCollective,
493
+ train_loader: DataLoader,
494
+ val_loader: DataLoader,
495
+ config: ImageNetCollectiveConfig,
496
+ ):
497
+ """Train collective with AMP."""
498
+
499
+ optimizer = torch.optim.AdamW(
500
+ model.parameters(),
501
+ lr=config.lr,
502
+ weight_decay=config.weight_decay,
503
+ )
504
+
505
+ # Warmup + cosine schedule
506
+ total_steps = len(train_loader) * config.epochs
507
+ warmup_steps = len(train_loader) * config.warmup_epochs
508
+
509
+ def lr_lambda(step):
510
+ if step < warmup_steps:
511
+ return step / warmup_steps
512
+ progress = (step - warmup_steps) / (total_steps - warmup_steps)
513
+ return 0.5 * (1 + np.cos(np.pi * progress))
514
+
515
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
516
+ scaler = GradScaler() if config.use_amp else None
517
+
518
+ history = defaultdict(list)
519
+ best_acc = 0
520
+
521
+ for epoch in range(config.epochs):
522
+ model.train()
523
+ epoch_loss = 0
524
+ correct = 0
525
+ total = 0
526
+
527
+ pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.epochs}")
528
+
529
+ for features, labels in pbar:
530
+ # Move to device
531
+ features = {k: v.to(config.device, non_blocking=True) for k, v in features.items()}
532
+ labels = labels.to(config.device, non_blocking=True)
533
+
534
+ optimizer.zero_grad()
535
+
536
+ if config.use_amp:
537
+ with autocast():
538
+ logits, info = model(features)
539
+ loss = F.cross_entropy(logits, labels)
540
+
541
+ scaler.scale(loss).backward()
542
+ scaler.unscale_(optimizer)
543
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
544
+ scaler.step(optimizer)
545
+ scaler.update()
546
+ else:
547
+ logits, info = model(features)
548
+ loss = F.cross_entropy(logits, labels)
549
+ loss.backward()
550
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
551
+ optimizer.step()
552
+
553
+ scheduler.step()
554
+
555
+ epoch_loss += loss.item() * labels.size(0)
556
+ correct += (logits.argmax(dim=1) == labels).sum().item()
557
+ total += labels.size(0)
558
+
559
+ pbar.set_postfix({
560
+ 'loss': f"{loss.item():.4f}",
561
+ 'acc': f"{correct/total*100:.1f}%",
562
+ 'lr': f"{scheduler.get_last_lr()[0]:.2e}",
563
+ })
564
+
565
+ # Validate
566
+ val_acc, val_stream_accs = evaluate_collective(model, val_loader, config)
567
+
568
+ history['train_loss'].append(epoch_loss / total)
569
+ history['train_acc'].append(correct / total)
570
+ history['val_acc'].append(val_acc)
571
+ history['stream_accs'].append(val_stream_accs)
572
+
573
+ # Log
574
+ stream_str = ' | '.join([f"{k[:4]}: {v*100:.1f}%" for k, v in val_stream_accs.items()])
575
+ tqdm.write(f"Epoch {epoch+1:3d} | Loss: {epoch_loss/total:.4f} | "
576
+ f"Val: {val_acc*100:.2f}% | {stream_str}")
577
+
578
+ if val_acc > best_acc:
579
+ best_acc = val_acc
580
+ tqdm.write(f" β˜… New best: {best_acc*100:.2f}%")
581
+
582
+ return dict(history), best_acc
583
+
584
+
585
+ def evaluate_collective(
586
+ model: ImageNetCollective,
587
+ loader: DataLoader,
588
+ config: ImageNetCollectiveConfig,
589
+ ) -> Tuple[float, Dict[str, float]]:
590
+ """Evaluate collective and per-stream accuracy."""
591
+
592
+ model.eval()
593
+ correct = 0
594
+ total = 0
595
+ stream_correct = defaultdict(int)
596
+
597
+ with torch.no_grad():
598
+ for features, labels in tqdm(loader, desc="Eval", leave=False):
599
+ features = {k: v.to(config.device, non_blocking=True) for k, v in features.items()}
600
+ labels = labels.to(config.device, non_blocking=True)
601
+
602
+ if config.use_amp:
603
+ with autocast():
604
+ logits, info = model(features, return_individual=True)
605
+ else:
606
+ logits, info = model(features, return_individual=True)
607
+
608
+ correct += (logits.argmax(dim=1) == labels).sum().item()
609
+ total += labels.size(0)
610
+
611
+ for name, ind_logits in info['individual_logits'].items():
612
+ stream_correct[name] += (ind_logits.argmax(dim=1) == labels).sum().item()
613
+
614
+ acc = correct / total
615
+ stream_accs = {k: v / total for k, v in stream_correct.items()}
616
+
617
+ return acc, stream_accs
618
+
619
+
620
+ def train_baseline(
621
+ variant_name: str,
622
+ input_dim: int,
623
+ train_loader: DataLoader,
624
+ val_loader: DataLoader,
625
+ config: ImageNetCollectiveConfig,
626
+ ):
627
+ """Train single stream baseline."""
628
+
629
+ model = SingleStreamBaseline(config, variant_name, input_dim).to(config.device)
630
+
631
+ optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
632
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.epochs)
633
+ scaler = GradScaler() if config.use_amp else None
634
+
635
+ history = defaultdict(list)
636
+ best_acc = 0
637
+
638
+ for epoch in range(config.epochs):
639
+ model.train()
640
+ epoch_loss = 0
641
+ correct = 0
642
+ total = 0
643
+
644
+ for features, labels in tqdm(train_loader, desc=f"{variant_name} E{epoch+1}", leave=False):
645
+ feat = features[variant_name].to(config.device, non_blocking=True)
646
+ labels = labels.to(config.device, non_blocking=True)
647
+
648
+ optimizer.zero_grad()
649
+
650
+ if config.use_amp:
651
+ with autocast():
652
+ logits = model(feat)
653
+ loss = F.cross_entropy(logits, labels)
654
+ scaler.scale(loss).backward()
655
+ scaler.step(optimizer)
656
+ scaler.update()
657
+ else:
658
+ logits = model(feat)
659
+ loss = F.cross_entropy(logits, labels)
660
+ loss.backward()
661
+ optimizer.step()
662
+
663
+ epoch_loss += loss.item() * labels.size(0)
664
+ correct += (logits.argmax(dim=1) == labels).sum().item()
665
+ total += labels.size(0)
666
+
667
+ scheduler.step()
668
+
669
+ # Validate
670
+ model.eval()
671
+ val_correct = 0
672
+ val_total = 0
673
+
674
+ with torch.no_grad():
675
+ for features, labels in val_loader:
676
+ feat = features[variant_name].to(config.device, non_blocking=True)
677
+ labels = labels.to(config.device, non_blocking=True)
678
+
679
+ if config.use_amp:
680
+ with autocast():
681
+ logits = model(feat)
682
+ else:
683
+ logits = model(feat)
684
+
685
+ val_correct += (logits.argmax(dim=1) == labels).sum().item()
686
+ val_total += labels.size(0)
687
+
688
+ val_acc = val_correct / val_total
689
+ history['val_acc'].append(val_acc)
690
+
691
+ if val_acc > best_acc:
692
+ best_acc = val_acc
693
+
694
+ if (epoch + 1) % 5 == 0 or epoch == 0:
695
+ tqdm.write(f"{variant_name} Epoch {epoch+1:3d} | Val: {val_acc*100:.2f}%")
696
+
697
+ return dict(history), best_acc
698
+
699
+
700
+ # =============================================================================
701
+ # VISUALIZATION
702
+ # =============================================================================
703
+
704
+ def plot_results(
705
+ collective_history: Dict,
706
+ baseline_results: Dict[str, float],
707
+ config: ImageNetCollectiveConfig,
708
+ save_path: str = "imagenet_collective_results.png",
709
+ ):
710
+ """Plot training results."""
711
+
712
+ fig, axes = plt.subplots(2, 2, figsize=(14, 10))
713
+
714
+ epochs = range(1, len(collective_history['val_acc']) + 1)
715
+
716
+ # Validation accuracy over time
717
+ ax = axes[0, 0]
718
+ ax.plot(epochs, [a*100 for a in collective_history['val_acc']], 'b-',
719
+ label='Collective', linewidth=2)
720
+ for name in config.clip_variants.keys():
721
+ accs = [sa[name]*100 for sa in collective_history['stream_accs']]
722
+ ax.plot(epochs, accs, '--', label=f'{name} (in coll.)', alpha=0.7)
723
+ ax.set_xlabel('Epoch')
724
+ ax.set_ylabel('Validation Accuracy (%)')
725
+ ax.set_title('Training Progress')
726
+ ax.legend(fontsize=8)
727
+ ax.grid(True, alpha=0.3)
728
+
729
+ # Final comparison bar
730
+ ax = axes[0, 1]
731
+
732
+ final_collective = collective_history['val_acc'][-1] * 100
733
+ final_streams = {k: v*100 for k, v in collective_history['stream_accs'][-1].items()}
734
+
735
+ names = ['Collective'] + list(baseline_results.keys())
736
+ values = [final_collective] + [v*100 for v in baseline_results.values()]
737
+ colors = ['steelblue'] + ['coral'] * len(baseline_results)
738
+
739
+ bars = ax.bar(range(len(names)), values, color=colors)
740
+ ax.set_xticks(range(len(names)))
741
+ ax.set_xticklabels([n.replace('clip_vit_', '').replace('_', '\n') for n in names], fontsize=8)
742
+ ax.set_ylabel('Validation Accuracy (%)')
743
+ ax.set_title('Final Accuracy: Collective vs Individual Baselines')
744
+
745
+ for bar, val in zip(bars, values):
746
+ ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.3,
747
+ f'{val:.1f}%', ha='center', va='bottom', fontsize=8)
748
+
749
+ # Per-stream accuracy in collective vs baseline
750
+ ax = axes[1, 0]
751
+
752
+ stream_names = list(config.clip_variants.keys())
753
+ x = np.arange(len(stream_names))
754
+ width = 0.35
755
+
756
+ in_collective = [final_streams[n] for n in stream_names]
757
+ standalone = [baseline_results[n]*100 for n in stream_names]
758
+
759
+ bars1 = ax.bar(x - width/2, in_collective, width, label='In Collective', color='steelblue')
760
+ bars2 = ax.bar(x + width/2, standalone, width, label='Standalone', color='coral')
761
+
762
+ ax.set_ylabel('Accuracy (%)')
763
+ ax.set_title('Per-Stream: Collective vs Standalone')
764
+ ax.set_xticks(x)
765
+ ax.set_xticklabels([n.replace('clip_vit_', '') for n in stream_names], fontsize=8, rotation=45)
766
+ ax.legend()
767
+ ax.grid(True, alpha=0.3, axis='y')
768
+
769
+ # Summary
770
+ ax = axes[1, 1]
771
+ ax.axis('off')
772
+
773
+ best_baseline = max(baseline_results.values()) * 100
774
+ improvement = final_collective - best_baseline
775
+
776
+ summary = f"""
777
+ IMAGENET COLLECTIVE RESULTS
778
+ ════════════════════════════════════════════════════════
779
+
780
+ Collective: {final_collective:.2f}%
781
+ Best Individual: {best_baseline:.2f}%
782
+
783
+ Improvement: {improvement:+.2f}%
784
+
785
+ ════════════════════════════════════════════════════════
786
+
787
+ Per-stream in collective:
788
+ """
789
+
790
+ for name, acc in final_streams.items():
791
+ short_name = name.replace('clip_vit_', '')
792
+ summary += f"\n {short_name:<15}: {acc:.2f}%"
793
+
794
+ summary += """
795
+
796
+ ════════════════════════════════════════════════════════
797
+
798
+ Individual baselines:
799
+ """
800
+
801
+ for name, acc in baseline_results.items():
802
+ short_name = name.replace('clip_vit_', '')
803
+ summary += f"\n {short_name:<15}: {acc*100:.2f}%"
804
+
805
+ ax.text(0.05, 0.95, summary, fontsize=10, family='monospace',
806
+ verticalalignment='top', transform=ax.transAxes)
807
+
808
+ plt.tight_layout()
809
+ plt.savefig(save_path, dpi=150, bbox_inches='tight')
810
+ plt.show()
811
+ print(f"\nSaved: {save_path}")
812
+
813
+
814
+ # =============================================================================
815
+ # MAIN
816
+ # =============================================================================
817
+
818
+ def main():
819
+ print("="*70)
820
+ print(" ImageNet Multi-CLIP Collective Experiment")
821
+ print(" Pre-extracted Features via GlobalFractalRouter")
822
+ print("="*70)
823
+
824
+ config = ImageNetCollectiveConfig()
825
+
826
+ print(f"\nConfig:")
827
+ print(f" Dataset: {config.dataset_name}")
828
+ print(f" Variants: {len(config.clip_variants)}")
829
+ for name, dim in config.clip_variants.items():
830
+ print(f" - {name}: {dim}D")
831
+ print(f" Feature dim: {config.feature_dim}")
832
+ print(f" Epochs: {config.epochs}")
833
+ print(f" Batch size: {config.batch_size}")
834
+ print(f" AMP: {config.use_amp}")
835
+ print(f" Device: {config.device}")
836
+
837
+ # Data
838
+ print("\n" + "="*70)
839
+ print(" Loading Data")
840
+ print("="*70)
841
+
842
+ train_loader, val_loader = get_dataloaders(config)
843
+ print(f"\n Train batches: {len(train_loader)}")
844
+ print(f" Val batches: {len(val_loader)}")
845
+
846
+ # =================================================================
847
+ # COLLECTIVE
848
+ # =================================================================
849
+ print("\n" + "="*70)
850
+ print(" Training COLLECTIVE")
851
+ print("="*70)
852
+
853
+ collective = ImageNetCollective(config).to(config.device)
854
+
855
+ params = sum(p.numel() for p in collective.parameters())
856
+ print(f"\n Parameters: {params:,}")
857
+
858
+ collective_history, collective_best = train_collective(
859
+ collective, train_loader, val_loader, config
860
+ )
861
+
862
+ # =================================================================
863
+ # BASELINES
864
+ # =================================================================
865
+ print("\n" + "="*70)
866
+ print(" Training BASELINES (Individual Streams)")
867
+ print("="*70)
868
+
869
+ baseline_results = {}
870
+
871
+ for variant_name, input_dim in config.clip_variants.items():
872
+ print(f"\n Training: {variant_name}")
873
+ _, best_acc = train_baseline(
874
+ variant_name, input_dim, train_loader, val_loader, config
875
+ )
876
+ baseline_results[variant_name] = best_acc
877
+ print(f" {variant_name} best: {best_acc*100:.2f}%")
878
+
879
+ # =================================================================
880
+ # RESULTS
881
+ # =================================================================
882
+ print("\n" + "="*70)
883
+ print(" FINAL RESULTS")
884
+ print("="*70)
885
+
886
+ print(f"\n Collective: {collective_best*100:.2f}%")
887
+ print(f" Best individual: {max(baseline_results.values())*100:.2f}%")
888
+ print(f" Improvement: {(collective_best - max(baseline_results.values()))*100:+.2f}%")
889
+
890
+ print("\n Per-stream final (in collective):")
891
+ for name, acc in collective_history['stream_accs'][-1].items():
892
+ print(f" {name}: {acc*100:.2f}%")
893
+
894
+ print("\n Individual baselines:")
895
+ for name, acc in baseline_results.items():
896
+ print(f" {name}: {acc*100:.2f}%")
897
+
898
+ plot_results(collective_history, baseline_results, config)
899
+
900
+ return collective, collective_history, baseline_results
901
+
902
+
903
+ if __name__ == "__main__":
904
+ results = main()