|
|
|
|
|
""" |
|
|
Train CantorLinear classifier on pre-extracted ImageNet CLIP features. |
|
|
Uses AbstractPhil/imagenet-clip-features-orderly dataset from HuggingFace. |
|
|
Author: AbstractPhil |
|
|
License: MIT |
|
|
|
|
|
|
|
|
Uses the geometricvocab github implementation. |
|
|
try: |
|
|
!pip uninstall -qy geometricvocab |
|
|
except: |
|
|
pass |
|
|
|
|
|
!pip install -q git+https://github.com/AbstractEyes/lattice_vocabulary.git |
|
|
|
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
from torch.utils.data import DataLoader, Dataset |
|
|
from datasets import load_dataset |
|
|
from tqdm import tqdm |
|
|
import wandb |
|
|
from dataclasses import dataclass |
|
|
import sys |
|
|
import math |
|
|
|
|
|
|
|
|
|
|
|
from geovocab2.train.model.layers.linear import CantorLinear, CantorLinearConfig |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TrainConfig: |
|
|
|
|
|
dataset_name: str = "AbstractPhil/imagenet-clip-features-orderly" |
|
|
clip_dim: int = 512 |
|
|
num_classes: int = 1000 |
|
|
|
|
|
|
|
|
hidden_dims: list = None |
|
|
cantor_depth: int = 8 |
|
|
mask_mode: str = "alpha" |
|
|
alpha_mode: str = "sigmoid" |
|
|
alpha_min: float = 0.1 |
|
|
alpha_max: float = 1.0 |
|
|
per_output_alpha: bool = False |
|
|
dropout: float = 0.1 |
|
|
|
|
|
|
|
|
batch_size: int = 512 |
|
|
num_epochs: int = 50 |
|
|
learning_rate: float = 1e-3 |
|
|
weight_decay: float = 1e-4 |
|
|
warmup_epochs: int = 5 |
|
|
|
|
|
|
|
|
alpha_lr_mult: float = 0.1 |
|
|
|
|
|
|
|
|
use_wandb: bool = False |
|
|
wandb_project: str = "cantor-imagenet" |
|
|
log_every: int = 50 |
|
|
eval_every: int = 500 |
|
|
|
|
|
|
|
|
device: str = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
num_workers: int = 4 |
|
|
seed: int = 42 |
|
|
|
|
|
def __post_init__(self): |
|
|
if self.hidden_dims is None: |
|
|
self.hidden_dims = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CLIPFeaturesDataset(Dataset): |
|
|
"""Wrapper for HuggingFace dataset of CLIP features.""" |
|
|
|
|
|
def __init__(self, hf_dataset): |
|
|
self.dataset = hf_dataset |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.dataset) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
item = self.dataset[idx] |
|
|
features = torch.tensor(item['clip_features'], dtype=torch.float32) |
|
|
label = torch.tensor(item['label'], dtype=torch.long) |
|
|
return features, label |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CantorCLIPClassifier(nn.Module): |
|
|
""" |
|
|
Multi-layer classifier using CantorLinear layers. |
|
|
Maps CLIP features → [hidden layers] → ImageNet classes |
|
|
""" |
|
|
|
|
|
def __init__(self, cfg: TrainConfig): |
|
|
super().__init__() |
|
|
self.cfg = cfg |
|
|
|
|
|
|
|
|
layers = [] |
|
|
in_dim = cfg.clip_dim |
|
|
|
|
|
|
|
|
for hidden_dim in cfg.hidden_dims: |
|
|
layers.append(CantorLinear(CantorLinearConfig( |
|
|
in_features=in_dim, |
|
|
out_features=hidden_dim, |
|
|
depth=cfg.cantor_depth, |
|
|
mask_mode=cfg.mask_mode, |
|
|
alpha_mode=cfg.alpha_mode, |
|
|
alpha_min=cfg.alpha_min, |
|
|
alpha_max=cfg.alpha_max, |
|
|
per_output_alpha=cfg.per_output_alpha |
|
|
))) |
|
|
layers.append(nn.ReLU()) |
|
|
layers.append(nn.Dropout(cfg.dropout)) |
|
|
in_dim = hidden_dim |
|
|
|
|
|
|
|
|
layers.append(CantorLinear(CantorLinearConfig( |
|
|
in_features=in_dim, |
|
|
out_features=cfg.num_classes, |
|
|
depth=cfg.cantor_depth, |
|
|
mask_mode=cfg.mask_mode, |
|
|
alpha_mode=cfg.alpha_mode, |
|
|
alpha_min=cfg.alpha_min, |
|
|
alpha_max=cfg.alpha_max, |
|
|
per_output_alpha=cfg.per_output_alpha |
|
|
))) |
|
|
|
|
|
self.classifier = nn.Sequential(*layers) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.classifier(x) |
|
|
|
|
|
def get_alpha_stats(self): |
|
|
"""Collect alpha statistics from all CantorLinear layers.""" |
|
|
stats = { |
|
|
"layer_names": [], |
|
|
"alpha_means": [], |
|
|
"alpha_stds": [], |
|
|
"mask_densities": [] |
|
|
} |
|
|
|
|
|
for name, module in self.named_modules(): |
|
|
if isinstance(module, CantorLinear): |
|
|
alpha_stats = module.get_alpha_stats() |
|
|
if alpha_stats: |
|
|
stats["layer_names"].append(name) |
|
|
stats["alpha_means"].append(alpha_stats["alpha_mean"]) |
|
|
stats["alpha_stds"].append(alpha_stats.get("alpha_std", 0.0)) |
|
|
stats["mask_densities"].append(module.mask.mean().item()) |
|
|
|
|
|
return stats |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_epoch(model, dataloader, criterion, optimizer, scheduler, cfg, epoch): |
|
|
"""Train for one epoch.""" |
|
|
model.train() |
|
|
total_loss = 0.0 |
|
|
correct = 0 |
|
|
total = 0 |
|
|
|
|
|
pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{cfg.num_epochs}") |
|
|
|
|
|
for batch_idx, (features, labels) in enumerate(pbar): |
|
|
features = features.to(cfg.device) |
|
|
labels = labels.to(cfg.device) |
|
|
|
|
|
|
|
|
optimizer.zero_grad() |
|
|
outputs = model(features) |
|
|
loss = criterion(outputs, labels) |
|
|
|
|
|
|
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
if scheduler is not None: |
|
|
scheduler.step() |
|
|
|
|
|
|
|
|
total_loss += loss.item() |
|
|
_, predicted = outputs.max(1) |
|
|
total += labels.size(0) |
|
|
correct += predicted.eq(labels).sum().item() |
|
|
|
|
|
|
|
|
if batch_idx % cfg.log_every == 0: |
|
|
avg_loss = total_loss / (batch_idx + 1) |
|
|
acc = 100. * correct / total |
|
|
pbar.set_postfix({ |
|
|
'loss': f'{avg_loss:.4f}', |
|
|
'acc': f'{acc:.2f}%' |
|
|
}) |
|
|
|
|
|
if cfg.use_wandb: |
|
|
wandb.log({ |
|
|
'train/loss': avg_loss, |
|
|
'train/acc': acc, |
|
|
'train/lr': optimizer.param_groups[0]['lr'] |
|
|
}) |
|
|
|
|
|
return total_loss / len(dataloader), 100. * correct / total |
|
|
|
|
|
|
|
|
def evaluate(model, dataloader, criterion, cfg): |
|
|
"""Evaluate model.""" |
|
|
model.eval() |
|
|
total_loss = 0.0 |
|
|
correct = 0 |
|
|
total = 0 |
|
|
|
|
|
with torch.no_grad(): |
|
|
for features, labels in tqdm(dataloader, desc="Evaluating"): |
|
|
features = features.to(cfg.device) |
|
|
labels = labels.to(cfg.device) |
|
|
|
|
|
outputs = model(features) |
|
|
loss = criterion(outputs, labels) |
|
|
|
|
|
total_loss += loss.item() |
|
|
_, predicted = outputs.max(1) |
|
|
total += labels.size(0) |
|
|
correct += predicted.eq(labels).sum().item() |
|
|
|
|
|
avg_loss = total_loss / len(dataloader) |
|
|
acc = 100. * correct / total |
|
|
|
|
|
return avg_loss, acc |
|
|
|
|
|
|
|
|
def main(): |
|
|
cfg = TrainConfig() |
|
|
|
|
|
|
|
|
torch.manual_seed(cfg.seed) |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.manual_seed(cfg.seed) |
|
|
|
|
|
print("=" * 60) |
|
|
print("CantorLinear ImageNet CLIP Features Training") |
|
|
print("=" * 60) |
|
|
print(f"\nConfiguration:") |
|
|
print(f" Dataset: {cfg.dataset_name}") |
|
|
print(f" CLIP dim: {cfg.clip_dim}") |
|
|
print(f" Hidden dims: {cfg.hidden_dims if cfg.hidden_dims else 'Direct'}") |
|
|
print(f" Cantor depth: {cfg.cantor_depth}") |
|
|
print(f" Batch size: {cfg.batch_size}") |
|
|
print(f" Learning rate: {cfg.learning_rate}") |
|
|
print(f" Device: {cfg.device}") |
|
|
|
|
|
|
|
|
if cfg.use_wandb: |
|
|
wandb.init(project=cfg.wandb_project, config=vars(cfg)) |
|
|
|
|
|
|
|
|
print("\nLoading dataset...") |
|
|
dataset = load_dataset(cfg.dataset_name, name="clip_vit_b16", split="train") |
|
|
|
|
|
|
|
|
dataset = dataset.train_test_split(test_size=0.1, seed=cfg.seed) |
|
|
train_dataset = CLIPFeaturesDataset(dataset['train']) |
|
|
val_dataset = CLIPFeaturesDataset(dataset['test']) |
|
|
|
|
|
print(f"Train samples: {len(train_dataset)}") |
|
|
print(f"Val samples: {len(val_dataset)}") |
|
|
|
|
|
|
|
|
train_loader = DataLoader( |
|
|
train_dataset, |
|
|
batch_size=cfg.batch_size, |
|
|
shuffle=True, |
|
|
num_workers=cfg.num_workers, |
|
|
pin_memory=True |
|
|
) |
|
|
val_loader = DataLoader( |
|
|
val_dataset, |
|
|
batch_size=cfg.batch_size, |
|
|
shuffle=False, |
|
|
num_workers=cfg.num_workers, |
|
|
pin_memory=True |
|
|
) |
|
|
|
|
|
|
|
|
print("\nBuilding model...") |
|
|
model = CantorCLIPClassifier(cfg).to(cfg.device) |
|
|
|
|
|
|
|
|
total_params = sum(p.numel() for p in model.parameters()) |
|
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
print(f"Total parameters: {total_params:,}") |
|
|
print(f"Trainable parameters: {trainable_params:,}") |
|
|
|
|
|
|
|
|
stats = model.get_alpha_stats() |
|
|
if stats['alpha_means']: |
|
|
print(f"CantorLinear layers: {len(stats['alpha_means'])}") |
|
|
print(f"Avg mask density: {sum(stats['mask_densities'])/len(stats['mask_densities']):.4f}") |
|
|
|
|
|
|
|
|
criterion = nn.CrossEntropyLoss() |
|
|
|
|
|
|
|
|
alpha_params = [] |
|
|
other_params = [] |
|
|
for name, param in model.named_parameters(): |
|
|
if 'alpha' in name: |
|
|
alpha_params.append(param) |
|
|
else: |
|
|
other_params.append(param) |
|
|
|
|
|
optimizer = optim.AdamW([ |
|
|
{'params': other_params, 'lr': cfg.learning_rate}, |
|
|
{'params': alpha_params, 'lr': cfg.learning_rate * cfg.alpha_lr_mult} |
|
|
], weight_decay=cfg.weight_decay) |
|
|
|
|
|
|
|
|
total_steps = len(train_loader) * cfg.num_epochs |
|
|
warmup_steps = len(train_loader) * cfg.warmup_epochs |
|
|
|
|
|
def lr_lambda(step): |
|
|
if step < warmup_steps: |
|
|
return step / warmup_steps |
|
|
else: |
|
|
return 0.5 * (1 + math.cos(math.pi * (step - warmup_steps) / (total_steps - warmup_steps))) |
|
|
|
|
|
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) |
|
|
|
|
|
|
|
|
print("\nStarting training...") |
|
|
best_val_acc = 0.0 |
|
|
|
|
|
for epoch in range(cfg.num_epochs): |
|
|
train_loss, train_acc = train_epoch( |
|
|
model, train_loader, criterion, optimizer, scheduler, cfg, epoch |
|
|
) |
|
|
|
|
|
val_loss, val_acc = evaluate(model, val_loader, criterion, cfg) |
|
|
|
|
|
print(f"\nEpoch {epoch+1}/{cfg.num_epochs}") |
|
|
print(f" Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%") |
|
|
print(f" Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%") |
|
|
|
|
|
|
|
|
stats = model.get_alpha_stats() |
|
|
if stats['alpha_means']: |
|
|
mean_alpha = sum(stats['alpha_means']) / len(stats['alpha_means']) |
|
|
mean_density = sum(stats['mask_densities']) / len(stats['mask_densities']) |
|
|
print(f" Mean Alpha: {mean_alpha:.4f} | Mean Density: {mean_density:.4f}") |
|
|
|
|
|
if cfg.use_wandb: |
|
|
wandb.log({ |
|
|
'val/loss': val_loss, |
|
|
'val/acc': val_acc, |
|
|
'alpha/mean': mean_alpha, |
|
|
'alpha/density': mean_density, |
|
|
'epoch': epoch |
|
|
}) |
|
|
|
|
|
|
|
|
if val_acc > best_val_acc: |
|
|
best_val_acc = val_acc |
|
|
torch.save({ |
|
|
'epoch': epoch, |
|
|
'model_state_dict': model.state_dict(), |
|
|
'optimizer_state_dict': optimizer.state_dict(), |
|
|
'val_acc': val_acc, |
|
|
'config': cfg |
|
|
}, 'best_cantor_imagenet.pt') |
|
|
print(f" ✓ New best model saved! (Val Acc: {val_acc:.2f}%)") |
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print(f"Training complete! Best Val Acc: {best_val_acc:.2f}%") |
|
|
print("=" * 60) |
|
|
|
|
|
if cfg.use_wandb: |
|
|
wandb.finish() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |