AbstractPhil's picture
Create trainer.py
e889862 verified
#!/usr/bin/env python3
"""
Train CantorLinear classifier on pre-extracted ImageNet CLIP features.
Uses AbstractPhil/imagenet-clip-features-orderly dataset from HuggingFace.
Author: AbstractPhil
License: MIT
Uses the geometricvocab github implementation.
try:
!pip uninstall -qy geometricvocab
except:
pass
!pip install -q git+https://github.com/AbstractEyes/lattice_vocabulary.git
"""
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
from tqdm import tqdm
import wandb
from dataclasses import dataclass
import sys
import math
# Import your CantorLinear layer
# Adjust the import path as needed for your setup
from geovocab2.train.model.layers.linear import CantorLinear, CantorLinearConfig
# ============================================================
# CONFIGURATION
# ============================================================
@dataclass
class TrainConfig:
# Dataset
dataset_name: str = "AbstractPhil/imagenet-clip-features-orderly"
clip_dim: int = 512 # CLIP ViT-B/16 feature dimension
num_classes: int = 1000 # ImageNet classes
# Model
hidden_dims: list = None # [2048, 1024] for 2-layer, None for direct
cantor_depth: int = 8
mask_mode: str = "alpha"
alpha_mode: str = "sigmoid"
alpha_min: float = 0.1
alpha_max: float = 1.0
per_output_alpha: bool = False
dropout: float = 0.1
# Training
batch_size: int = 512
num_epochs: int = 50
learning_rate: float = 1e-3
weight_decay: float = 1e-4
warmup_epochs: int = 5
# Optimizer
alpha_lr_mult: float = 0.1 # Separate LR for alpha parameters
# Logging
use_wandb: bool = False
wandb_project: str = "cantor-imagenet"
log_every: int = 50
eval_every: int = 500
# System
device: str = "cuda" if torch.cuda.is_available() else "cpu"
num_workers: int = 4
seed: int = 42
def __post_init__(self):
if self.hidden_dims is None:
self.hidden_dims = [] # Direct CLIP → classes
# ============================================================
# DATASET
# ============================================================
class CLIPFeaturesDataset(Dataset):
"""Wrapper for HuggingFace dataset of CLIP features."""
def __init__(self, hf_dataset):
self.dataset = hf_dataset
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
item = self.dataset[idx]
features = torch.tensor(item['clip_features'], dtype=torch.float32)
label = torch.tensor(item['label'], dtype=torch.long)
return features, label
# ============================================================
# MODEL
# ============================================================
class CantorCLIPClassifier(nn.Module):
"""
Multi-layer classifier using CantorLinear layers.
Maps CLIP features → [hidden layers] → ImageNet classes
"""
def __init__(self, cfg: TrainConfig):
super().__init__()
self.cfg = cfg
# Build layers
layers = []
in_dim = cfg.clip_dim
# Hidden layers
for hidden_dim in cfg.hidden_dims:
layers.append(CantorLinear(CantorLinearConfig(
in_features=in_dim,
out_features=hidden_dim,
depth=cfg.cantor_depth,
mask_mode=cfg.mask_mode,
alpha_mode=cfg.alpha_mode,
alpha_min=cfg.alpha_min,
alpha_max=cfg.alpha_max,
per_output_alpha=cfg.per_output_alpha
)))
layers.append(nn.ReLU())
layers.append(nn.Dropout(cfg.dropout))
in_dim = hidden_dim
# Output layer
layers.append(CantorLinear(CantorLinearConfig(
in_features=in_dim,
out_features=cfg.num_classes,
depth=cfg.cantor_depth,
mask_mode=cfg.mask_mode,
alpha_mode=cfg.alpha_mode,
alpha_min=cfg.alpha_min,
alpha_max=cfg.alpha_max,
per_output_alpha=cfg.per_output_alpha
)))
self.classifier = nn.Sequential(*layers)
def forward(self, x):
return self.classifier(x)
def get_alpha_stats(self):
"""Collect alpha statistics from all CantorLinear layers."""
stats = {
"layer_names": [],
"alpha_means": [],
"alpha_stds": [],
"mask_densities": []
}
for name, module in self.named_modules():
if isinstance(module, CantorLinear):
alpha_stats = module.get_alpha_stats()
if alpha_stats:
stats["layer_names"].append(name)
stats["alpha_means"].append(alpha_stats["alpha_mean"])
stats["alpha_stds"].append(alpha_stats.get("alpha_std", 0.0))
stats["mask_densities"].append(module.mask.mean().item())
return stats
# ============================================================
# TRAINING
# ============================================================
def train_epoch(model, dataloader, criterion, optimizer, scheduler, cfg, epoch):
"""Train for one epoch."""
model.train()
total_loss = 0.0
correct = 0
total = 0
pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{cfg.num_epochs}")
for batch_idx, (features, labels) in enumerate(pbar):
features = features.to(cfg.device)
labels = labels.to(cfg.device)
# Forward
optimizer.zero_grad()
outputs = model(features)
loss = criterion(outputs, labels)
# Backward
loss.backward()
optimizer.step()
if scheduler is not None:
scheduler.step()
# Metrics
total_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
# Logging
if batch_idx % cfg.log_every == 0:
avg_loss = total_loss / (batch_idx + 1)
acc = 100. * correct / total
pbar.set_postfix({
'loss': f'{avg_loss:.4f}',
'acc': f'{acc:.2f}%'
})
if cfg.use_wandb:
wandb.log({
'train/loss': avg_loss,
'train/acc': acc,
'train/lr': optimizer.param_groups[0]['lr']
})
return total_loss / len(dataloader), 100. * correct / total
def evaluate(model, dataloader, criterion, cfg):
"""Evaluate model."""
model.eval()
total_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for features, labels in tqdm(dataloader, desc="Evaluating"):
features = features.to(cfg.device)
labels = labels.to(cfg.device)
outputs = model(features)
loss = criterion(outputs, labels)
total_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
avg_loss = total_loss / len(dataloader)
acc = 100. * correct / total
return avg_loss, acc
def main():
cfg = TrainConfig()
# Set seed
torch.manual_seed(cfg.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(cfg.seed)
print("=" * 60)
print("CantorLinear ImageNet CLIP Features Training")
print("=" * 60)
print(f"\nConfiguration:")
print(f" Dataset: {cfg.dataset_name}")
print(f" CLIP dim: {cfg.clip_dim}")
print(f" Hidden dims: {cfg.hidden_dims if cfg.hidden_dims else 'Direct'}")
print(f" Cantor depth: {cfg.cantor_depth}")
print(f" Batch size: {cfg.batch_size}")
print(f" Learning rate: {cfg.learning_rate}")
print(f" Device: {cfg.device}")
# Initialize wandb
if cfg.use_wandb:
wandb.init(project=cfg.wandb_project, config=vars(cfg))
# Load dataset
print("\nLoading dataset...")
dataset = load_dataset(cfg.dataset_name, name="clip_vit_b16", split="train")
# Split into train/val (90/10)
dataset = dataset.train_test_split(test_size=0.1, seed=cfg.seed)
train_dataset = CLIPFeaturesDataset(dataset['train'])
val_dataset = CLIPFeaturesDataset(dataset['test'])
print(f"Train samples: {len(train_dataset)}")
print(f"Val samples: {len(val_dataset)}")
# Create dataloaders
train_loader = DataLoader(
train_dataset,
batch_size=cfg.batch_size,
shuffle=True,
num_workers=cfg.num_workers,
pin_memory=True
)
val_loader = DataLoader(
val_dataset,
batch_size=cfg.batch_size,
shuffle=False,
num_workers=cfg.num_workers,
pin_memory=True
)
# Create model
print("\nBuilding model...")
model = CantorCLIPClassifier(cfg).to(cfg.device)
# Print model info
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
# Alpha statistics
stats = model.get_alpha_stats()
if stats['alpha_means']:
print(f"CantorLinear layers: {len(stats['alpha_means'])}")
print(f"Avg mask density: {sum(stats['mask_densities'])/len(stats['mask_densities']):.4f}")
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
# Separate learning rates for alpha parameters
alpha_params = []
other_params = []
for name, param in model.named_parameters():
if 'alpha' in name:
alpha_params.append(param)
else:
other_params.append(param)
optimizer = optim.AdamW([
{'params': other_params, 'lr': cfg.learning_rate},
{'params': alpha_params, 'lr': cfg.learning_rate * cfg.alpha_lr_mult}
], weight_decay=cfg.weight_decay)
# Learning rate scheduler with warmup
total_steps = len(train_loader) * cfg.num_epochs
warmup_steps = len(train_loader) * cfg.warmup_epochs
def lr_lambda(step):
if step < warmup_steps:
return step / warmup_steps
else:
return 0.5 * (1 + math.cos(math.pi * (step - warmup_steps) / (total_steps - warmup_steps)))
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
# Training loop
print("\nStarting training...")
best_val_acc = 0.0
for epoch in range(cfg.num_epochs):
train_loss, train_acc = train_epoch(
model, train_loader, criterion, optimizer, scheduler, cfg, epoch
)
val_loss, val_acc = evaluate(model, val_loader, criterion, cfg)
print(f"\nEpoch {epoch+1}/{cfg.num_epochs}")
print(f" Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
print(f" Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
# Log alpha evolution
stats = model.get_alpha_stats()
if stats['alpha_means']:
mean_alpha = sum(stats['alpha_means']) / len(stats['alpha_means'])
mean_density = sum(stats['mask_densities']) / len(stats['mask_densities'])
print(f" Mean Alpha: {mean_alpha:.4f} | Mean Density: {mean_density:.4f}")
if cfg.use_wandb:
wandb.log({
'val/loss': val_loss,
'val/acc': val_acc,
'alpha/mean': mean_alpha,
'alpha/density': mean_density,
'epoch': epoch
})
# Save best model
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'val_acc': val_acc,
'config': cfg
}, 'best_cantor_imagenet.pt')
print(f" ✓ New best model saved! (Val Acc: {val_acc:.2f}%)")
print("\n" + "=" * 60)
print(f"Training complete! Best Val Acc: {best_val_acc:.2f}%")
print("=" * 60)
if cfg.use_wandb:
wandb.finish()
if __name__ == "__main__":
main()