""" MobiusNet - CIFAR-100 (Dynamic Stages) ====================================== Properly handles variable stage counts. Author: AbstractPhil https://huggingface.co/AbstractPhil/mobiusnet License: Apache 2.0 """ import math import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor from typing import Tuple from torchvision import datasets, transforms from torch.utils.data import DataLoader from tqdm.auto import tqdm device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Device: {device}") # ============================================================================ # MÖBIUS LENS # ============================================================================ class MobiusLens(nn.Module): def __init__( self, dim: int, layer_idx: int, total_layers: int, scale_range: Tuple[float, float] = (1.0, 9.0), ): super().__init__() self.t = layer_idx / max(total_layers - 1, 1) scale_span = scale_range[1] - scale_range[0] step = scale_span / max(total_layers, 1) scale_low = scale_range[0] + self.t * scale_span scale_high = scale_low + step self.register_buffer('scales', torch.tensor([scale_low, scale_high])) # TWIST IN self.twist_in_angle = nn.Parameter(torch.tensor(self.t * math.pi)) self.twist_in_proj = nn.Linear(dim, dim, bias=False) nn.init.orthogonal_(self.twist_in_proj.weight) # CENTER LENS self.omega = nn.Parameter(torch.tensor(math.pi)) self.alpha = nn.Parameter(torch.tensor(1.5)) self.phase_l = nn.Parameter(torch.zeros(2)) self.drift_l = nn.Parameter(torch.ones(2)) self.phase_m = nn.Parameter(torch.zeros(2)) self.drift_m = nn.Parameter(torch.zeros(2)) self.phase_r = nn.Parameter(torch.zeros(2)) self.drift_r = nn.Parameter(-torch.ones(2)) self.accum_weights = nn.Parameter(torch.tensor([0.4, 0.2, 0.4])) self.xor_weight = nn.Parameter(torch.tensor(0.7)) # TWIST OUT self.twist_out_angle = nn.Parameter(torch.tensor(-self.t * math.pi)) self.twist_out_proj = nn.Linear(dim, dim, bias=False) nn.init.orthogonal_(self.twist_out_proj.weight) def _twist_in(self, x: Tensor) -> Tensor: cos_t = torch.cos(self.twist_in_angle) sin_t = torch.sin(self.twist_in_angle) return x * cos_t + self.twist_in_proj(x) * sin_t def _center_lens(self, x: Tensor) -> Tensor: x_norm = torch.tanh(x) t = x_norm.abs().mean(dim=-1, keepdim=True).unsqueeze(-2) x_exp = x_norm.unsqueeze(-2) s = self.scales.view(-1, 1) def wave(phase, drift): a = self.alpha.abs() + 0.1 pos = s * self.omega * (x_exp + drift.view(-1, 1) * t) + phase.view(-1, 1) return torch.exp(-a * torch.sin(pos).pow(2)).prod(dim=-2) L = wave(self.phase_l, self.drift_l) M = wave(self.phase_m, self.drift_m) R = wave(self.phase_r, self.drift_r) w = torch.softmax(self.accum_weights, dim=0) xor_w = torch.sigmoid(self.xor_weight) xor_comp = (L + R - 2 * L * R).abs() and_comp = L * R lr = xor_w * xor_comp + (1 - xor_w) * and_comp gate = w[0] * L + w[1] * M + w[2] * R gate = gate * (0.5 + 0.5 * lr) gate = gate / (gate.mean() + 1e-6) * 0.5 return x * gate.clamp(0, 1) def _twist_out(self, x: Tensor) -> Tensor: cos_t = torch.cos(self.twist_out_angle) sin_t = torch.sin(self.twist_out_angle) return x * cos_t + self.twist_out_proj(x) * sin_t def forward(self, x: Tensor) -> Tensor: return self._twist_out(self._center_lens(self._twist_in(x))) # ============================================================================ # MÖBIUS CONV BLOCK # ============================================================================ class MobiusConvBlock(nn.Module): def __init__( self, channels: int, layer_idx: int, total_layers: int, scale_range: Tuple[float, float] = (1.0, 9.0), reduction: float = 0.5, ): super().__init__() self.conv = nn.Sequential( nn.Conv2d(channels, channels, 3, padding=1, groups=channels, bias=False), nn.Conv2d(channels, channels, 1, bias=False), nn.BatchNorm2d(channels), ) self.lens = MobiusLens(channels, layer_idx, total_layers, scale_range) third = channels // 3 which_third = layer_idx % 3 mask = torch.ones(channels) start = which_third * third end = start + third + (channels % 3 if which_third == 2 else 0) mask[start:end] = reduction self.register_buffer('thirds_mask', mask.view(1, -1, 1, 1)) self.residual_weight = nn.Parameter(torch.tensor(0.9)) def forward(self, x: Tensor) -> Tensor: identity = x h = self.conv(x) B, D, H, W = h.shape h = h.permute(0, 2, 3, 1) h = self.lens(h) h = h.permute(0, 3, 1, 2) h = h * self.thirds_mask rw = torch.sigmoid(self.residual_weight) return rw * identity + (1 - rw) * h # ============================================================================ # MÖBIUS NET - DYNAMIC STAGES # ============================================================================ class MobiusNet(nn.Module): """ Pure conv with Möbius topology. Dynamic number of stages based on len(depths). """ def __init__( self, in_chans: int = 3, num_classes: int = 100, channels: Tuple[int, ...] = (64, 64, 128, 128), depths: Tuple[int, ...] = (8, 4, 2), scale_range: Tuple[float, float] = (0.5, 2.5), ): super().__init__() num_stages = len(depths) total_layers = sum(depths) self.total_layers = total_layers self.scale_range = scale_range self.channels = channels self.depths = depths self.num_stages = num_stages # Ensure we have enough channel specs channels = list(channels) while len(channels) < num_stages: channels.append(channels[-1]) # Stem self.stem = nn.Sequential( nn.Conv2d(in_chans, channels[0], 3, padding=1, bias=False), nn.BatchNorm2d(channels[0]), ) # Build stages dynamically layer_idx = 0 self.stages = nn.ModuleList() self.downsamples = nn.ModuleList() for stage_idx in range(num_stages): ch = channels[stage_idx] # Stage blocks stage = nn.ModuleList() for _ in range(depths[stage_idx]): stage.append(MobiusConvBlock( ch, layer_idx, total_layers, scale_range )) layer_idx += 1 self.stages.append(stage) # Downsample between stages (not after last) if stage_idx < num_stages - 1: ch_next = channels[stage_idx + 1] self.downsamples.append(nn.Sequential( nn.Conv2d(ch, ch_next, 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(ch_next), )) # Head self.pool = nn.AdaptiveAvgPool2d(1) self.head = nn.Linear(channels[num_stages - 1], num_classes) def forward(self, x: Tensor) -> Tensor: x = self.stem(x) for i, stage in enumerate(self.stages): for block in stage: x = block(x) if i < len(self.downsamples): x = self.downsamples[i](x) return self.head(self.pool(x).flatten(1)) def get_info(self) -> str: return ( f"MobiusNet: channels={self.channels}, depths={self.depths}, " f"total_layers={self.total_layers}, scale_range={self.scale_range}" ) def get_topology_info(self) -> str: lines = ["Möbius Ribbon Topology:"] lines.append("=" * 60) scale_span = self.scale_range[1] - self.scale_range[0] layer_idx = 0 for stage_idx, depth in enumerate(self.depths): ch = self.channels[stage_idx] if stage_idx < len(self.channels) else self.channels[-1] for local_idx in range(depth): t = layer_idx / max(self.total_layers - 1, 1) scale_low = self.scale_range[0] + t * scale_span scale_high = scale_low + scale_span / self.total_layers lines.append( f"Layer {layer_idx:2d} (Stage {stage_idx+1}, ch={ch:3d}): " f"t={t:.3f}, scales=[{scale_low:.3f}, {scale_high:.3f}]" ) layer_idx += 1 if stage_idx < self.num_stages - 1: ch_next = self.channels[stage_idx + 1] if stage_idx + 1 < len(self.channels) else self.channels[-1] lines.append(f" ↓ Downsample {ch} → {ch_next}") lines.append("=" * 60) return "\n".join(lines) # ============================================================================ # PRESETS # ============================================================================ PRESETS = { 'mobius_xs': { 'channels': (64, 64, 128), 'depths': (4, 2, 2), 'scale_range': (0.5, 2.5), }, 'mobius_stretched': { 'channels': (32, 64, 96, 128, 192, 256, 320, 384, 448), 'depths': (4, 4, 4, 3, 3, 3, 2, 2, 2), 'scale_range': (0.2915, 2.85), }, 'mobius_m': { 'channels': (64, 128, 256, 256), 'depths': (8, 4, 2), 'scale_range': (0.5, 3.0), }, 'mobius_deep': { 'channels': (64, 64, 128, 128), 'depths': (12, 6, 4), 'scale_range': (0.5, 3.5), }, 'mobius_wide': { 'channels': (96, 96, 192, 192), 'depths': (8, 4, 2), 'scale_range': (0.5, 2.5), }, } # ============================================================================ # TRAINING # ============================================================================ def train_mobius_cifar100( preset: str = 'mobius_s', epochs: int = 100, lr: float = 1e-3, batch_size: int = 128, use_autoaugment: bool = True, ): config = PRESETS[preset] print("=" * 70) print(f"MÖBIUS NET - {preset.upper()} - CIFAR-100") print("=" * 70) print(f"Device: {device}") print(f"Channels: {config['channels']}") print(f"Depths: {config['depths']}") print(f"Scale range: {config['scale_range']}") print(f"AutoAugment: {use_autoaugment}") print() # CIFAR-100 normalization mean = (0.5071, 0.4867, 0.4408) std = (0.2675, 0.2565, 0.2761) train_transforms = [ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), ] if use_autoaugment: train_transforms.append(transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10)) train_transforms.extend([ transforms.ToTensor(), transforms.Normalize(mean, std), ]) train_tf = transforms.Compose(train_transforms) test_tf = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean, std), ]) train_ds = datasets.CIFAR100('./data', train=True, download=True, transform=train_tf) test_ds = datasets.CIFAR100('./data', train=False, download=True, transform=test_tf) train_loader = DataLoader( train_ds, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True, persistent_workers=True ) test_loader = DataLoader( test_ds, batch_size=256, num_workers=2, pin_memory=True, persistent_workers=True, ) model = MobiusNet( in_chans=3, num_classes=100, **config ).to(device) print(model.get_info()) print() print(model.get_topology_info()) print() model.compile(mode='reduce-overhead') total_params = sum(p.numel() for p in model.parameters()) print(f"Total params: {total_params:,}") print() optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.05) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) best_acc = 0.0 for epoch in range(1, epochs + 1): model.train() train_loss, train_correct, train_total = 0, 0, 0 pbar = tqdm(train_loader, desc=f"Epoch {epoch:3d}") for x, y in pbar: x, y = x.to(device), y.to(device) optimizer.zero_grad() logits = model(x) loss = F.cross_entropy(logits, y) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() train_loss += loss.item() * x.size(0) train_correct += (logits.argmax(1) == y).sum().item() train_total += x.size(0) pbar.set_postfix(loss=f"{loss.item():.4f}") scheduler.step() model.eval() val_correct, val_total = 0, 0 with torch.no_grad(): for x, y in test_loader: x, y = x.to(device), y.to(device) logits = model(x) val_correct += (logits.argmax(1) == y).sum().item() val_total += x.size(0) train_acc = train_correct / train_total val_acc = val_correct / val_total best_acc = max(best_acc, val_acc) marker = " ★" if val_acc >= best_acc else "" print(f"Epoch {epoch:3d} | Loss: {train_loss/train_total:.4f} | " f"Train: {train_acc:.4f} | Val: {val_acc:.4f} | Best: {best_acc:.4f}{marker}") print() print("=" * 70) print("FINAL RESULTS") print("=" * 70) print(model.get_info()) print(f"Best accuracy: {best_acc:.4f}") print(f"Total params: {total_params:,}") print("=" * 70) return model, best_acc # ============================================================================ # RUN # ============================================================================ if __name__ == '__main__': model, best_acc = train_mobius_cifar100( preset='mobius_stretched', # channels=(64, 64, 128, 128), depths=(8, 4, 2) epochs=100, lr=1e-3, use_autoaugment=True, )