#!/usr/bin/env python3 """ Train CantorLinear classifier on pre-extracted ImageNet CLIP features. Uses AbstractPhil/imagenet-clip-features-orderly dataset from HuggingFace. Author: AbstractPhil License: MIT Uses the geometricvocab github implementation. try: !pip uninstall -qy geometricvocab except: pass !pip install -q git+https://github.com/AbstractEyes/lattice_vocabulary.git """ import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, Dataset from datasets import load_dataset from tqdm import tqdm import wandb from dataclasses import dataclass import sys import math # Import your CantorLinear layer # Adjust the import path as needed for your setup from geovocab2.train.model.layers.linear import CantorLinear, CantorLinearConfig # ============================================================ # CONFIGURATION # ============================================================ @dataclass class TrainConfig: # Dataset dataset_name: str = "AbstractPhil/imagenet-clip-features-orderly" clip_dim: int = 512 # CLIP ViT-B/16 feature dimension num_classes: int = 1000 # ImageNet classes # Model hidden_dims: list = None # [2048, 1024] for 2-layer, None for direct cantor_depth: int = 8 mask_mode: str = "alpha" alpha_mode: str = "sigmoid" alpha_min: float = 0.1 alpha_max: float = 1.0 per_output_alpha: bool = False dropout: float = 0.1 # Training batch_size: int = 512 num_epochs: int = 50 learning_rate: float = 1e-3 weight_decay: float = 1e-4 warmup_epochs: int = 5 # Optimizer alpha_lr_mult: float = 0.1 # Separate LR for alpha parameters # Logging use_wandb: bool = False wandb_project: str = "cantor-imagenet" log_every: int = 50 eval_every: int = 500 # System device: str = "cuda" if torch.cuda.is_available() else "cpu" num_workers: int = 4 seed: int = 42 def __post_init__(self): if self.hidden_dims is None: self.hidden_dims = [] # Direct CLIP → classes # ============================================================ # DATASET # ============================================================ class CLIPFeaturesDataset(Dataset): """Wrapper for HuggingFace dataset of CLIP features.""" def __init__(self, hf_dataset): self.dataset = hf_dataset def __len__(self): return len(self.dataset) def __getitem__(self, idx): item = self.dataset[idx] features = torch.tensor(item['clip_features'], dtype=torch.float32) label = torch.tensor(item['label'], dtype=torch.long) return features, label # ============================================================ # MODEL # ============================================================ class CantorCLIPClassifier(nn.Module): """ Multi-layer classifier using CantorLinear layers. Maps CLIP features → [hidden layers] → ImageNet classes """ def __init__(self, cfg: TrainConfig): super().__init__() self.cfg = cfg # Build layers layers = [] in_dim = cfg.clip_dim # Hidden layers for hidden_dim in cfg.hidden_dims: layers.append(CantorLinear(CantorLinearConfig( in_features=in_dim, out_features=hidden_dim, depth=cfg.cantor_depth, mask_mode=cfg.mask_mode, alpha_mode=cfg.alpha_mode, alpha_min=cfg.alpha_min, alpha_max=cfg.alpha_max, per_output_alpha=cfg.per_output_alpha ))) layers.append(nn.ReLU()) layers.append(nn.Dropout(cfg.dropout)) in_dim = hidden_dim # Output layer layers.append(CantorLinear(CantorLinearConfig( in_features=in_dim, out_features=cfg.num_classes, depth=cfg.cantor_depth, mask_mode=cfg.mask_mode, alpha_mode=cfg.alpha_mode, alpha_min=cfg.alpha_min, alpha_max=cfg.alpha_max, per_output_alpha=cfg.per_output_alpha ))) self.classifier = nn.Sequential(*layers) def forward(self, x): return self.classifier(x) def get_alpha_stats(self): """Collect alpha statistics from all CantorLinear layers.""" stats = { "layer_names": [], "alpha_means": [], "alpha_stds": [], "mask_densities": [] } for name, module in self.named_modules(): if isinstance(module, CantorLinear): alpha_stats = module.get_alpha_stats() if alpha_stats: stats["layer_names"].append(name) stats["alpha_means"].append(alpha_stats["alpha_mean"]) stats["alpha_stds"].append(alpha_stats.get("alpha_std", 0.0)) stats["mask_densities"].append(module.mask.mean().item()) return stats # ============================================================ # TRAINING # ============================================================ def train_epoch(model, dataloader, criterion, optimizer, scheduler, cfg, epoch): """Train for one epoch.""" model.train() total_loss = 0.0 correct = 0 total = 0 pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{cfg.num_epochs}") for batch_idx, (features, labels) in enumerate(pbar): features = features.to(cfg.device) labels = labels.to(cfg.device) # Forward optimizer.zero_grad() outputs = model(features) loss = criterion(outputs, labels) # Backward loss.backward() optimizer.step() if scheduler is not None: scheduler.step() # Metrics total_loss += loss.item() _, predicted = outputs.max(1) total += labels.size(0) correct += predicted.eq(labels).sum().item() # Logging if batch_idx % cfg.log_every == 0: avg_loss = total_loss / (batch_idx + 1) acc = 100. * correct / total pbar.set_postfix({ 'loss': f'{avg_loss:.4f}', 'acc': f'{acc:.2f}%' }) if cfg.use_wandb: wandb.log({ 'train/loss': avg_loss, 'train/acc': acc, 'train/lr': optimizer.param_groups[0]['lr'] }) return total_loss / len(dataloader), 100. * correct / total def evaluate(model, dataloader, criterion, cfg): """Evaluate model.""" model.eval() total_loss = 0.0 correct = 0 total = 0 with torch.no_grad(): for features, labels in tqdm(dataloader, desc="Evaluating"): features = features.to(cfg.device) labels = labels.to(cfg.device) outputs = model(features) loss = criterion(outputs, labels) total_loss += loss.item() _, predicted = outputs.max(1) total += labels.size(0) correct += predicted.eq(labels).sum().item() avg_loss = total_loss / len(dataloader) acc = 100. * correct / total return avg_loss, acc def main(): cfg = TrainConfig() # Set seed torch.manual_seed(cfg.seed) if torch.cuda.is_available(): torch.cuda.manual_seed(cfg.seed) print("=" * 60) print("CantorLinear ImageNet CLIP Features Training") print("=" * 60) print(f"\nConfiguration:") print(f" Dataset: {cfg.dataset_name}") print(f" CLIP dim: {cfg.clip_dim}") print(f" Hidden dims: {cfg.hidden_dims if cfg.hidden_dims else 'Direct'}") print(f" Cantor depth: {cfg.cantor_depth}") print(f" Batch size: {cfg.batch_size}") print(f" Learning rate: {cfg.learning_rate}") print(f" Device: {cfg.device}") # Initialize wandb if cfg.use_wandb: wandb.init(project=cfg.wandb_project, config=vars(cfg)) # Load dataset print("\nLoading dataset...") dataset = load_dataset(cfg.dataset_name, name="clip_vit_b16", split="train") # Split into train/val (90/10) dataset = dataset.train_test_split(test_size=0.1, seed=cfg.seed) train_dataset = CLIPFeaturesDataset(dataset['train']) val_dataset = CLIPFeaturesDataset(dataset['test']) print(f"Train samples: {len(train_dataset)}") print(f"Val samples: {len(val_dataset)}") # Create dataloaders train_loader = DataLoader( train_dataset, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers, pin_memory=True ) val_loader = DataLoader( val_dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers, pin_memory=True ) # Create model print("\nBuilding model...") model = CantorCLIPClassifier(cfg).to(cfg.device) # Print model info total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"Total parameters: {total_params:,}") print(f"Trainable parameters: {trainable_params:,}") # Alpha statistics stats = model.get_alpha_stats() if stats['alpha_means']: print(f"CantorLinear layers: {len(stats['alpha_means'])}") print(f"Avg mask density: {sum(stats['mask_densities'])/len(stats['mask_densities']):.4f}") # Loss and optimizer criterion = nn.CrossEntropyLoss() # Separate learning rates for alpha parameters alpha_params = [] other_params = [] for name, param in model.named_parameters(): if 'alpha' in name: alpha_params.append(param) else: other_params.append(param) optimizer = optim.AdamW([ {'params': other_params, 'lr': cfg.learning_rate}, {'params': alpha_params, 'lr': cfg.learning_rate * cfg.alpha_lr_mult} ], weight_decay=cfg.weight_decay) # Learning rate scheduler with warmup total_steps = len(train_loader) * cfg.num_epochs warmup_steps = len(train_loader) * cfg.warmup_epochs def lr_lambda(step): if step < warmup_steps: return step / warmup_steps else: return 0.5 * (1 + math.cos(math.pi * (step - warmup_steps) / (total_steps - warmup_steps))) scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) # Training loop print("\nStarting training...") best_val_acc = 0.0 for epoch in range(cfg.num_epochs): train_loss, train_acc = train_epoch( model, train_loader, criterion, optimizer, scheduler, cfg, epoch ) val_loss, val_acc = evaluate(model, val_loader, criterion, cfg) print(f"\nEpoch {epoch+1}/{cfg.num_epochs}") print(f" Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%") print(f" Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%") # Log alpha evolution stats = model.get_alpha_stats() if stats['alpha_means']: mean_alpha = sum(stats['alpha_means']) / len(stats['alpha_means']) mean_density = sum(stats['mask_densities']) / len(stats['mask_densities']) print(f" Mean Alpha: {mean_alpha:.4f} | Mean Density: {mean_density:.4f}") if cfg.use_wandb: wandb.log({ 'val/loss': val_loss, 'val/acc': val_acc, 'alpha/mean': mean_alpha, 'alpha/density': mean_density, 'epoch': epoch }) # Save best model if val_acc > best_val_acc: best_val_acc = val_acc torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'val_acc': val_acc, 'config': cfg }, 'best_cantor_imagenet.pt') print(f" ✓ New best model saved! (Val Acc: {val_acc:.2f}%)") print("\n" + "=" * 60) print(f"Training complete! Best Val Acc: {best_val_acc:.2f}%") print("=" * 60) if cfg.use_wandb: wandb.finish() if __name__ == "__main__": main()