geovit-david-beans / trainer_v3_v21.py
AbstractPhil's picture
Create trainer_v3_v21.py
ff31041 verified
"""
Train DavidBeans V2: Wormhole Routing Architecture
===================================================
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ BEANS V2.1 β”‚ "I learn where to look..."
β”‚ (Wormhole ViT)β”‚
β”‚ πŸŒ€ β†’ πŸŒ€ β†’ πŸŒ€ β”‚ Learned sparse routing
β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”‚
β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ DAVID β”‚ "I know the crystals..."
β”‚ (Classifier) β”‚
β”‚ πŸ’Ž β†’ πŸ’Ž β†’ πŸ’Ž β”‚ Multi-scale projection
β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”‚
β–Ό
[Prediction]
Key findings from wormhole experiments:
1. When routing IS the task, routing learns structure
2. Auxiliary losses can be gamed - removed in V2
3. Gradient flow through router is critical - verified
4. Cross-contrastive aligns patch↔scale features
V2.1 additions:
- AlphaMix augmentation (localized transparent overlay)
- Configurable normalization (standard, none, center_only, unit_var)
- Support for redundant scales, conv spine, collective mode
- Configurable belly depth
Author: AbstractPhil
Date: November 30, 2025
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR
from tqdm.auto import tqdm
import time
import math
from pathlib import Path
from typing import Dict, Optional, Tuple, List, Union
from dataclasses import dataclass, field
import json
from datetime import datetime
import os
import shutil
from google.colab import userdata
os.environ['HF_TOKEN'] = userdata.get('HF_TOKEN')
HF_TOKEN = userdata.get('HF_TOKEN')
try:
from google.colab import userdata
HF_TOKEN = userdata.get('HF_TOKEN')
os.environ['HF_TOKEN'] = HF_TOKEN
except:
pass
# Import both model versions
from geofractal.model.david_beans.model import DavidBeans, DavidBeansConfig
from geofractal.model.david_beans.model_v2 import DavidBeansV2, DavidBeansV2Config
# HuggingFace Hub integration
try:
from huggingface_hub import HfApi, create_repo, upload_folder
HF_HUB_AVAILABLE = True
except ImportError:
HF_HUB_AVAILABLE = False
print(" [!] huggingface_hub not installed. Run: pip install huggingface_hub")
# Safetensors support
try:
from safetensors.torch import save_file as save_safetensors
SAFETENSORS_AVAILABLE = True
except ImportError:
SAFETENSORS_AVAILABLE = False
# TensorBoard support
try:
from torch.utils.tensorboard import SummaryWriter
TENSORBOARD_AVAILABLE = True
except ImportError:
TENSORBOARD_AVAILABLE = False
print(" [!] tensorboard not installed. Run: pip install tensorboard")
import numpy as np
# ============================================================================
# TRAINING CONFIGURATION V2.1
# ============================================================================
@dataclass
class TrainingConfigV2:
"""Training configuration for DavidBeans V2 with wormhole routing."""
# Run identification
run_name: str = "default"
run_number: Optional[int] = None
# Model version
model_version: int = 2 # 1 = original, 2 = wormhole
# Data
dataset: str = "cifar100"
image_size: int = 32
batch_size: int = 128
num_workers: int = 4
# Normalization
normalization: str = "standard" # "standard", "none", "center_only", "unit_var"
# Training schedule
epochs: int = 200
warmup_epochs: int = 10
# Optimizer
learning_rate: float = 3e-4
weight_decay: float = 0.05
betas: Tuple[float, float] = (0.9, 0.999)
# Learning rate schedule
scheduler: str = "cosine"
min_lr: float = 1e-6
# Loss weights (based on experimental findings)
ce_weight: float = 1.0
contrast_weight: float = 0.5
# NOTE: No auxiliary routing loss - routing learns from task pressure
# Regularization
gradient_clip: float = 1.0
label_smoothing: float = 0.1
# Augmentation
use_augmentation: bool = True
mixup_alpha: float = 0.2
cutmix_alpha: float = 1.0
# AlphaMix augmentation
use_alphamix: bool = False
alphamix_alpha_range: Tuple[float, float] = (0.3, 0.7)
alphamix_spatial_ratio: float = 0.25
# Checkpointing
save_interval: int = 10
output_dir: str = "./checkpoints"
resume_from: Optional[str] = None
# TensorBoard
use_tensorboard: bool = True
log_interval: int = 50
log_routing: bool = True # Log routing patterns
# HuggingFace Hub
push_to_hub: bool = False
hub_repo_id: str = "AbstractPhil/geovit-david-beans"
hub_private: bool = False
# Device
device: str = "cuda" if torch.cuda.is_available() else "cpu"
def to_dict(self) -> Dict:
return {k: v for k, v in self.__dict__.items()}
def __post_init__(self):
assert self.normalization in ["standard", "none", "center_only", "unit_var"], \
f"Invalid normalization mode: {self.normalization}"
# ============================================================================
# ROUTING METRICS
# ============================================================================
class RoutingMetrics:
"""Track and analyze wormhole routing patterns."""
def __init__(self):
self.reset()
def reset(self):
self.route_entropies = []
self.route_diversities = []
self.grad_norms = {'query': [], 'key': []}
@torch.no_grad()
def compute_route_entropy(self, soft_routes: torch.Tensor) -> float:
"""Compute average entropy of routing distributions."""
eps = 1e-8
entropy = -(soft_routes * (soft_routes + eps).log()).sum(dim=-1)
return entropy.mean().item()
@torch.no_grad()
def compute_route_diversity(self, routes: torch.Tensor, num_positions: int) -> float:
"""Compute how many unique destinations are used."""
unique_per_sample = []
for b in range(routes.shape[0]):
unique = routes[b].unique().numel()
unique_per_sample.append(unique / num_positions)
return sum(unique_per_sample) / len(unique_per_sample)
def update_from_routing_info(self, routing_info: List[Dict], model: nn.Module):
"""Extract metrics from routing info returned by V2 model."""
if not routing_info:
return
for layer_info in routing_info:
if layer_info.get('attention'):
attn = layer_info['attention']
if attn.get('weights') is not None:
entropy = self.compute_route_entropy(attn['weights'])
self.route_entropies.append(entropy)
if attn.get('routes') is not None:
P = attn['routes'].shape[1]
diversity = self.compute_route_diversity(attn['routes'], P)
self.route_diversities.append(diversity)
if layer_info.get('expert'):
exp = layer_info['expert']
if exp.get('weights') is not None:
entropy = self.compute_route_entropy(exp['weights'])
self.route_entropies.append(entropy)
def update_grad_norms(self, model: nn.Module):
"""Track gradient norms through router projections."""
for name, param in model.named_parameters():
if param.grad is not None:
if 'query_proj' in name and 'weight' in name:
self.grad_norms['query'].append(param.grad.norm().item())
elif 'key_proj' in name and 'weight' in name:
self.grad_norms['key'].append(param.grad.norm().item())
def get_summary(self) -> Dict[str, float]:
"""Get summary statistics."""
summary = {}
if self.route_entropies:
summary['route_entropy'] = sum(self.route_entropies) / len(self.route_entropies)
if self.route_diversities:
summary['route_diversity'] = sum(self.route_diversities) / len(self.route_diversities)
if self.grad_norms['query']:
summary['grad_query'] = sum(self.grad_norms['query']) / len(self.grad_norms['query'])
if self.grad_norms['key']:
summary['grad_key'] = sum(self.grad_norms['key']) / len(self.grad_norms['key'])
return summary
# ============================================================================
# DATA LOADING WITH NORMALIZATION OPTIONS
# ============================================================================
def get_normalization_transform(config: TrainingConfigV2, dataset: str):
"""Get normalization transform based on config."""
import torchvision.transforms as T
if dataset == "cifar10":
mean = (0.4914, 0.4822, 0.4465)
std = (0.2470, 0.2435, 0.2616)
elif dataset == "cifar100":
mean = (0.5071, 0.4867, 0.4408)
std = (0.2675, 0.2565, 0.2761)
else:
mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)
if config.normalization == "standard":
return T.Normalize(mean, std)
elif config.normalization == "none":
# No normalization - raw [0, 1] from ToTensor
return None
elif config.normalization == "center_only":
# Center at 0 but don't scale variance
return T.Normalize(mean=(0.5, 0.5, 0.5), std=(1.0, 1.0, 1.0))
elif config.normalization == "unit_var":
# Scale variance but don't center
return T.Normalize(mean=(0.0, 0.0, 0.0), std=std)
else:
return T.Normalize(mean, std)
def get_dataloaders(config: TrainingConfigV2) -> Tuple[DataLoader, DataLoader, int]:
"""Get train and test dataloaders with configurable normalization."""
try:
import torchvision
import torchvision.transforms as T
if config.dataset == "cifar10":
num_classes = 10
DatasetClass = torchvision.datasets.CIFAR10
elif config.dataset == "cifar100":
num_classes = 100
DatasetClass = torchvision.datasets.CIFAR100
else:
raise ValueError(f"Unknown dataset: {config.dataset}")
# Get normalization transform
norm_transform = get_normalization_transform(config, config.dataset)
# Build train transforms
train_transforms = [
T.RandomCrop(32, padding=4),
T.RandomHorizontalFlip(),
]
if config.use_augmentation:
train_transforms.append(T.AutoAugment(T.AutoAugmentPolicy.CIFAR10))
train_transforms.append(T.ToTensor())
if norm_transform is not None:
train_transforms.append(norm_transform)
train_transform = T.Compose(train_transforms)
# Build test transforms
test_transforms = [T.ToTensor()]
if norm_transform is not None:
test_transforms.append(norm_transform)
test_transform = T.Compose(test_transforms)
print(f" Normalization: {config.normalization}")
train_dataset = DatasetClass(
root='./data', train=True, download=True, transform=train_transform
)
test_dataset = DatasetClass(
root='./data', train=False, download=True, transform=test_transform
)
train_loader = DataLoader(
train_dataset,
batch_size=config.batch_size,
shuffle=True,
num_workers=config.num_workers,
pin_memory=True,
persistent_workers=config.num_workers > 0,
drop_last=True
)
test_loader = DataLoader(
test_dataset,
batch_size=config.batch_size,
shuffle=False,
num_workers=config.num_workers,
pin_memory=True,
persistent_workers=config.num_workers > 0
)
return train_loader, test_loader, num_classes
except ImportError:
print(" [!] torchvision not available, using synthetic data")
return get_synthetic_dataloaders(config)
def get_synthetic_dataloaders(config: TrainingConfigV2) -> Tuple[DataLoader, DataLoader, int]:
"""Fallback synthetic data for testing."""
class SyntheticDataset(torch.utils.data.Dataset):
def __init__(self, size: int, image_size: int, num_classes: int):
self.size = size
self.image_size = image_size
self.num_classes = num_classes
def __len__(self):
return self.size
def __getitem__(self, idx):
x = torch.randn(3, self.image_size, self.image_size)
y = idx % self.num_classes
return x, y
num_classes = 100 if config.dataset == "cifar100" else 10
train_dataset = SyntheticDataset(5000, config.image_size, num_classes)
test_dataset = SyntheticDataset(1000, config.image_size, num_classes)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False)
return train_loader, test_loader, num_classes
# ============================================================================
# MIXING AUGMENTATIONS
# ============================================================================
def mixup_data(x: torch.Tensor, y: torch.Tensor, alpha: float = 0.2):
"""Mixup augmentation."""
if alpha > 0:
lam = torch.distributions.Beta(alpha, alpha).sample().item()
else:
lam = 1.0
batch_size = x.size(0)
index = torch.randperm(batch_size, device=x.device)
mixed_x = lam * x + (1 - lam) * x[index]
y_a, y_b = y, y[index]
return mixed_x, y_a, y_b, lam
def cutmix_data(x: torch.Tensor, y: torch.Tensor, alpha: float = 1.0):
"""CutMix augmentation."""
if alpha > 0:
lam = torch.distributions.Beta(alpha, alpha).sample().item()
else:
lam = 1.0
batch_size = x.size(0)
index = torch.randperm(batch_size, device=x.device)
_, _, H, W = x.shape
cut_ratio = math.sqrt(1 - lam)
cut_h = int(H * cut_ratio)
cut_w = int(W * cut_ratio)
cx = torch.randint(0, H, (1,)).item()
cy = torch.randint(0, W, (1,)).item()
x1 = max(0, cx - cut_h // 2)
x2 = min(H, cx + cut_h // 2)
y1 = max(0, cy - cut_w // 2)
y2 = min(W, cy + cut_w // 2)
mixed_x = x.clone()
mixed_x[:, :, x1:x2, y1:y2] = x[index, :, x1:x2, y1:y2]
lam = 1 - ((x2 - x1) * (y2 - y1)) / (H * W)
y_a, y_b = y, y[index]
return mixed_x, y_a, y_b, lam
def alphamix_data(
x: torch.Tensor,
y: torch.Tensor,
alpha_range: Tuple[float, float] = (0.3, 0.7),
spatial_ratio: float = 0.25
):
"""
AlphaMix: Spatially localized transparent overlay.
Unlike CutMix (full replacement) or Mixup (global blend),
AlphaMix creates a localized alpha-blended region.
Args:
x: [B, C, H, W] input images
y: [B] labels
alpha_range: (min, max) for alpha blending in overlay region
spatial_ratio: Fraction of image area for overlay
Returns:
mixed_x, y_a, y_b, lam (effective lambda for loss weighting)
"""
batch_size = x.size(0)
index = torch.randperm(batch_size, device=x.device)
y_a, y_b = y, y[index]
# Sample alpha from beta distribution within range
alpha_min, alpha_max = alpha_range
beta_sample = np.random.beta(2, 2)
alpha = alpha_min + (alpha_max - alpha_min) * beta_sample
_, _, H, W = x.shape
# Compute overlay region size
overlay_ratio = np.sqrt(spatial_ratio)
overlay_h = max(4, int(H * overlay_ratio))
overlay_w = max(4, int(W * overlay_ratio))
# Random position for overlay
top = np.random.randint(0, max(1, H - overlay_h + 1))
left = np.random.randint(0, max(1, W - overlay_w + 1))
# Create composited image
composited_x = x.clone()
# Alpha blend in the overlay region
overlay_region = alpha * x[:, :, top:top + overlay_h, left:left + overlay_w]
background_region = (1 - alpha) * x[index, :, top:top + overlay_h, left:left + overlay_w]
composited_x[:, :, top:top + overlay_h, left:left + overlay_w] = overlay_region + background_region
# Compute effective lambda based on blended area
blended_area = (overlay_h * overlay_w) / (H * W)
# lam represents contribution of original sample
# In non-blended region: 100% original
# In blended region: alpha% original
lam = 1.0 - blended_area * (1 - alpha)
return composited_x, y_a, y_b, lam
# ============================================================================
# METRICS TRACKER
# ============================================================================
class MetricsTracker:
"""Track training metrics with EMA smoothing."""
def __init__(self, ema_decay: float = 0.9):
self.ema_decay = ema_decay
self.metrics = {}
self.ema_metrics = {}
self.history = {}
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
if k not in self.metrics:
self.metrics[k] = []
self.ema_metrics[k] = v
self.history[k] = []
self.metrics[k].append(v)
self.ema_metrics[k] = self.ema_decay * self.ema_metrics[k] + (1 - self.ema_decay) * v
def get_ema(self, key: str) -> float:
return self.ema_metrics.get(key, 0.0)
def get_epoch_mean(self, key: str) -> float:
values = self.metrics.get(key, [])
return sum(values) / len(values) if values else 0.0
def end_epoch(self):
for k, v in self.metrics.items():
if v:
self.history[k].append(sum(v) / len(v))
self.metrics = {k: [] for k in self.metrics}
def get_history(self) -> Dict:
return self.history
# ============================================================================
# CHECKPOINT UTILITIES
# ============================================================================
def find_latest_checkpoint(output_dir: Path) -> Optional[Path]:
"""Find the most recent checkpoint in output directory."""
checkpoints = list(output_dir.glob("checkpoint_epoch_*.pt"))
if not checkpoints:
best_model = output_dir / "best_model.pt"
if best_model.exists():
return best_model
return None
def get_epoch(p):
try:
return int(p.stem.split("_")[-1])
except:
return 0
checkpoints.sort(key=get_epoch, reverse=True)
return checkpoints[0]
def get_next_run_number(base_dir: Path) -> int:
"""Get the next run number by scanning existing run directories."""
if not base_dir.exists():
return 1
max_num = 0
for d in base_dir.iterdir():
if d.is_dir() and d.name.startswith("run_"):
try:
num = int(d.name.split("_")[1])
max_num = max(max_num, num)
except (IndexError, ValueError):
continue
return max_num + 1
def generate_run_dir_name(run_number: int, run_name: str, version: int = 2) -> str:
"""Generate a run directory name."""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
safe_name = "".join(c if c.isalnum() or c == "_" else "_" for c in run_name.lower())
safe_name = "_".join(filter(None, safe_name.split("_")))
return f"run_{run_number:03d}_v{version}_{safe_name}_{timestamp}"
def find_latest_run_dir(base_dir: Path) -> Optional[Path]:
"""Find the most recent run directory."""
if not base_dir.exists():
return None
run_dirs = [d for d in base_dir.iterdir() if d.is_dir() and d.name.startswith("run_")]
if not run_dirs:
return None
run_dirs.sort(key=lambda d: d.stat().st_mtime, reverse=True)
return run_dirs[0]
def load_checkpoint(
checkpoint_path: Path,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
device: str = "cuda"
) -> Tuple[int, float]:
"""Load checkpoint and return (start_epoch, best_acc)."""
print(f"\nπŸ“‚ Loading checkpoint: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
print(f" βœ“ Loaded model weights")
if optimizer is not None and 'optimizer_state_dict' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
print(f" βœ“ Loaded optimizer state")
epoch = checkpoint.get('epoch', 0)
best_acc = checkpoint.get('best_acc', 0.0)
print(f" βœ“ Resuming from epoch {epoch + 1}, best_acc={best_acc:.2f}%")
return epoch + 1, best_acc
# ============================================================================
# HUGGINGFACE HUB INTEGRATION
# ============================================================================
def generate_run_readme(
model_config: Union[DavidBeansConfig, DavidBeansV2Config],
train_config: TrainingConfigV2,
best_acc: float,
run_dir_name: str
) -> str:
"""Generate README for a specific run."""
scales_str = ", ".join([str(s) for s in model_config.scales])
# V2 specific info
if isinstance(model_config, DavidBeansV2Config):
copies_str = ""
if model_config.scale_copies:
copies_str = f"\n| Scale Copies | {model_config.scale_copies} |"
routing_info = f"""
## Wormhole Routing (V2)
| Parameter | Value |
|-----------|-------|
| Mode | {model_config.wormhole_mode} |
| Wormholes/Position | {model_config.num_wormholes} |
| Temperature | {model_config.wormhole_temperature} |
| Tiles | {model_config.num_tiles} |
| Tile Wormholes | {model_config.tile_wormholes} |
## Crystal Head
| Parameter | Value |
|-----------|-------|
| Scales | [{scales_str}] |{copies_str}
| Weighting Mode | {model_config.weighting_mode} |
| Belly Layers | {model_config.belly_layers} |
| Belly Residual | {model_config.belly_residual} |
| Use Spine | {model_config.use_spine} |
| Use Collective | {model_config.use_collective} |
"""
else:
routing_info = f"""
## Routing (V1)
| Parameter | Value |
|-----------|-------|
| k_neighbors | {model_config.k_neighbors} |
| Cantor Weight | {model_config.cantor_weight} |
"""
aug_info = f"""
## Augmentation
| Parameter | Value |
|-----------|-------|
| Normalization | {train_config.normalization} |
| Mixup Alpha | {train_config.mixup_alpha} |
| CutMix Alpha | {train_config.cutmix_alpha} |
| AlphaMix | {train_config.use_alphamix} |
| Label Smoothing | {train_config.label_smoothing} |
"""
return f"""# Run: {run_dir_name}
## Results
- **Best Accuracy**: {best_acc:.2f}%
- **Dataset**: {train_config.dataset}
- **Epochs**: {train_config.epochs}
- **Model Version**: V{train_config.model_version}
## Model Config
| Parameter | Value |
|-----------|-------|
| Dim | {model_config.dim} |
| Layers | {model_config.num_layers} |
| Heads | {model_config.num_heads} |
| Patch Size | {model_config.patch_size} |
{routing_info}
## Training Config
| Parameter | Value |
|-----------|-------|
| Learning Rate | {train_config.learning_rate} |
| Weight Decay | {train_config.weight_decay} |
| Batch Size | {train_config.batch_size} |
| CE Weight | {train_config.ce_weight} |
| Contrast Weight | {train_config.contrast_weight} |
{aug_info}
## Key Findings Applied
- Routing learns from task pressure (no auxiliary routing losses)
- Gradients verified to flow through router
- Cross-contrastive aligns patch↔scale features
"""
def prepare_run_for_hub(
model: nn.Module,
model_config: Union[DavidBeansConfig, DavidBeansV2Config],
train_config: TrainingConfigV2,
best_acc: float,
output_dir: Path,
run_dir_name: str,
training_history: Optional[Dict] = None
) -> Path:
"""Prepare run files for upload to HuggingFace Hub."""
hub_dir = output_dir / "hub_upload"
run_hub_dir = hub_dir / "weights" / run_dir_name
run_hub_dir.mkdir(parents=True, exist_ok=True)
state_dict = {k: v.clone() for k, v in model.state_dict().items()}
if SAFETENSORS_AVAILABLE:
try:
save_safetensors(state_dict, run_hub_dir / "best.safetensors")
print(f" βœ“ Saved best.safetensors")
except Exception as e:
print(f" [!] Safetensors failed ({e}), using pytorch format")
torch.save(state_dict, run_hub_dir / "best.pt")
else:
torch.save(state_dict, run_hub_dir / "best.pt")
config_dict = {
"architecture": f"DavidBeans_V{train_config.model_version}",
"model_type": "david_beans_v2" if train_config.model_version == 2 else "david_beans",
**model_config.__dict__
}
with open(run_hub_dir / "config.json", "w") as f:
json.dump(config_dict, f, indent=2, default=str)
with open(run_hub_dir / "training_config.json", "w") as f:
json.dump(train_config.to_dict(), f, indent=2, default=str)
run_readme = generate_run_readme(model_config, train_config, best_acc, run_dir_name)
with open(run_hub_dir / "README.md", "w") as f:
f.write(run_readme)
if training_history:
with open(run_hub_dir / "training_history.json", "w") as f:
json.dump(training_history, f, indent=2)
tb_dir = output_dir / "tensorboard"
if tb_dir.exists():
hub_tb_dir = run_hub_dir / "tensorboard"
if hub_tb_dir.exists():
shutil.rmtree(hub_tb_dir)
shutil.copytree(tb_dir, hub_tb_dir)
return hub_dir
def push_run_to_hub(
hub_dir: Path,
repo_id: str,
run_dir_name: str,
private: bool = False,
commit_message: Optional[str] = None
) -> str:
"""Push run files to HuggingFace Hub."""
if not HF_HUB_AVAILABLE:
raise RuntimeError("huggingface_hub not installed")
api = HfApi()
try:
create_repo(repo_id, private=private, exist_ok=True)
except Exception as e:
print(f" [!] Repo creation note: {e}")
run_upload_dir = hub_dir / "weights" / run_dir_name
if commit_message is None:
commit_message = f"Update {run_dir_name} - {datetime.now().strftime('%Y-%m-%d %H:%M')}"
url = upload_folder(
folder_path=str(run_upload_dir),
repo_id=repo_id,
path_in_repo=f"weights/{run_dir_name}",
commit_message=commit_message
)
return url
# ============================================================================
# TRAINING LOOP V2
# ============================================================================
def train_epoch_v2(
model: nn.Module,
train_loader: DataLoader,
optimizer: torch.optim.Optimizer,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler],
config: TrainingConfigV2,
epoch: int,
tracker: MetricsTracker,
routing_metrics: RoutingMetrics,
writer: Optional['SummaryWriter'] = None
) -> Dict[str, float]:
"""Train for one epoch with V2 routing metrics and AlphaMix support."""
model.train()
device = config.device
is_v2 = config.model_version == 2
total_loss = 0.0
total_correct = 0
total_samples = 0
global_step = epoch * len(train_loader)
routing_metrics.reset()
pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}", leave=True)
for batch_idx, (images, targets) in enumerate(pbar):
images = images.to(device, non_blocking=True)
targets = targets.to(device, non_blocking=True)
# Apply mixing augmentations
use_mixup = config.use_augmentation and config.mixup_alpha > 0
use_cutmix = config.use_augmentation and config.cutmix_alpha > 0
use_alphamix = config.use_alphamix
mixed = False
mix_type = None
if use_mixup or use_cutmix or use_alphamix:
r = torch.rand(1).item()
# Probability distribution for mix types
# If all three enabled: 40% none, 20% mixup, 20% cutmix, 20% alphamix
# Adjust based on what's enabled
thresholds = [0.4] # Base: 40% no mixing
enabled_mixes = []
if use_mixup:
enabled_mixes.append(('mixup', config.mixup_alpha))
if use_cutmix:
enabled_mixes.append(('cutmix', config.cutmix_alpha))
if use_alphamix:
enabled_mixes.append(('alphamix', None))
if enabled_mixes:
mix_prob = 0.6 / len(enabled_mixes) # Split remaining 60% among enabled
cumulative = 0.4
for i, (mix_name, _) in enumerate(enabled_mixes):
cumulative += mix_prob
thresholds.append(cumulative)
# Determine which mix to use
if r < 0.4:
pass # No mixing
else:
for i, (mix_name, mix_param) in enumerate(enabled_mixes):
if r < thresholds[i + 1]:
mix_type = mix_name
break
if mix_type == 'mixup':
images, targets_a, targets_b, lam = mixup_data(images, targets, config.mixup_alpha)
mixed = True
elif mix_type == 'cutmix':
images, targets_a, targets_b, lam = cutmix_data(images, targets, config.cutmix_alpha)
mixed = True
elif mix_type == 'alphamix':
images, targets_a, targets_b, lam = alphamix_data(
images, targets,
alpha_range=config.alphamix_alpha_range,
spatial_ratio=config.alphamix_spatial_ratio
)
mixed = True
# Forward pass
if is_v2:
result = model(
images,
targets=targets,
return_loss=True,
return_routing=(batch_idx % 10 == 0)
)
else:
result = model(images, targets=targets, return_loss=True)
losses = result['losses']
# Handle mixed CE loss
if mixed:
logits = result['logits']
ce_loss = lam * F.cross_entropy(logits, targets_a, label_smoothing=config.label_smoothing) + \
(1 - lam) * F.cross_entropy(logits, targets_b, label_smoothing=config.label_smoothing)
losses['ce'] = ce_loss
# Compute total loss (NO auxiliary routing loss - key finding!)
loss = (
config.ce_weight * losses['ce'] +
config.contrast_weight * losses.get('contrast', torch.tensor(0.0, device=device))
)
# Add scale CE losses (handle both regular and copy scales)
for key, val in losses.items():
if key.startswith('ce_') and key != 'ce':
if isinstance(val, torch.Tensor):
loss = loss + 0.1 * val
# Backward pass
optimizer.zero_grad()
loss.backward()
# Track routing gradient norms
if is_v2:
routing_metrics.update_grad_norms(model)
if config.gradient_clip > 0:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.gradient_clip)
else:
grad_norm = 0.0
optimizer.step()
if scheduler is not None and config.scheduler == "onecycle":
scheduler.step()
# Update routing metrics
if is_v2 and result.get('routing'):
routing_metrics.update_from_routing_info(result['routing'], model)
# Compute accuracy
with torch.no_grad():
logits = result['logits']
preds = logits.argmax(dim=-1)
if mixed:
correct = (lam * (preds == targets_a).float() +
(1 - lam) * (preds == targets_b).float()).sum()
else:
correct = (preds == targets).sum()
total_correct += correct.item()
total_samples += targets.size(0)
total_loss += loss.item()
# Track metrics
def to_float(v):
return v.item() if isinstance(v, torch.Tensor) else float(v)
contrast_loss = to_float(losses.get('contrast', 0.0))
current_lr = optimizer.param_groups[0]['lr']
tracker.update(
loss=loss.item(),
ce=losses['ce'].item(),
contrast=contrast_loss,
lr=current_lr
)
# TensorBoard logging
if writer is not None and (batch_idx + 1) % config.log_interval == 0:
step = global_step + batch_idx
writer.add_scalar('train/loss_total', loss.item(), step)
writer.add_scalar('train/loss_ce', losses['ce'].item(), step)
writer.add_scalar('train/loss_contrast', contrast_loss, step)
writer.add_scalar('train/learning_rate', current_lr, step)
writer.add_scalar('train/grad_norm', to_float(grad_norm), step)
if is_v2 and config.log_routing:
routing_summary = routing_metrics.get_summary()
for k, v in routing_summary.items():
writer.add_scalar(f'routing/{k}', v, step)
# Progress bar
routing_summary = routing_metrics.get_summary()
postfix = {
'loss': f"{tracker.get_ema('loss'):.3f}",
'acc': f"{100.0 * total_correct / total_samples:.1f}%",
}
if is_v2 and 'grad_query' in routing_summary:
postfix['βˆ‡q'] = f"{routing_summary['grad_query']:.2f}"
if 'route_entropy' in routing_summary:
postfix['H'] = f"{routing_summary['route_entropy']:.2f}"
pbar.set_postfix(postfix)
if scheduler is not None and config.scheduler == "cosine":
scheduler.step()
return {
'loss': total_loss / len(train_loader),
'acc': 100.0 * total_correct / total_samples,
**routing_metrics.get_summary()
}
@torch.no_grad()
def evaluate_v2(
model: nn.Module,
test_loader: DataLoader,
config: TrainingConfigV2
) -> Dict[str, float]:
"""Evaluate on test set."""
model.eval()
device = config.device
total_loss = 0.0
total_correct = 0
total_samples = 0
# Handle variable number of scale heads (including copies)
num_heads = len(model.head.heads) if hasattr(model.head, 'heads') else len(model.config.scales)
head_correct = [0] * num_heads
for images, targets in test_loader:
images = images.to(device, non_blocking=True)
targets = targets.to(device, non_blocking=True)
result = model(images, targets=targets, return_loss=True)
logits = result['logits']
losses = result['losses']
loss = losses['total']
preds = logits.argmax(dim=-1)
total_loss += loss.item() * targets.size(0)
total_correct += (preds == targets).sum().item()
total_samples += targets.size(0)
# Per-head accuracy
for i, scale_logits in enumerate(result['scale_logits']):
scale_preds = scale_logits.argmax(dim=-1)
head_correct[i] += (scale_preds == targets).sum().item()
metrics = {
'loss': total_loss / total_samples,
'acc': 100.0 * total_correct / total_samples
}
# Map head indices to scale names
if hasattr(model.head, 'head_scale_map'):
for i, (scale, copy_idx) in enumerate(model.head.head_scale_map):
key = f'acc_{scale}' if copy_idx == 0 else f'acc_{scale}_c{copy_idx}'
metrics[key] = 100.0 * head_correct[i] / total_samples
else:
for i, scale in enumerate(model.config.scales):
metrics[f'acc_{scale}'] = 100.0 * head_correct[i] / total_samples
return metrics
# ============================================================================
# MAIN TRAINING FUNCTION V2
# ============================================================================
def train_david_beans_v2(
model_config: Optional[Union[DavidBeansConfig, DavidBeansV2Config]] = None,
train_config: Optional[TrainingConfigV2] = None
):
"""Main training function for DavidBeans V1 or V2."""
print("=" * 70)
print(" DAVID-BEANS V2.1 TRAINING: Wormhole Routing")
print("=" * 70)
print()
print(" πŸŒ€ WORMHOLES: Learned sparse routing")
print(" πŸ’Ž CRYSTALS: Multi-scale projection")
print()
print(" Key insight: When routing IS the task, routing learns structure")
print()
print("=" * 70)
if train_config is None:
train_config = TrainingConfigV2()
base_output_dir = Path(train_config.output_dir)
base_output_dir.mkdir(parents=True, exist_ok=True)
# Checkpoint resolution
checkpoint_path = None
run_dir = None
run_dir_name = None
if train_config.resume_from:
resume_path = Path(train_config.resume_from)
if resume_path.is_file():
checkpoint_path = resume_path
run_dir = checkpoint_path.parent
run_dir_name = run_dir.name
print(f"\nπŸ“‚ Found checkpoint file: {checkpoint_path.name}")
elif resume_path.is_dir():
checkpoint_path = find_latest_checkpoint(resume_path)
if checkpoint_path:
run_dir = resume_path
run_dir_name = resume_path.name
print(f"\nπŸ“‚ Found checkpoint in dir: {checkpoint_path.name}")
else:
possible_dir = base_output_dir / train_config.resume_from
if possible_dir.is_dir():
checkpoint_path = find_latest_checkpoint(possible_dir)
if checkpoint_path:
run_dir = possible_dir
run_dir_name = possible_dir.name
print(f"\nπŸ“‚ Found checkpoint in: {run_dir_name}")
if checkpoint_path is None:
possible_file = base_output_dir / train_config.resume_from
if possible_file.is_file():
checkpoint_path = possible_file
run_dir = checkpoint_path.parent
run_dir_name = run_dir.name
print(f"\nπŸ“‚ Found checkpoint: {checkpoint_path.name}")
if checkpoint_path is None:
print(f"\n [!] Could not find checkpoint: {train_config.resume_from}")
print(f" [!] Starting fresh run instead")
else:
print(f" βœ“ Will resume from: {checkpoint_path}")
# Create new run directory if not resuming
if run_dir is None:
run_number = train_config.run_number or get_next_run_number(base_output_dir)
run_dir_name = generate_run_dir_name(run_number, train_config.run_name, train_config.model_version)
run_dir = base_output_dir / run_dir_name
run_dir.mkdir(parents=True, exist_ok=True)
print(f"\nπŸ“ New run: {run_dir_name}")
else:
print(f"\nπŸ“ Resuming run: {run_dir_name}")
output_dir = run_dir
# Model config
if checkpoint_path and checkpoint_path.exists() and model_config is None:
try:
ckpt = torch.load(checkpoint_path, map_location='cpu')
if 'model_config' in ckpt:
saved_config = ckpt['model_config']
print(f" βœ“ Loading model config from checkpoint")
if train_config.model_version == 2:
model_config = DavidBeansV2Config(**saved_config)
else:
model_config = DavidBeansConfig(**saved_config)
except Exception as e:
print(f" [!] Could not load config from checkpoint: {e}")
if model_config is None:
if train_config.model_version == 2:
model_config = DavidBeansV2Config(
image_size=train_config.image_size,
patch_size=4,
dim=512,
num_layers=4,
num_heads=8,
num_wormholes=8,
wormhole_temperature=0.1,
wormhole_mode="hybrid",
num_tiles=16,
tile_wormholes=4,
scales=[64, 128, 256, 384, 512],
num_classes=100,
contrast_weight=train_config.contrast_weight,
dropout=0.1
)
else:
model_config = DavidBeansConfig(
image_size=train_config.image_size,
patch_size=4,
dim=512,
num_layers=4,
num_heads=8,
num_experts=5,
k_neighbors=16,
cantor_weight=0.3,
scales=[64, 128, 256, 384, 512],
num_classes=100,
dropout=0.1
)
device = train_config.device
print(f"\nDevice: {device}")
print(f"Model version: V{train_config.model_version}")
# Data
print("\nLoading data...")
train_loader, test_loader, num_classes = get_dataloaders(train_config)
print(f" Dataset: {train_config.dataset}")
print(f" Train: {len(train_loader.dataset)}, Test: {len(test_loader.dataset)}")
print(f" Classes: {num_classes}")
model_config.num_classes = num_classes
# Model
print("\nBuilding model...")
if train_config.model_version == 2:
model = DavidBeansV2(model_config)
else:
model = DavidBeans(model_config)
model = model.to(device)
print(f"\n{model}")
num_params = sum(p.numel() for p in model.parameters())
print(f"\nParameters: {num_params:,}")
# Optimizer
print("\nSetting up optimizer...")
decay_params = []
no_decay_params = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if 'bias' in name or 'norm' in name or 'embedding' in name:
no_decay_params.append(param)
else:
decay_params.append(param)
optimizer = AdamW([
{'params': decay_params, 'weight_decay': train_config.weight_decay},
{'params': no_decay_params, 'weight_decay': 0.0}
], lr=train_config.learning_rate, betas=train_config.betas)
if train_config.scheduler == "cosine":
scheduler = CosineAnnealingLR(
optimizer,
T_max=train_config.epochs - train_config.warmup_epochs,
eta_min=train_config.min_lr
)
elif train_config.scheduler == "onecycle":
scheduler = OneCycleLR(
optimizer,
max_lr=train_config.learning_rate,
epochs=train_config.epochs,
steps_per_epoch=len(train_loader),
pct_start=train_config.warmup_epochs / train_config.epochs
)
else:
scheduler = None
print(f" Optimizer: AdamW (lr={train_config.learning_rate}, wd={train_config.weight_decay})")
print(f" Scheduler: {train_config.scheduler}")
# Print augmentation config
print(f"\nAugmentation:")
print(f" Mixup: {train_config.mixup_alpha if train_config.mixup_alpha > 0 else 'disabled'}")
print(f" CutMix: {train_config.cutmix_alpha if train_config.cutmix_alpha > 0 else 'disabled'}")
print(f" AlphaMix: {train_config.alphamix_alpha_range if train_config.use_alphamix else 'disabled'}")
tracker = MetricsTracker()
routing_metrics = RoutingMetrics()
best_acc = 0.0
start_epoch = 0
# Load checkpoint
if checkpoint_path and checkpoint_path.exists():
start_epoch, best_acc = load_checkpoint(checkpoint_path, model, optimizer, device)
if scheduler is not None and train_config.scheduler == "cosine":
for _ in range(start_epoch):
scheduler.step()
print(f" βœ“ Advanced scheduler to epoch {start_epoch}")
# TensorBoard
writer = None
if train_config.use_tensorboard and TENSORBOARD_AVAILABLE:
tb_dir = output_dir / "tensorboard"
tb_dir.mkdir(parents=True, exist_ok=True)
writer = SummaryWriter(log_dir=str(tb_dir))
print(f" TensorBoard: {tb_dir}")
# Save configs
with open(output_dir / "config.json", "w") as f:
json.dump({**model_config.__dict__, "architecture": f"DavidBeans_V{train_config.model_version}"},
f, indent=2, default=str)
with open(output_dir / "training_config.json", "w") as f:
json.dump(train_config.to_dict(), f, indent=2, default=str)
# Training loop
print("\n" + "=" * 70)
print(" TRAINING")
print("=" * 70)
for epoch in range(start_epoch, train_config.epochs):
epoch_start = time.time()
# Warmup
if epoch < train_config.warmup_epochs and train_config.scheduler == "cosine":
warmup_lr = train_config.learning_rate * (epoch + 1) / train_config.warmup_epochs
for param_group in optimizer.param_groups:
param_group['lr'] = warmup_lr
train_metrics = train_epoch_v2(
model, train_loader, optimizer, scheduler,
train_config, epoch, tracker, routing_metrics, writer
)
test_metrics = evaluate_v2(model, test_loader, train_config)
epoch_time = time.time() - epoch_start
# TensorBoard
if writer is not None:
writer.add_scalar('epoch/train_loss', train_metrics['loss'], epoch)
writer.add_scalar('epoch/train_acc', train_metrics['acc'], epoch)
writer.add_scalar('epoch/test_loss', test_metrics['loss'], epoch)
writer.add_scalar('epoch/test_acc', test_metrics['acc'], epoch)
# Log all scale accuracies
for key, val in test_metrics.items():
if key.startswith('acc_'):
writer.add_scalar(f'scales/{key}', val, epoch)
# Print summary - show primary scales only (not copies)
primary_scale_accs = []
for scale in model.config.scales:
if f'acc_{scale}' in test_metrics:
primary_scale_accs.append(f"{scale}:{test_metrics[f'acc_{scale}']:.1f}%")
scale_accs = " | ".join(primary_scale_accs)
star = "β˜…" if test_metrics['acc'] > best_acc else ""
routing_info = ""
if train_config.model_version == 2 and 'grad_query' in train_metrics:
routing_info = f" | βˆ‡q:{train_metrics.get('grad_query', 0):.2f}"
print(f" β†’ Train: {train_metrics['acc']:.1f}% | Test: {test_metrics['acc']:.1f}% | "
f"[{scale_accs}]{routing_info} | {epoch_time:.0f}s {star}")
# Save best model
if test_metrics['acc'] > best_acc:
best_acc = test_metrics['acc']
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'best_acc': best_acc,
'model_config': model_config.__dict__,
'train_config': train_config.to_dict()
}, output_dir / "best_model.pt")
# Periodic checkpoint
if (epoch + 1) % train_config.save_interval == 0:
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'best_acc': best_acc,
'model_config': model_config.__dict__,
'train_config': train_config.to_dict()
}, output_dir / f"checkpoint_epoch_{epoch + 1}.pt")
if train_config.push_to_hub and HF_HUB_AVAILABLE:
try:
hub_dir = prepare_run_for_hub(
model=model,
model_config=model_config,
train_config=train_config,
best_acc=best_acc,
output_dir=output_dir,
run_dir_name=run_dir_name,
training_history=tracker.get_history()
)
push_run_to_hub(
hub_dir=hub_dir,
repo_id=train_config.hub_repo_id,
run_dir_name=run_dir_name,
commit_message=f"Epoch {epoch + 1} - {best_acc:.2f}% acc"
)
print(f" πŸ“€ Uploaded to hub")
except Exception as e:
print(f" [!] Hub upload failed: {e}")
tracker.end_epoch()
# Final summary
print("\n" + "=" * 70)
print(" TRAINING COMPLETE")
print("=" * 70)
print(f"\n Best Test Accuracy: {best_acc:.2f}%")
print(f" Model saved to: {output_dir / 'best_model.pt'}")
if writer is not None:
writer.close()
return model, best_acc
# ============================================================================
# PRESETS
# ============================================================================
def train_cifar100_v2_wormhole(
run_name: str = "wormhole_base",
push_to_hub: bool = False,
resume: bool = False
):
"""CIFAR-100 with V2 wormhole routing."""
model_config = DavidBeansV2Config(
image_size=32,
patch_size=2,
dim=512,
num_layers=4,
num_heads=16,
# Wormhole routing parameters
num_wormholes=16,
wormhole_temperature=0.1,
wormhole_mode="hybrid",
# Tessellation parameters
num_tiles=16,
tile_wormholes=4,
# Crystal head
scales=[64, 128, 256, 512, 1024],
num_classes=100,
# V2.1 additions
belly_layers=2,
belly_residual=False,
weighting_mode="learned",
scale_copies=None,
use_spine=False,
use_collective=False,
# Other
contrast_temperature=0.07,
contrast_weight=0.5,
dropout=0.1
)
train_config = TrainingConfigV2(
run_name=run_name,
model_version=2,
dataset="cifar100",
epochs=300,
batch_size=512,
learning_rate=3e-4,
weight_decay=0.05,
warmup_epochs=15,
# Normalization
normalization="standard",
# Loss weights
ce_weight=1.0,
contrast_weight=0.5,
# Augmentation
label_smoothing=0.1,
mixup_alpha=0.2,
cutmix_alpha=1.0,
# AlphaMix
use_alphamix=True,
alphamix_alpha_range=(0.3, 0.7),
alphamix_spatial_ratio=0.25,
# Output
output_dir="./checkpoints/cifar100_v2",
resume_from=None,
# Hub
push_to_hub=push_to_hub,
hub_repo_id="AbstractPhil/geovit-david-beans",
# Routing logging
log_routing=True
)
return train_david_beans_v2(model_config, train_config)
def train_cifar100_v2_with_spine(
run_name: str = "wormhole_spine",
push_to_hub: bool = False,
resume: bool = False
):
"""CIFAR-100 with V2 wormhole routing + conv spine."""
model_config = DavidBeansV2Config(
image_size=32,
patch_size=4,
dim=512,
num_layers=4,
num_heads=8,
num_wormholes=8,
wormhole_temperature=0.1,
wormhole_mode="hybrid",
num_tiles=16,
tile_wormholes=4,
scales=[64, 128, 256, 384, 512],
num_classes=100,
# Enable spine
use_spine=True,
spine_channels=[64, 128, 256],
spine_cross_attn=True,
spine_gate_init=0.0,
# Belly
belly_layers=2,
weighting_mode="geometric",
contrast_temperature=0.07,
contrast_weight=0.5,
dropout=0.1
)
train_config = TrainingConfigV2(
run_name=run_name,
model_version=2,
dataset="cifar100",
epochs=200,
batch_size=128,
learning_rate=3e-4,
weight_decay=0.05,
warmup_epochs=10,
normalization="standard",
ce_weight=1.0,
contrast_weight=0.5,
label_smoothing=0.1,
mixup_alpha=0.2,
cutmix_alpha=1.0,
use_alphamix=True,
output_dir="./checkpoints/cifar100_v2",
push_to_hub=push_to_hub,
hub_repo_id="AbstractPhil/geovit-david-beans",
log_routing=True
)
return train_david_beans_v2(model_config, train_config)
def train_cifar100_v2_redundant_scales(
run_name: str = "wormhole_redundant",
push_to_hub: bool = False,
resume: bool = False
):
"""CIFAR-100 with redundant small scales for ensemble effect."""
model_config = DavidBeansV2Config(
image_size=32,
patch_size=4,
dim=512,
num_layers=4,
num_heads=8,
num_wormholes=8,
wormhole_temperature=0.1,
wormhole_mode="hybrid",
num_tiles=16,
tile_wormholes=4,
scales=[64, 128, 256, 512],
# Redundant copies: 4x 64d, 2x 128d, 1x 256d, 1x 512d
scale_copies=[4, 2, 1, 1],
copy_theta_step=0.15,
num_classes=100,
weighting_mode="geometric",
belly_layers=2,
contrast_temperature=0.07,
contrast_weight=0.5,
dropout=0.1
)
train_config = TrainingConfigV2(
run_name=run_name,
model_version=2,
dataset="cifar100",
epochs=200,
batch_size=128,
learning_rate=3e-4,
weight_decay=0.05,
warmup_epochs=10,
normalization="standard",
ce_weight=1.0,
contrast_weight=0.5,
label_smoothing=0.1,
mixup_alpha=0.2,
cutmix_alpha=1.0,
use_alphamix=True,
output_dir="./checkpoints/cifar100_v2",
push_to_hub=push_to_hub,
hub_repo_id="AbstractPhil/geovit-david-beans",
log_routing=True
)
return train_david_beans_v2(model_config, train_config)
def train_cifar100_v2_no_norm(
run_name: str = "wormhole_no_norm",
push_to_hub: bool = False,
resume: bool = False
):
"""CIFAR-100 with no normalization (raw pixels) for geometric components."""
model_config = DavidBeansV2Config(
image_size=32,
patch_size=4,
dim=512,
num_layers=4,
num_heads=8,
num_wormholes=8,
wormhole_temperature=0.1,
wormhole_mode="hybrid",
num_tiles=16,
tile_wormholes=4,
scales=[64, 128, 256, 384, 512],
num_classes=100,
belly_layers=2,
weighting_mode="learned",
contrast_temperature=0.07,
contrast_weight=0.5,
dropout=0.1
)
train_config = TrainingConfigV2(
run_name=run_name,
model_version=2,
dataset="cifar100",
epochs=200,
batch_size=128,
learning_rate=3e-4,
weight_decay=0.05,
warmup_epochs=10,
# No normalization - raw [0,1] pixels
normalization="none",
ce_weight=1.0,
contrast_weight=0.5,
label_smoothing=0.1,
mixup_alpha=0.2,
cutmix_alpha=1.0,
use_alphamix=True,
output_dir="./checkpoints/cifar100_v2",
push_to_hub=push_to_hub,
hub_repo_id="AbstractPhil/geovit-david-beans",
log_routing=True
)
return train_david_beans_v2(model_config, train_config)
def train_cifar100_v1_baseline(
run_name: str = "v1_baseline",
push_to_hub: bool = False,
resume: bool = False
):
"""CIFAR-100 with V1 (fixed Cantor routing) for comparison."""
model_config = DavidBeansConfig(
image_size=32,
patch_size=4,
dim=512,
num_layers=4,
num_heads=8,
num_experts=5,
k_neighbors=16,
cantor_weight=0.3,
scales=[64, 128, 256, 384, 512],
num_classes=100,
dropout=0.1
)
train_config = TrainingConfigV2(
run_name=run_name,
model_version=1,
dataset="cifar100",
epochs=200,
batch_size=128,
learning_rate=3e-4,
weight_decay=0.05,
warmup_epochs=10,
normalization="standard",
ce_weight=1.0,
contrast_weight=0.5,
label_smoothing=0.1,
mixup_alpha=0.2,
cutmix_alpha=1.0,
use_alphamix=False, # V1 doesn't benefit as much
output_dir="./checkpoints/cifar100_v1",
resume_from="latest" if resume else None,
push_to_hub=push_to_hub,
hub_repo_id="AbstractPhil/geovit-david-beans",
log_routing=False
)
return train_david_beans_v2(model_config, train_config)
# ============================================================================
# MAIN
# ============================================================================
if __name__ == "__main__":
# =====================================================
# CONFIGURATION
# =====================================================
PRESET = "v2_wormhole" # Options: "v1_baseline", "v2_wormhole", "v2_spine", "v2_redundant", "v2_no_norm", "test"
RESUME = False
RUN_NAME = "5scale_2x2patch_alphamix_d512_4layer"
PUSH_TO_HUB = True
# =====================================================
# RUN
# =====================================================
if PRESET == "test":
print("πŸ§ͺ Quick test...")
model_config = DavidBeansV2Config(
image_size=32, patch_size=4, dim=128, num_layers=2,
num_heads=4, num_wormholes=4, num_tiles=8,
scales=[32, 64, 128], num_classes=10,
belly_layers=2
)
train_config = TrainingConfigV2(
run_name="test", model_version=2,
epochs=2, batch_size=32,
use_augmentation=False, mixup_alpha=0.0, cutmix_alpha=0.0,
use_alphamix=False
)
model, acc = train_david_beans_v2(model_config, train_config)
elif PRESET == "v1_baseline":
print("πŸ«˜πŸ’Ž Training DavidBeans V1 (Cantor routing)...")
model, acc = train_cifar100_v1_baseline(
run_name=RUN_NAME,
push_to_hub=PUSH_TO_HUB,
resume=RESUME
)
elif PRESET == "v2_wormhole":
print("πŸ’Ž Training DavidBeans V2 (Wormhole routing)...")
model, acc = train_cifar100_v2_wormhole(
run_name=RUN_NAME,
push_to_hub=PUSH_TO_HUB,
resume=RESUME
)
elif PRESET == "v2_spine":
print("πŸ’ŽπŸ¦΄ Training DavidBeans V2 with Conv Spine...")
model, acc = train_cifar100_v2_with_spine(
run_name=RUN_NAME,
push_to_hub=PUSH_TO_HUB,
resume=RESUME
)
elif PRESET == "v2_redundant":
print("πŸ’Žβœ–οΈ Training DavidBeans V2 with Redundant Scales...")
model, acc = train_cifar100_v2_redundant_scales(
run_name=RUN_NAME,
push_to_hub=PUSH_TO_HUB,
resume=RESUME
)
elif PRESET == "v2_no_norm":
print("πŸ’ŽπŸ“· Training DavidBeans V2 with No Normalization...")
model, acc = train_cifar100_v2_no_norm(
run_name=RUN_NAME,
push_to_hub=PUSH_TO_HUB,
resume=RESUME
)
else:
raise ValueError(f"Unknown preset: {PRESET}")
print(f"\nπŸŽ‰ Done! Best accuracy: {acc:.2f}%")