""" Train DavidBeans V2: Wormhole Routing Architecture =================================================== ┌─────────────────┐ │ BEANS V2.1 │ "I learn where to look..." │ (Wormhole ViT)│ │ 🌀 → 🌀 → 🌀 │ Learned sparse routing └────────┬────────┘ │ ▼ ┌─────────────────┐ │ DAVID │ "I know the crystals..." │ (Classifier) │ │ 💎 → 💎 → 💎 │ Multi-scale projection └────────┬────────┘ │ ▼ [Prediction] Key findings from wormhole experiments: 1. When routing IS the task, routing learns structure 2. Auxiliary losses can be gamed - removed in V2 3. Gradient flow through router is critical - verified 4. Cross-contrastive aligns patch↔scale features V2.1 additions: - AlphaMix augmentation (localized transparent overlay) - Configurable normalization (standard, none, center_only, unit_var) - Support for redundant scales, conv spine, collective mode - Configurable belly depth Author: AbstractPhil Date: November 30, 2025 """ import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader from torch.optim import AdamW from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR from tqdm.auto import tqdm import time import math from pathlib import Path from typing import Dict, Optional, Tuple, List, Union from dataclasses import dataclass, field import json from datetime import datetime import os import shutil from google.colab import userdata os.environ['HF_TOKEN'] = userdata.get('HF_TOKEN') HF_TOKEN = userdata.get('HF_TOKEN') try: from google.colab import userdata HF_TOKEN = userdata.get('HF_TOKEN') os.environ['HF_TOKEN'] = HF_TOKEN except: pass # Import both model versions from geofractal.model.david_beans.model import DavidBeans, DavidBeansConfig from geofractal.model.david_beans.model_v2 import DavidBeansV2, DavidBeansV2Config # HuggingFace Hub integration try: from huggingface_hub import HfApi, create_repo, upload_folder HF_HUB_AVAILABLE = True except ImportError: HF_HUB_AVAILABLE = False print(" [!] huggingface_hub not installed. Run: pip install huggingface_hub") # Safetensors support try: from safetensors.torch import save_file as save_safetensors SAFETENSORS_AVAILABLE = True except ImportError: SAFETENSORS_AVAILABLE = False # TensorBoard support try: from torch.utils.tensorboard import SummaryWriter TENSORBOARD_AVAILABLE = True except ImportError: TENSORBOARD_AVAILABLE = False print(" [!] tensorboard not installed. Run: pip install tensorboard") import numpy as np # ============================================================================ # TRAINING CONFIGURATION V2.1 # ============================================================================ @dataclass class TrainingConfigV2: """Training configuration for DavidBeans V2 with wormhole routing.""" # Run identification run_name: str = "default" run_number: Optional[int] = None # Model version model_version: int = 2 # 1 = original, 2 = wormhole # Data dataset: str = "cifar100" image_size: int = 32 batch_size: int = 128 num_workers: int = 4 # Normalization normalization: str = "standard" # "standard", "none", "center_only", "unit_var" # Training schedule epochs: int = 200 warmup_epochs: int = 10 # Optimizer learning_rate: float = 3e-4 weight_decay: float = 0.05 betas: Tuple[float, float] = (0.9, 0.999) # Learning rate schedule scheduler: str = "cosine" min_lr: float = 1e-6 # Loss weights (based on experimental findings) ce_weight: float = 1.0 contrast_weight: float = 0.5 # NOTE: No auxiliary routing loss - routing learns from task pressure # Regularization gradient_clip: float = 1.0 label_smoothing: float = 0.1 # Augmentation use_augmentation: bool = True mixup_alpha: float = 0.2 cutmix_alpha: float = 1.0 # AlphaMix augmentation use_alphamix: bool = False alphamix_alpha_range: Tuple[float, float] = (0.3, 0.7) alphamix_spatial_ratio: float = 0.25 # Checkpointing save_interval: int = 10 output_dir: str = "./checkpoints" resume_from: Optional[str] = None # TensorBoard use_tensorboard: bool = True log_interval: int = 50 log_routing: bool = True # Log routing patterns # HuggingFace Hub push_to_hub: bool = False hub_repo_id: str = "AbstractPhil/geovit-david-beans" hub_private: bool = False # Device device: str = "cuda" if torch.cuda.is_available() else "cpu" def to_dict(self) -> Dict: return {k: v for k, v in self.__dict__.items()} def __post_init__(self): assert self.normalization in ["standard", "none", "center_only", "unit_var"], \ f"Invalid normalization mode: {self.normalization}" # ============================================================================ # ROUTING METRICS # ============================================================================ class RoutingMetrics: """Track and analyze wormhole routing patterns.""" def __init__(self): self.reset() def reset(self): self.route_entropies = [] self.route_diversities = [] self.grad_norms = {'query': [], 'key': []} @torch.no_grad() def compute_route_entropy(self, soft_routes: torch.Tensor) -> float: """Compute average entropy of routing distributions.""" eps = 1e-8 entropy = -(soft_routes * (soft_routes + eps).log()).sum(dim=-1) return entropy.mean().item() @torch.no_grad() def compute_route_diversity(self, routes: torch.Tensor, num_positions: int) -> float: """Compute how many unique destinations are used.""" unique_per_sample = [] for b in range(routes.shape[0]): unique = routes[b].unique().numel() unique_per_sample.append(unique / num_positions) return sum(unique_per_sample) / len(unique_per_sample) def update_from_routing_info(self, routing_info: List[Dict], model: nn.Module): """Extract metrics from routing info returned by V2 model.""" if not routing_info: return for layer_info in routing_info: if layer_info.get('attention'): attn = layer_info['attention'] if attn.get('weights') is not None: entropy = self.compute_route_entropy(attn['weights']) self.route_entropies.append(entropy) if attn.get('routes') is not None: P = attn['routes'].shape[1] diversity = self.compute_route_diversity(attn['routes'], P) self.route_diversities.append(diversity) if layer_info.get('expert'): exp = layer_info['expert'] if exp.get('weights') is not None: entropy = self.compute_route_entropy(exp['weights']) self.route_entropies.append(entropy) def update_grad_norms(self, model: nn.Module): """Track gradient norms through router projections.""" for name, param in model.named_parameters(): if param.grad is not None: if 'query_proj' in name and 'weight' in name: self.grad_norms['query'].append(param.grad.norm().item()) elif 'key_proj' in name and 'weight' in name: self.grad_norms['key'].append(param.grad.norm().item()) def get_summary(self) -> Dict[str, float]: """Get summary statistics.""" summary = {} if self.route_entropies: summary['route_entropy'] = sum(self.route_entropies) / len(self.route_entropies) if self.route_diversities: summary['route_diversity'] = sum(self.route_diversities) / len(self.route_diversities) if self.grad_norms['query']: summary['grad_query'] = sum(self.grad_norms['query']) / len(self.grad_norms['query']) if self.grad_norms['key']: summary['grad_key'] = sum(self.grad_norms['key']) / len(self.grad_norms['key']) return summary # ============================================================================ # DATA LOADING WITH NORMALIZATION OPTIONS # ============================================================================ def get_normalization_transform(config: TrainingConfigV2, dataset: str): """Get normalization transform based on config.""" import torchvision.transforms as T if dataset == "cifar10": mean = (0.4914, 0.4822, 0.4465) std = (0.2470, 0.2435, 0.2616) elif dataset == "cifar100": mean = (0.5071, 0.4867, 0.4408) std = (0.2675, 0.2565, 0.2761) else: mean = (0.5, 0.5, 0.5) std = (0.5, 0.5, 0.5) if config.normalization == "standard": return T.Normalize(mean, std) elif config.normalization == "none": # No normalization - raw [0, 1] from ToTensor return None elif config.normalization == "center_only": # Center at 0 but don't scale variance return T.Normalize(mean=(0.5, 0.5, 0.5), std=(1.0, 1.0, 1.0)) elif config.normalization == "unit_var": # Scale variance but don't center return T.Normalize(mean=(0.0, 0.0, 0.0), std=std) else: return T.Normalize(mean, std) def get_dataloaders(config: TrainingConfigV2) -> Tuple[DataLoader, DataLoader, int]: """Get train and test dataloaders with configurable normalization.""" try: import torchvision import torchvision.transforms as T if config.dataset == "cifar10": num_classes = 10 DatasetClass = torchvision.datasets.CIFAR10 elif config.dataset == "cifar100": num_classes = 100 DatasetClass = torchvision.datasets.CIFAR100 else: raise ValueError(f"Unknown dataset: {config.dataset}") # Get normalization transform norm_transform = get_normalization_transform(config, config.dataset) # Build train transforms train_transforms = [ T.RandomCrop(32, padding=4), T.RandomHorizontalFlip(), ] if config.use_augmentation: train_transforms.append(T.AutoAugment(T.AutoAugmentPolicy.CIFAR10)) train_transforms.append(T.ToTensor()) if norm_transform is not None: train_transforms.append(norm_transform) train_transform = T.Compose(train_transforms) # Build test transforms test_transforms = [T.ToTensor()] if norm_transform is not None: test_transforms.append(norm_transform) test_transform = T.Compose(test_transforms) print(f" Normalization: {config.normalization}") train_dataset = DatasetClass( root='./data', train=True, download=True, transform=train_transform ) test_dataset = DatasetClass( root='./data', train=False, download=True, transform=test_transform ) train_loader = DataLoader( train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers, pin_memory=True, persistent_workers=config.num_workers > 0, drop_last=True ) test_loader = DataLoader( test_dataset, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers, pin_memory=True, persistent_workers=config.num_workers > 0 ) return train_loader, test_loader, num_classes except ImportError: print(" [!] torchvision not available, using synthetic data") return get_synthetic_dataloaders(config) def get_synthetic_dataloaders(config: TrainingConfigV2) -> Tuple[DataLoader, DataLoader, int]: """Fallback synthetic data for testing.""" class SyntheticDataset(torch.utils.data.Dataset): def __init__(self, size: int, image_size: int, num_classes: int): self.size = size self.image_size = image_size self.num_classes = num_classes def __len__(self): return self.size def __getitem__(self, idx): x = torch.randn(3, self.image_size, self.image_size) y = idx % self.num_classes return x, y num_classes = 100 if config.dataset == "cifar100" else 10 train_dataset = SyntheticDataset(5000, config.image_size, num_classes) test_dataset = SyntheticDataset(1000, config.image_size, num_classes) train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False) return train_loader, test_loader, num_classes # ============================================================================ # MIXING AUGMENTATIONS # ============================================================================ def mixup_data(x: torch.Tensor, y: torch.Tensor, alpha: float = 0.2): """Mixup augmentation.""" if alpha > 0: lam = torch.distributions.Beta(alpha, alpha).sample().item() else: lam = 1.0 batch_size = x.size(0) index = torch.randperm(batch_size, device=x.device) mixed_x = lam * x + (1 - lam) * x[index] y_a, y_b = y, y[index] return mixed_x, y_a, y_b, lam def cutmix_data(x: torch.Tensor, y: torch.Tensor, alpha: float = 1.0): """CutMix augmentation.""" if alpha > 0: lam = torch.distributions.Beta(alpha, alpha).sample().item() else: lam = 1.0 batch_size = x.size(0) index = torch.randperm(batch_size, device=x.device) _, _, H, W = x.shape cut_ratio = math.sqrt(1 - lam) cut_h = int(H * cut_ratio) cut_w = int(W * cut_ratio) cx = torch.randint(0, H, (1,)).item() cy = torch.randint(0, W, (1,)).item() x1 = max(0, cx - cut_h // 2) x2 = min(H, cx + cut_h // 2) y1 = max(0, cy - cut_w // 2) y2 = min(W, cy + cut_w // 2) mixed_x = x.clone() mixed_x[:, :, x1:x2, y1:y2] = x[index, :, x1:x2, y1:y2] lam = 1 - ((x2 - x1) * (y2 - y1)) / (H * W) y_a, y_b = y, y[index] return mixed_x, y_a, y_b, lam def alphamix_data( x: torch.Tensor, y: torch.Tensor, alpha_range: Tuple[float, float] = (0.3, 0.7), spatial_ratio: float = 0.25 ): """ AlphaMix: Spatially localized transparent overlay. Unlike CutMix (full replacement) or Mixup (global blend), AlphaMix creates a localized alpha-blended region. Args: x: [B, C, H, W] input images y: [B] labels alpha_range: (min, max) for alpha blending in overlay region spatial_ratio: Fraction of image area for overlay Returns: mixed_x, y_a, y_b, lam (effective lambda for loss weighting) """ batch_size = x.size(0) index = torch.randperm(batch_size, device=x.device) y_a, y_b = y, y[index] # Sample alpha from beta distribution within range alpha_min, alpha_max = alpha_range beta_sample = np.random.beta(2, 2) alpha = alpha_min + (alpha_max - alpha_min) * beta_sample _, _, H, W = x.shape # Compute overlay region size overlay_ratio = np.sqrt(spatial_ratio) overlay_h = max(4, int(H * overlay_ratio)) overlay_w = max(4, int(W * overlay_ratio)) # Random position for overlay top = np.random.randint(0, max(1, H - overlay_h + 1)) left = np.random.randint(0, max(1, W - overlay_w + 1)) # Create composited image composited_x = x.clone() # Alpha blend in the overlay region overlay_region = alpha * x[:, :, top:top + overlay_h, left:left + overlay_w] background_region = (1 - alpha) * x[index, :, top:top + overlay_h, left:left + overlay_w] composited_x[:, :, top:top + overlay_h, left:left + overlay_w] = overlay_region + background_region # Compute effective lambda based on blended area blended_area = (overlay_h * overlay_w) / (H * W) # lam represents contribution of original sample # In non-blended region: 100% original # In blended region: alpha% original lam = 1.0 - blended_area * (1 - alpha) return composited_x, y_a, y_b, lam # ============================================================================ # METRICS TRACKER # ============================================================================ class MetricsTracker: """Track training metrics with EMA smoothing.""" def __init__(self, ema_decay: float = 0.9): self.ema_decay = ema_decay self.metrics = {} self.ema_metrics = {} self.history = {} def update(self, **kwargs): for k, v in kwargs.items(): if isinstance(v, torch.Tensor): v = v.item() if k not in self.metrics: self.metrics[k] = [] self.ema_metrics[k] = v self.history[k] = [] self.metrics[k].append(v) self.ema_metrics[k] = self.ema_decay * self.ema_metrics[k] + (1 - self.ema_decay) * v def get_ema(self, key: str) -> float: return self.ema_metrics.get(key, 0.0) def get_epoch_mean(self, key: str) -> float: values = self.metrics.get(key, []) return sum(values) / len(values) if values else 0.0 def end_epoch(self): for k, v in self.metrics.items(): if v: self.history[k].append(sum(v) / len(v)) self.metrics = {k: [] for k in self.metrics} def get_history(self) -> Dict: return self.history # ============================================================================ # CHECKPOINT UTILITIES # ============================================================================ def find_latest_checkpoint(output_dir: Path) -> Optional[Path]: """Find the most recent checkpoint in output directory.""" checkpoints = list(output_dir.glob("checkpoint_epoch_*.pt")) if not checkpoints: best_model = output_dir / "best_model.pt" if best_model.exists(): return best_model return None def get_epoch(p): try: return int(p.stem.split("_")[-1]) except: return 0 checkpoints.sort(key=get_epoch, reverse=True) return checkpoints[0] def get_next_run_number(base_dir: Path) -> int: """Get the next run number by scanning existing run directories.""" if not base_dir.exists(): return 1 max_num = 0 for d in base_dir.iterdir(): if d.is_dir() and d.name.startswith("run_"): try: num = int(d.name.split("_")[1]) max_num = max(max_num, num) except (IndexError, ValueError): continue return max_num + 1 def generate_run_dir_name(run_number: int, run_name: str, version: int = 2) -> str: """Generate a run directory name.""" timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") safe_name = "".join(c if c.isalnum() or c == "_" else "_" for c in run_name.lower()) safe_name = "_".join(filter(None, safe_name.split("_"))) return f"run_{run_number:03d}_v{version}_{safe_name}_{timestamp}" def find_latest_run_dir(base_dir: Path) -> Optional[Path]: """Find the most recent run directory.""" if not base_dir.exists(): return None run_dirs = [d for d in base_dir.iterdir() if d.is_dir() and d.name.startswith("run_")] if not run_dirs: return None run_dirs.sort(key=lambda d: d.stat().st_mtime, reverse=True) return run_dirs[0] def load_checkpoint( checkpoint_path: Path, model: nn.Module, optimizer: Optional[torch.optim.Optimizer] = None, device: str = "cuda" ) -> Tuple[int, float]: """Load checkpoint and return (start_epoch, best_acc).""" print(f"\n📂 Loading checkpoint: {checkpoint_path}") checkpoint = torch.load(checkpoint_path, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) print(f" ✓ Loaded model weights") if optimizer is not None and 'optimizer_state_dict' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer_state_dict']) print(f" ✓ Loaded optimizer state") epoch = checkpoint.get('epoch', 0) best_acc = checkpoint.get('best_acc', 0.0) print(f" ✓ Resuming from epoch {epoch + 1}, best_acc={best_acc:.2f}%") return epoch + 1, best_acc # ============================================================================ # HUGGINGFACE HUB INTEGRATION # ============================================================================ def generate_run_readme( model_config: Union[DavidBeansConfig, DavidBeansV2Config], train_config: TrainingConfigV2, best_acc: float, run_dir_name: str ) -> str: """Generate README for a specific run.""" scales_str = ", ".join([str(s) for s in model_config.scales]) # V2 specific info if isinstance(model_config, DavidBeansV2Config): copies_str = "" if model_config.scale_copies: copies_str = f"\n| Scale Copies | {model_config.scale_copies} |" routing_info = f""" ## Wormhole Routing (V2) | Parameter | Value | |-----------|-------| | Mode | {model_config.wormhole_mode} | | Wormholes/Position | {model_config.num_wormholes} | | Temperature | {model_config.wormhole_temperature} | | Tiles | {model_config.num_tiles} | | Tile Wormholes | {model_config.tile_wormholes} | ## Crystal Head | Parameter | Value | |-----------|-------| | Scales | [{scales_str}] |{copies_str} | Weighting Mode | {model_config.weighting_mode} | | Belly Layers | {model_config.belly_layers} | | Belly Residual | {model_config.belly_residual} | | Use Spine | {model_config.use_spine} | | Use Collective | {model_config.use_collective} | """ else: routing_info = f""" ## Routing (V1) | Parameter | Value | |-----------|-------| | k_neighbors | {model_config.k_neighbors} | | Cantor Weight | {model_config.cantor_weight} | """ aug_info = f""" ## Augmentation | Parameter | Value | |-----------|-------| | Normalization | {train_config.normalization} | | Mixup Alpha | {train_config.mixup_alpha} | | CutMix Alpha | {train_config.cutmix_alpha} | | AlphaMix | {train_config.use_alphamix} | | Label Smoothing | {train_config.label_smoothing} | """ return f"""# Run: {run_dir_name} ## Results - **Best Accuracy**: {best_acc:.2f}% - **Dataset**: {train_config.dataset} - **Epochs**: {train_config.epochs} - **Model Version**: V{train_config.model_version} ## Model Config | Parameter | Value | |-----------|-------| | Dim | {model_config.dim} | | Layers | {model_config.num_layers} | | Heads | {model_config.num_heads} | | Patch Size | {model_config.patch_size} | {routing_info} ## Training Config | Parameter | Value | |-----------|-------| | Learning Rate | {train_config.learning_rate} | | Weight Decay | {train_config.weight_decay} | | Batch Size | {train_config.batch_size} | | CE Weight | {train_config.ce_weight} | | Contrast Weight | {train_config.contrast_weight} | {aug_info} ## Key Findings Applied - Routing learns from task pressure (no auxiliary routing losses) - Gradients verified to flow through router - Cross-contrastive aligns patch↔scale features """ def prepare_run_for_hub( model: nn.Module, model_config: Union[DavidBeansConfig, DavidBeansV2Config], train_config: TrainingConfigV2, best_acc: float, output_dir: Path, run_dir_name: str, training_history: Optional[Dict] = None ) -> Path: """Prepare run files for upload to HuggingFace Hub.""" hub_dir = output_dir / "hub_upload" run_hub_dir = hub_dir / "weights" / run_dir_name run_hub_dir.mkdir(parents=True, exist_ok=True) state_dict = {k: v.clone() for k, v in model.state_dict().items()} if SAFETENSORS_AVAILABLE: try: save_safetensors(state_dict, run_hub_dir / "best.safetensors") print(f" ✓ Saved best.safetensors") except Exception as e: print(f" [!] Safetensors failed ({e}), using pytorch format") torch.save(state_dict, run_hub_dir / "best.pt") else: torch.save(state_dict, run_hub_dir / "best.pt") config_dict = { "architecture": f"DavidBeans_V{train_config.model_version}", "model_type": "david_beans_v2" if train_config.model_version == 2 else "david_beans", **model_config.__dict__ } with open(run_hub_dir / "config.json", "w") as f: json.dump(config_dict, f, indent=2, default=str) with open(run_hub_dir / "training_config.json", "w") as f: json.dump(train_config.to_dict(), f, indent=2, default=str) run_readme = generate_run_readme(model_config, train_config, best_acc, run_dir_name) with open(run_hub_dir / "README.md", "w") as f: f.write(run_readme) if training_history: with open(run_hub_dir / "training_history.json", "w") as f: json.dump(training_history, f, indent=2) tb_dir = output_dir / "tensorboard" if tb_dir.exists(): hub_tb_dir = run_hub_dir / "tensorboard" if hub_tb_dir.exists(): shutil.rmtree(hub_tb_dir) shutil.copytree(tb_dir, hub_tb_dir) return hub_dir def push_run_to_hub( hub_dir: Path, repo_id: str, run_dir_name: str, private: bool = False, commit_message: Optional[str] = None ) -> str: """Push run files to HuggingFace Hub.""" if not HF_HUB_AVAILABLE: raise RuntimeError("huggingface_hub not installed") api = HfApi() try: create_repo(repo_id, private=private, exist_ok=True) except Exception as e: print(f" [!] Repo creation note: {e}") run_upload_dir = hub_dir / "weights" / run_dir_name if commit_message is None: commit_message = f"Update {run_dir_name} - {datetime.now().strftime('%Y-%m-%d %H:%M')}" url = upload_folder( folder_path=str(run_upload_dir), repo_id=repo_id, path_in_repo=f"weights/{run_dir_name}", commit_message=commit_message ) return url # ============================================================================ # TRAINING LOOP V2 # ============================================================================ def train_epoch_v2( model: nn.Module, train_loader: DataLoader, optimizer: torch.optim.Optimizer, scheduler: Optional[torch.optim.lr_scheduler._LRScheduler], config: TrainingConfigV2, epoch: int, tracker: MetricsTracker, routing_metrics: RoutingMetrics, writer: Optional['SummaryWriter'] = None ) -> Dict[str, float]: """Train for one epoch with V2 routing metrics and AlphaMix support.""" model.train() device = config.device is_v2 = config.model_version == 2 total_loss = 0.0 total_correct = 0 total_samples = 0 global_step = epoch * len(train_loader) routing_metrics.reset() pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}", leave=True) for batch_idx, (images, targets) in enumerate(pbar): images = images.to(device, non_blocking=True) targets = targets.to(device, non_blocking=True) # Apply mixing augmentations use_mixup = config.use_augmentation and config.mixup_alpha > 0 use_cutmix = config.use_augmentation and config.cutmix_alpha > 0 use_alphamix = config.use_alphamix mixed = False mix_type = None if use_mixup or use_cutmix or use_alphamix: r = torch.rand(1).item() # Probability distribution for mix types # If all three enabled: 40% none, 20% mixup, 20% cutmix, 20% alphamix # Adjust based on what's enabled thresholds = [0.4] # Base: 40% no mixing enabled_mixes = [] if use_mixup: enabled_mixes.append(('mixup', config.mixup_alpha)) if use_cutmix: enabled_mixes.append(('cutmix', config.cutmix_alpha)) if use_alphamix: enabled_mixes.append(('alphamix', None)) if enabled_mixes: mix_prob = 0.6 / len(enabled_mixes) # Split remaining 60% among enabled cumulative = 0.4 for i, (mix_name, _) in enumerate(enabled_mixes): cumulative += mix_prob thresholds.append(cumulative) # Determine which mix to use if r < 0.4: pass # No mixing else: for i, (mix_name, mix_param) in enumerate(enabled_mixes): if r < thresholds[i + 1]: mix_type = mix_name break if mix_type == 'mixup': images, targets_a, targets_b, lam = mixup_data(images, targets, config.mixup_alpha) mixed = True elif mix_type == 'cutmix': images, targets_a, targets_b, lam = cutmix_data(images, targets, config.cutmix_alpha) mixed = True elif mix_type == 'alphamix': images, targets_a, targets_b, lam = alphamix_data( images, targets, alpha_range=config.alphamix_alpha_range, spatial_ratio=config.alphamix_spatial_ratio ) mixed = True # Forward pass if is_v2: result = model( images, targets=targets, return_loss=True, return_routing=(batch_idx % 10 == 0) ) else: result = model(images, targets=targets, return_loss=True) losses = result['losses'] # Handle mixed CE loss if mixed: logits = result['logits'] ce_loss = lam * F.cross_entropy(logits, targets_a, label_smoothing=config.label_smoothing) + \ (1 - lam) * F.cross_entropy(logits, targets_b, label_smoothing=config.label_smoothing) losses['ce'] = ce_loss # Compute total loss (NO auxiliary routing loss - key finding!) loss = ( config.ce_weight * losses['ce'] + config.contrast_weight * losses.get('contrast', torch.tensor(0.0, device=device)) ) # Add scale CE losses (handle both regular and copy scales) for key, val in losses.items(): if key.startswith('ce_') and key != 'ce': if isinstance(val, torch.Tensor): loss = loss + 0.1 * val # Backward pass optimizer.zero_grad() loss.backward() # Track routing gradient norms if is_v2: routing_metrics.update_grad_norms(model) if config.gradient_clip > 0: grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.gradient_clip) else: grad_norm = 0.0 optimizer.step() if scheduler is not None and config.scheduler == "onecycle": scheduler.step() # Update routing metrics if is_v2 and result.get('routing'): routing_metrics.update_from_routing_info(result['routing'], model) # Compute accuracy with torch.no_grad(): logits = result['logits'] preds = logits.argmax(dim=-1) if mixed: correct = (lam * (preds == targets_a).float() + (1 - lam) * (preds == targets_b).float()).sum() else: correct = (preds == targets).sum() total_correct += correct.item() total_samples += targets.size(0) total_loss += loss.item() # Track metrics def to_float(v): return v.item() if isinstance(v, torch.Tensor) else float(v) contrast_loss = to_float(losses.get('contrast', 0.0)) current_lr = optimizer.param_groups[0]['lr'] tracker.update( loss=loss.item(), ce=losses['ce'].item(), contrast=contrast_loss, lr=current_lr ) # TensorBoard logging if writer is not None and (batch_idx + 1) % config.log_interval == 0: step = global_step + batch_idx writer.add_scalar('train/loss_total', loss.item(), step) writer.add_scalar('train/loss_ce', losses['ce'].item(), step) writer.add_scalar('train/loss_contrast', contrast_loss, step) writer.add_scalar('train/learning_rate', current_lr, step) writer.add_scalar('train/grad_norm', to_float(grad_norm), step) if is_v2 and config.log_routing: routing_summary = routing_metrics.get_summary() for k, v in routing_summary.items(): writer.add_scalar(f'routing/{k}', v, step) # Progress bar routing_summary = routing_metrics.get_summary() postfix = { 'loss': f"{tracker.get_ema('loss'):.3f}", 'acc': f"{100.0 * total_correct / total_samples:.1f}%", } if is_v2 and 'grad_query' in routing_summary: postfix['∇q'] = f"{routing_summary['grad_query']:.2f}" if 'route_entropy' in routing_summary: postfix['H'] = f"{routing_summary['route_entropy']:.2f}" pbar.set_postfix(postfix) if scheduler is not None and config.scheduler == "cosine": scheduler.step() return { 'loss': total_loss / len(train_loader), 'acc': 100.0 * total_correct / total_samples, **routing_metrics.get_summary() } @torch.no_grad() def evaluate_v2( model: nn.Module, test_loader: DataLoader, config: TrainingConfigV2 ) -> Dict[str, float]: """Evaluate on test set.""" model.eval() device = config.device total_loss = 0.0 total_correct = 0 total_samples = 0 # Handle variable number of scale heads (including copies) num_heads = len(model.head.heads) if hasattr(model.head, 'heads') else len(model.config.scales) head_correct = [0] * num_heads for images, targets in test_loader: images = images.to(device, non_blocking=True) targets = targets.to(device, non_blocking=True) result = model(images, targets=targets, return_loss=True) logits = result['logits'] losses = result['losses'] loss = losses['total'] preds = logits.argmax(dim=-1) total_loss += loss.item() * targets.size(0) total_correct += (preds == targets).sum().item() total_samples += targets.size(0) # Per-head accuracy for i, scale_logits in enumerate(result['scale_logits']): scale_preds = scale_logits.argmax(dim=-1) head_correct[i] += (scale_preds == targets).sum().item() metrics = { 'loss': total_loss / total_samples, 'acc': 100.0 * total_correct / total_samples } # Map head indices to scale names if hasattr(model.head, 'head_scale_map'): for i, (scale, copy_idx) in enumerate(model.head.head_scale_map): key = f'acc_{scale}' if copy_idx == 0 else f'acc_{scale}_c{copy_idx}' metrics[key] = 100.0 * head_correct[i] / total_samples else: for i, scale in enumerate(model.config.scales): metrics[f'acc_{scale}'] = 100.0 * head_correct[i] / total_samples return metrics # ============================================================================ # MAIN TRAINING FUNCTION V2 # ============================================================================ def train_david_beans_v2( model_config: Optional[Union[DavidBeansConfig, DavidBeansV2Config]] = None, train_config: Optional[TrainingConfigV2] = None ): """Main training function for DavidBeans V1 or V2.""" print("=" * 70) print(" DAVID-BEANS V2.1 TRAINING: Wormhole Routing") print("=" * 70) print() print(" 🌀 WORMHOLES: Learned sparse routing") print(" 💎 CRYSTALS: Multi-scale projection") print() print(" Key insight: When routing IS the task, routing learns structure") print() print("=" * 70) if train_config is None: train_config = TrainingConfigV2() base_output_dir = Path(train_config.output_dir) base_output_dir.mkdir(parents=True, exist_ok=True) # Checkpoint resolution checkpoint_path = None run_dir = None run_dir_name = None if train_config.resume_from: resume_path = Path(train_config.resume_from) if resume_path.is_file(): checkpoint_path = resume_path run_dir = checkpoint_path.parent run_dir_name = run_dir.name print(f"\n📂 Found checkpoint file: {checkpoint_path.name}") elif resume_path.is_dir(): checkpoint_path = find_latest_checkpoint(resume_path) if checkpoint_path: run_dir = resume_path run_dir_name = resume_path.name print(f"\n📂 Found checkpoint in dir: {checkpoint_path.name}") else: possible_dir = base_output_dir / train_config.resume_from if possible_dir.is_dir(): checkpoint_path = find_latest_checkpoint(possible_dir) if checkpoint_path: run_dir = possible_dir run_dir_name = possible_dir.name print(f"\n📂 Found checkpoint in: {run_dir_name}") if checkpoint_path is None: possible_file = base_output_dir / train_config.resume_from if possible_file.is_file(): checkpoint_path = possible_file run_dir = checkpoint_path.parent run_dir_name = run_dir.name print(f"\n📂 Found checkpoint: {checkpoint_path.name}") if checkpoint_path is None: print(f"\n [!] Could not find checkpoint: {train_config.resume_from}") print(f" [!] Starting fresh run instead") else: print(f" ✓ Will resume from: {checkpoint_path}") # Create new run directory if not resuming if run_dir is None: run_number = train_config.run_number or get_next_run_number(base_output_dir) run_dir_name = generate_run_dir_name(run_number, train_config.run_name, train_config.model_version) run_dir = base_output_dir / run_dir_name run_dir.mkdir(parents=True, exist_ok=True) print(f"\n📁 New run: {run_dir_name}") else: print(f"\n📁 Resuming run: {run_dir_name}") output_dir = run_dir # Model config if checkpoint_path and checkpoint_path.exists() and model_config is None: try: ckpt = torch.load(checkpoint_path, map_location='cpu') if 'model_config' in ckpt: saved_config = ckpt['model_config'] print(f" ✓ Loading model config from checkpoint") if train_config.model_version == 2: model_config = DavidBeansV2Config(**saved_config) else: model_config = DavidBeansConfig(**saved_config) except Exception as e: print(f" [!] Could not load config from checkpoint: {e}") if model_config is None: if train_config.model_version == 2: model_config = DavidBeansV2Config( image_size=train_config.image_size, patch_size=4, dim=512, num_layers=4, num_heads=8, num_wormholes=8, wormhole_temperature=0.1, wormhole_mode="hybrid", num_tiles=16, tile_wormholes=4, scales=[64, 128, 256, 384, 512], num_classes=100, contrast_weight=train_config.contrast_weight, dropout=0.1 ) else: model_config = DavidBeansConfig( image_size=train_config.image_size, patch_size=4, dim=512, num_layers=4, num_heads=8, num_experts=5, k_neighbors=16, cantor_weight=0.3, scales=[64, 128, 256, 384, 512], num_classes=100, dropout=0.1 ) device = train_config.device print(f"\nDevice: {device}") print(f"Model version: V{train_config.model_version}") # Data print("\nLoading data...") train_loader, test_loader, num_classes = get_dataloaders(train_config) print(f" Dataset: {train_config.dataset}") print(f" Train: {len(train_loader.dataset)}, Test: {len(test_loader.dataset)}") print(f" Classes: {num_classes}") model_config.num_classes = num_classes # Model print("\nBuilding model...") if train_config.model_version == 2: model = DavidBeansV2(model_config) else: model = DavidBeans(model_config) model = model.to(device) print(f"\n{model}") num_params = sum(p.numel() for p in model.parameters()) print(f"\nParameters: {num_params:,}") # Optimizer print("\nSetting up optimizer...") decay_params = [] no_decay_params = [] for name, param in model.named_parameters(): if not param.requires_grad: continue if 'bias' in name or 'norm' in name or 'embedding' in name: no_decay_params.append(param) else: decay_params.append(param) optimizer = AdamW([ {'params': decay_params, 'weight_decay': train_config.weight_decay}, {'params': no_decay_params, 'weight_decay': 0.0} ], lr=train_config.learning_rate, betas=train_config.betas) if train_config.scheduler == "cosine": scheduler = CosineAnnealingLR( optimizer, T_max=train_config.epochs - train_config.warmup_epochs, eta_min=train_config.min_lr ) elif train_config.scheduler == "onecycle": scheduler = OneCycleLR( optimizer, max_lr=train_config.learning_rate, epochs=train_config.epochs, steps_per_epoch=len(train_loader), pct_start=train_config.warmup_epochs / train_config.epochs ) else: scheduler = None print(f" Optimizer: AdamW (lr={train_config.learning_rate}, wd={train_config.weight_decay})") print(f" Scheduler: {train_config.scheduler}") # Print augmentation config print(f"\nAugmentation:") print(f" Mixup: {train_config.mixup_alpha if train_config.mixup_alpha > 0 else 'disabled'}") print(f" CutMix: {train_config.cutmix_alpha if train_config.cutmix_alpha > 0 else 'disabled'}") print(f" AlphaMix: {train_config.alphamix_alpha_range if train_config.use_alphamix else 'disabled'}") tracker = MetricsTracker() routing_metrics = RoutingMetrics() best_acc = 0.0 start_epoch = 0 # Load checkpoint if checkpoint_path and checkpoint_path.exists(): start_epoch, best_acc = load_checkpoint(checkpoint_path, model, optimizer, device) if scheduler is not None and train_config.scheduler == "cosine": for _ in range(start_epoch): scheduler.step() print(f" ✓ Advanced scheduler to epoch {start_epoch}") # TensorBoard writer = None if train_config.use_tensorboard and TENSORBOARD_AVAILABLE: tb_dir = output_dir / "tensorboard" tb_dir.mkdir(parents=True, exist_ok=True) writer = SummaryWriter(log_dir=str(tb_dir)) print(f" TensorBoard: {tb_dir}") # Save configs with open(output_dir / "config.json", "w") as f: json.dump({**model_config.__dict__, "architecture": f"DavidBeans_V{train_config.model_version}"}, f, indent=2, default=str) with open(output_dir / "training_config.json", "w") as f: json.dump(train_config.to_dict(), f, indent=2, default=str) # Training loop print("\n" + "=" * 70) print(" TRAINING") print("=" * 70) for epoch in range(start_epoch, train_config.epochs): epoch_start = time.time() # Warmup if epoch < train_config.warmup_epochs and train_config.scheduler == "cosine": warmup_lr = train_config.learning_rate * (epoch + 1) / train_config.warmup_epochs for param_group in optimizer.param_groups: param_group['lr'] = warmup_lr train_metrics = train_epoch_v2( model, train_loader, optimizer, scheduler, train_config, epoch, tracker, routing_metrics, writer ) test_metrics = evaluate_v2(model, test_loader, train_config) epoch_time = time.time() - epoch_start # TensorBoard if writer is not None: writer.add_scalar('epoch/train_loss', train_metrics['loss'], epoch) writer.add_scalar('epoch/train_acc', train_metrics['acc'], epoch) writer.add_scalar('epoch/test_loss', test_metrics['loss'], epoch) writer.add_scalar('epoch/test_acc', test_metrics['acc'], epoch) # Log all scale accuracies for key, val in test_metrics.items(): if key.startswith('acc_'): writer.add_scalar(f'scales/{key}', val, epoch) # Print summary - show primary scales only (not copies) primary_scale_accs = [] for scale in model.config.scales: if f'acc_{scale}' in test_metrics: primary_scale_accs.append(f"{scale}:{test_metrics[f'acc_{scale}']:.1f}%") scale_accs = " | ".join(primary_scale_accs) star = "★" if test_metrics['acc'] > best_acc else "" routing_info = "" if train_config.model_version == 2 and 'grad_query' in train_metrics: routing_info = f" | ∇q:{train_metrics.get('grad_query', 0):.2f}" print(f" → Train: {train_metrics['acc']:.1f}% | Test: {test_metrics['acc']:.1f}% | " f"[{scale_accs}]{routing_info} | {epoch_time:.0f}s {star}") # Save best model if test_metrics['acc'] > best_acc: best_acc = test_metrics['acc'] torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'best_acc': best_acc, 'model_config': model_config.__dict__, 'train_config': train_config.to_dict() }, output_dir / "best_model.pt") # Periodic checkpoint if (epoch + 1) % train_config.save_interval == 0: torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'best_acc': best_acc, 'model_config': model_config.__dict__, 'train_config': train_config.to_dict() }, output_dir / f"checkpoint_epoch_{epoch + 1}.pt") if train_config.push_to_hub and HF_HUB_AVAILABLE: try: hub_dir = prepare_run_for_hub( model=model, model_config=model_config, train_config=train_config, best_acc=best_acc, output_dir=output_dir, run_dir_name=run_dir_name, training_history=tracker.get_history() ) push_run_to_hub( hub_dir=hub_dir, repo_id=train_config.hub_repo_id, run_dir_name=run_dir_name, commit_message=f"Epoch {epoch + 1} - {best_acc:.2f}% acc" ) print(f" 📤 Uploaded to hub") except Exception as e: print(f" [!] Hub upload failed: {e}") tracker.end_epoch() # Final summary print("\n" + "=" * 70) print(" TRAINING COMPLETE") print("=" * 70) print(f"\n Best Test Accuracy: {best_acc:.2f}%") print(f" Model saved to: {output_dir / 'best_model.pt'}") if writer is not None: writer.close() return model, best_acc # ============================================================================ # PRESETS # ============================================================================ def train_cifar100_v2_wormhole( run_name: str = "wormhole_base", push_to_hub: bool = False, resume: bool = False ): """CIFAR-100 with V2 wormhole routing.""" model_config = DavidBeansV2Config( image_size=32, patch_size=2, dim=512, num_layers=4, num_heads=16, # Wormhole routing parameters num_wormholes=16, wormhole_temperature=0.1, wormhole_mode="hybrid", # Tessellation parameters num_tiles=16, tile_wormholes=4, # Crystal head scales=[64, 128, 256, 512, 1024], num_classes=100, # V2.1 additions belly_layers=2, belly_residual=False, weighting_mode="learned", scale_copies=None, use_spine=False, use_collective=False, # Other contrast_temperature=0.07, contrast_weight=0.5, dropout=0.1 ) train_config = TrainingConfigV2( run_name=run_name, model_version=2, dataset="cifar100", epochs=300, batch_size=512, learning_rate=3e-4, weight_decay=0.05, warmup_epochs=15, # Normalization normalization="standard", # Loss weights ce_weight=1.0, contrast_weight=0.5, # Augmentation label_smoothing=0.1, mixup_alpha=0.2, cutmix_alpha=1.0, # AlphaMix use_alphamix=True, alphamix_alpha_range=(0.3, 0.7), alphamix_spatial_ratio=0.25, # Output output_dir="./checkpoints/cifar100_v2", resume_from=None, # Hub push_to_hub=push_to_hub, hub_repo_id="AbstractPhil/geovit-david-beans", # Routing logging log_routing=True ) return train_david_beans_v2(model_config, train_config) def train_cifar100_v2_with_spine( run_name: str = "wormhole_spine", push_to_hub: bool = False, resume: bool = False ): """CIFAR-100 with V2 wormhole routing + conv spine.""" model_config = DavidBeansV2Config( image_size=32, patch_size=4, dim=512, num_layers=4, num_heads=8, num_wormholes=8, wormhole_temperature=0.1, wormhole_mode="hybrid", num_tiles=16, tile_wormholes=4, scales=[64, 128, 256, 384, 512], num_classes=100, # Enable spine use_spine=True, spine_channels=[64, 128, 256], spine_cross_attn=True, spine_gate_init=0.0, # Belly belly_layers=2, weighting_mode="geometric", contrast_temperature=0.07, contrast_weight=0.5, dropout=0.1 ) train_config = TrainingConfigV2( run_name=run_name, model_version=2, dataset="cifar100", epochs=200, batch_size=128, learning_rate=3e-4, weight_decay=0.05, warmup_epochs=10, normalization="standard", ce_weight=1.0, contrast_weight=0.5, label_smoothing=0.1, mixup_alpha=0.2, cutmix_alpha=1.0, use_alphamix=True, output_dir="./checkpoints/cifar100_v2", push_to_hub=push_to_hub, hub_repo_id="AbstractPhil/geovit-david-beans", log_routing=True ) return train_david_beans_v2(model_config, train_config) def train_cifar100_v2_redundant_scales( run_name: str = "wormhole_redundant", push_to_hub: bool = False, resume: bool = False ): """CIFAR-100 with redundant small scales for ensemble effect.""" model_config = DavidBeansV2Config( image_size=32, patch_size=4, dim=512, num_layers=4, num_heads=8, num_wormholes=8, wormhole_temperature=0.1, wormhole_mode="hybrid", num_tiles=16, tile_wormholes=4, scales=[64, 128, 256, 512], # Redundant copies: 4x 64d, 2x 128d, 1x 256d, 1x 512d scale_copies=[4, 2, 1, 1], copy_theta_step=0.15, num_classes=100, weighting_mode="geometric", belly_layers=2, contrast_temperature=0.07, contrast_weight=0.5, dropout=0.1 ) train_config = TrainingConfigV2( run_name=run_name, model_version=2, dataset="cifar100", epochs=200, batch_size=128, learning_rate=3e-4, weight_decay=0.05, warmup_epochs=10, normalization="standard", ce_weight=1.0, contrast_weight=0.5, label_smoothing=0.1, mixup_alpha=0.2, cutmix_alpha=1.0, use_alphamix=True, output_dir="./checkpoints/cifar100_v2", push_to_hub=push_to_hub, hub_repo_id="AbstractPhil/geovit-david-beans", log_routing=True ) return train_david_beans_v2(model_config, train_config) def train_cifar100_v2_no_norm( run_name: str = "wormhole_no_norm", push_to_hub: bool = False, resume: bool = False ): """CIFAR-100 with no normalization (raw pixels) for geometric components.""" model_config = DavidBeansV2Config( image_size=32, patch_size=4, dim=512, num_layers=4, num_heads=8, num_wormholes=8, wormhole_temperature=0.1, wormhole_mode="hybrid", num_tiles=16, tile_wormholes=4, scales=[64, 128, 256, 384, 512], num_classes=100, belly_layers=2, weighting_mode="learned", contrast_temperature=0.07, contrast_weight=0.5, dropout=0.1 ) train_config = TrainingConfigV2( run_name=run_name, model_version=2, dataset="cifar100", epochs=200, batch_size=128, learning_rate=3e-4, weight_decay=0.05, warmup_epochs=10, # No normalization - raw [0,1] pixels normalization="none", ce_weight=1.0, contrast_weight=0.5, label_smoothing=0.1, mixup_alpha=0.2, cutmix_alpha=1.0, use_alphamix=True, output_dir="./checkpoints/cifar100_v2", push_to_hub=push_to_hub, hub_repo_id="AbstractPhil/geovit-david-beans", log_routing=True ) return train_david_beans_v2(model_config, train_config) def train_cifar100_v1_baseline( run_name: str = "v1_baseline", push_to_hub: bool = False, resume: bool = False ): """CIFAR-100 with V1 (fixed Cantor routing) for comparison.""" model_config = DavidBeansConfig( image_size=32, patch_size=4, dim=512, num_layers=4, num_heads=8, num_experts=5, k_neighbors=16, cantor_weight=0.3, scales=[64, 128, 256, 384, 512], num_classes=100, dropout=0.1 ) train_config = TrainingConfigV2( run_name=run_name, model_version=1, dataset="cifar100", epochs=200, batch_size=128, learning_rate=3e-4, weight_decay=0.05, warmup_epochs=10, normalization="standard", ce_weight=1.0, contrast_weight=0.5, label_smoothing=0.1, mixup_alpha=0.2, cutmix_alpha=1.0, use_alphamix=False, # V1 doesn't benefit as much output_dir="./checkpoints/cifar100_v1", resume_from="latest" if resume else None, push_to_hub=push_to_hub, hub_repo_id="AbstractPhil/geovit-david-beans", log_routing=False ) return train_david_beans_v2(model_config, train_config) # ============================================================================ # MAIN # ============================================================================ if __name__ == "__main__": # ===================================================== # CONFIGURATION # ===================================================== PRESET = "v2_wormhole" # Options: "v1_baseline", "v2_wormhole", "v2_spine", "v2_redundant", "v2_no_norm", "test" RESUME = False RUN_NAME = "5scale_2x2patch_alphamix_d512_4layer" PUSH_TO_HUB = True # ===================================================== # RUN # ===================================================== if PRESET == "test": print("🧪 Quick test...") model_config = DavidBeansV2Config( image_size=32, patch_size=4, dim=128, num_layers=2, num_heads=4, num_wormholes=4, num_tiles=8, scales=[32, 64, 128], num_classes=10, belly_layers=2 ) train_config = TrainingConfigV2( run_name="test", model_version=2, epochs=2, batch_size=32, use_augmentation=False, mixup_alpha=0.0, cutmix_alpha=0.0, use_alphamix=False ) model, acc = train_david_beans_v2(model_config, train_config) elif PRESET == "v1_baseline": print("🫘💎 Training DavidBeans V1 (Cantor routing)...") model, acc = train_cifar100_v1_baseline( run_name=RUN_NAME, push_to_hub=PUSH_TO_HUB, resume=RESUME ) elif PRESET == "v2_wormhole": print("💎 Training DavidBeans V2 (Wormhole routing)...") model, acc = train_cifar100_v2_wormhole( run_name=RUN_NAME, push_to_hub=PUSH_TO_HUB, resume=RESUME ) elif PRESET == "v2_spine": print("💎🦴 Training DavidBeans V2 with Conv Spine...") model, acc = train_cifar100_v2_with_spine( run_name=RUN_NAME, push_to_hub=PUSH_TO_HUB, resume=RESUME ) elif PRESET == "v2_redundant": print("💎✖️ Training DavidBeans V2 with Redundant Scales...") model, acc = train_cifar100_v2_redundant_scales( run_name=RUN_NAME, push_to_hub=PUSH_TO_HUB, resume=RESUME ) elif PRESET == "v2_no_norm": print("💎📷 Training DavidBeans V2 with No Normalization...") model, acc = train_cifar100_v2_no_norm( run_name=RUN_NAME, push_to_hub=PUSH_TO_HUB, resume=RESUME ) else: raise ValueError(f"Unknown preset: {PRESET}") print(f"\n🎉 Done! Best accuracy: {acc:.2f}%")