| """ |
| MobiusNet - CIFAR-100 (Dynamic Stages) |
| ====================================== |
| Properly handles variable stage counts. |
| |
| Author: AbstractPhil |
| https://huggingface.co/AbstractPhil/mobiusnet |
| License: Apache 2.0 |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch import Tensor |
| from typing import Tuple |
| from torchvision import datasets, transforms |
| from torch.utils.data import DataLoader |
| from tqdm.auto import tqdm |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"Device: {device}") |
|
|
|
|
| |
| |
| |
|
|
| class MobiusLens(nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| layer_idx: int, |
| total_layers: int, |
| scale_range: Tuple[float, float] = (1.0, 9.0), |
| ): |
| super().__init__() |
| |
| self.t = layer_idx / max(total_layers - 1, 1) |
| |
| scale_span = scale_range[1] - scale_range[0] |
| step = scale_span / max(total_layers, 1) |
| scale_low = scale_range[0] + self.t * scale_span |
| scale_high = scale_low + step |
| |
| self.register_buffer('scales', torch.tensor([scale_low, scale_high])) |
| |
| |
| self.twist_in_angle = nn.Parameter(torch.tensor(self.t * math.pi)) |
| self.twist_in_proj = nn.Linear(dim, dim, bias=False) |
| nn.init.orthogonal_(self.twist_in_proj.weight) |
| |
| |
| self.omega = nn.Parameter(torch.tensor(math.pi)) |
| self.alpha = nn.Parameter(torch.tensor(1.5)) |
| |
| self.phase_l = nn.Parameter(torch.zeros(2)) |
| self.drift_l = nn.Parameter(torch.ones(2)) |
| self.phase_m = nn.Parameter(torch.zeros(2)) |
| self.drift_m = nn.Parameter(torch.zeros(2)) |
| self.phase_r = nn.Parameter(torch.zeros(2)) |
| self.drift_r = nn.Parameter(-torch.ones(2)) |
| |
| self.accum_weights = nn.Parameter(torch.tensor([0.4, 0.2, 0.4])) |
| self.xor_weight = nn.Parameter(torch.tensor(0.7)) |
| |
| |
| self.twist_out_angle = nn.Parameter(torch.tensor(-self.t * math.pi)) |
| self.twist_out_proj = nn.Linear(dim, dim, bias=False) |
| nn.init.orthogonal_(self.twist_out_proj.weight) |
| |
| def _twist_in(self, x: Tensor) -> Tensor: |
| cos_t = torch.cos(self.twist_in_angle) |
| sin_t = torch.sin(self.twist_in_angle) |
| return x * cos_t + self.twist_in_proj(x) * sin_t |
| |
| def _center_lens(self, x: Tensor) -> Tensor: |
| x_norm = torch.tanh(x) |
| t = x_norm.abs().mean(dim=-1, keepdim=True).unsqueeze(-2) |
| |
| x_exp = x_norm.unsqueeze(-2) |
| s = self.scales.view(-1, 1) |
| |
| def wave(phase, drift): |
| a = self.alpha.abs() + 0.1 |
| pos = s * self.omega * (x_exp + drift.view(-1, 1) * t) + phase.view(-1, 1) |
| return torch.exp(-a * torch.sin(pos).pow(2)).prod(dim=-2) |
| |
| L = wave(self.phase_l, self.drift_l) |
| M = wave(self.phase_m, self.drift_m) |
| R = wave(self.phase_r, self.drift_r) |
| |
| w = torch.softmax(self.accum_weights, dim=0) |
| xor_w = torch.sigmoid(self.xor_weight) |
| |
| xor_comp = (L + R - 2 * L * R).abs() |
| and_comp = L * R |
| lr = xor_w * xor_comp + (1 - xor_w) * and_comp |
| |
| gate = w[0] * L + w[1] * M + w[2] * R |
| gate = gate * (0.5 + 0.5 * lr) |
| gate = gate / (gate.mean() + 1e-6) * 0.5 |
| |
| return x * gate.clamp(0, 1) |
| |
| def _twist_out(self, x: Tensor) -> Tensor: |
| cos_t = torch.cos(self.twist_out_angle) |
| sin_t = torch.sin(self.twist_out_angle) |
| return x * cos_t + self.twist_out_proj(x) * sin_t |
| |
| def forward(self, x: Tensor) -> Tensor: |
| return self._twist_out(self._center_lens(self._twist_in(x))) |
|
|
|
|
| |
| |
| |
|
|
| class MobiusConvBlock(nn.Module): |
| def __init__( |
| self, |
| channels: int, |
| layer_idx: int, |
| total_layers: int, |
| scale_range: Tuple[float, float] = (1.0, 9.0), |
| reduction: float = 0.5, |
| ): |
| super().__init__() |
| |
| self.conv = nn.Sequential( |
| nn.Conv2d(channels, channels, 3, padding=1, groups=channels, bias=False), |
| nn.Conv2d(channels, channels, 1, bias=False), |
| nn.BatchNorm2d(channels), |
| ) |
| |
| self.lens = MobiusLens(channels, layer_idx, total_layers, scale_range) |
| |
| third = channels // 3 |
| which_third = layer_idx % 3 |
| mask = torch.ones(channels) |
| start = which_third * third |
| end = start + third + (channels % 3 if which_third == 2 else 0) |
| mask[start:end] = reduction |
| self.register_buffer('thirds_mask', mask.view(1, -1, 1, 1)) |
| |
| self.residual_weight = nn.Parameter(torch.tensor(0.9)) |
| |
| def forward(self, x: Tensor) -> Tensor: |
| identity = x |
| |
| h = self.conv(x) |
| B, D, H, W = h.shape |
| h = h.permute(0, 2, 3, 1) |
| h = self.lens(h) |
| h = h.permute(0, 3, 1, 2) |
| h = h * self.thirds_mask |
| |
| rw = torch.sigmoid(self.residual_weight) |
| return rw * identity + (1 - rw) * h |
|
|
|
|
| |
| |
| |
|
|
| class MobiusNet(nn.Module): |
| """ |
| Pure conv with Möbius topology. |
| Dynamic number of stages based on len(depths). |
| """ |
| |
| def __init__( |
| self, |
| in_chans: int = 3, |
| num_classes: int = 100, |
| channels: Tuple[int, ...] = (64, 64, 128, 128), |
| depths: Tuple[int, ...] = (8, 4, 2), |
| scale_range: Tuple[float, float] = (0.5, 2.5), |
| ): |
| super().__init__() |
| |
| num_stages = len(depths) |
| total_layers = sum(depths) |
| |
| self.total_layers = total_layers |
| self.scale_range = scale_range |
| self.channels = channels |
| self.depths = depths |
| self.num_stages = num_stages |
| |
| |
| channels = list(channels) |
| while len(channels) < num_stages: |
| channels.append(channels[-1]) |
| |
| |
| self.stem = nn.Sequential( |
| nn.Conv2d(in_chans, channels[0], 3, padding=1, bias=False), |
| nn.BatchNorm2d(channels[0]), |
| ) |
| |
| |
| layer_idx = 0 |
| self.stages = nn.ModuleList() |
| self.downsamples = nn.ModuleList() |
| |
| for stage_idx in range(num_stages): |
| ch = channels[stage_idx] |
| |
| |
| stage = nn.ModuleList() |
| for _ in range(depths[stage_idx]): |
| stage.append(MobiusConvBlock( |
| ch, layer_idx, total_layers, scale_range |
| )) |
| layer_idx += 1 |
| self.stages.append(stage) |
| |
| |
| if stage_idx < num_stages - 1: |
| ch_next = channels[stage_idx + 1] |
| self.downsamples.append(nn.Sequential( |
| nn.Conv2d(ch, ch_next, 3, stride=2, padding=1, bias=False), |
| nn.BatchNorm2d(ch_next), |
| )) |
| |
| |
| self.pool = nn.AdaptiveAvgPool2d(1) |
| self.head = nn.Linear(channels[num_stages - 1], num_classes) |
| |
| def forward(self, x: Tensor) -> Tensor: |
| x = self.stem(x) |
| |
| for i, stage in enumerate(self.stages): |
| for block in stage: |
| x = block(x) |
| if i < len(self.downsamples): |
| x = self.downsamples[i](x) |
| |
| return self.head(self.pool(x).flatten(1)) |
| |
| def get_info(self) -> str: |
| return ( |
| f"MobiusNet: channels={self.channels}, depths={self.depths}, " |
| f"total_layers={self.total_layers}, scale_range={self.scale_range}" |
| ) |
| |
| def get_topology_info(self) -> str: |
| lines = ["Möbius Ribbon Topology:"] |
| lines.append("=" * 60) |
| |
| scale_span = self.scale_range[1] - self.scale_range[0] |
| layer_idx = 0 |
| |
| for stage_idx, depth in enumerate(self.depths): |
| ch = self.channels[stage_idx] if stage_idx < len(self.channels) else self.channels[-1] |
| for local_idx in range(depth): |
| t = layer_idx / max(self.total_layers - 1, 1) |
| scale_low = self.scale_range[0] + t * scale_span |
| scale_high = scale_low + scale_span / self.total_layers |
| |
| lines.append( |
| f"Layer {layer_idx:2d} (Stage {stage_idx+1}, ch={ch:3d}): " |
| f"t={t:.3f}, scales=[{scale_low:.3f}, {scale_high:.3f}]" |
| ) |
| layer_idx += 1 |
| |
| if stage_idx < self.num_stages - 1: |
| ch_next = self.channels[stage_idx + 1] if stage_idx + 1 < len(self.channels) else self.channels[-1] |
| lines.append(f" ↓ Downsample {ch} → {ch_next}") |
| |
| lines.append("=" * 60) |
| return "\n".join(lines) |
|
|
|
|
| |
| |
| |
|
|
| PRESETS = { |
| 'mobius_xs': { |
| 'channels': (64, 64, 128), |
| 'depths': (4, 2, 2), |
| 'scale_range': (0.5, 2.5), |
| }, |
| 'mobius_stretched': { |
| 'channels': (32, 64, 96, 128, 192, 256, 320, 384, 448), |
| 'depths': (4, 4, 4, 3, 3, 3, 2, 2, 2), |
| 'scale_range': (0.2915, 2.85), |
| }, |
| 'mobius_m': { |
| 'channels': (64, 128, 256, 256), |
| 'depths': (8, 4, 2), |
| 'scale_range': (0.5, 3.0), |
| }, |
| 'mobius_deep': { |
| 'channels': (64, 64, 128, 128), |
| 'depths': (12, 6, 4), |
| 'scale_range': (0.5, 3.5), |
| }, |
| 'mobius_wide': { |
| 'channels': (96, 96, 192, 192), |
| 'depths': (8, 4, 2), |
| 'scale_range': (0.5, 2.5), |
| }, |
| } |
|
|
|
|
| |
| |
| |
|
|
| def train_mobius_cifar100( |
| preset: str = 'mobius_s', |
| epochs: int = 100, |
| lr: float = 1e-3, |
| batch_size: int = 128, |
| use_autoaugment: bool = True, |
| ): |
| config = PRESETS[preset] |
| |
| print("=" * 70) |
| print(f"MÖBIUS NET - {preset.upper()} - CIFAR-100") |
| print("=" * 70) |
| print(f"Device: {device}") |
| print(f"Channels: {config['channels']}") |
| print(f"Depths: {config['depths']}") |
| print(f"Scale range: {config['scale_range']}") |
| print(f"AutoAugment: {use_autoaugment}") |
| print() |
| |
| |
| mean = (0.5071, 0.4867, 0.4408) |
| std = (0.2675, 0.2565, 0.2761) |
| |
| train_transforms = [ |
| transforms.RandomCrop(32, padding=4), |
| transforms.RandomHorizontalFlip(), |
| ] |
| if use_autoaugment: |
| train_transforms.append(transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10)) |
| train_transforms.extend([ |
| transforms.ToTensor(), |
| transforms.Normalize(mean, std), |
| ]) |
| |
| train_tf = transforms.Compose(train_transforms) |
| test_tf = transforms.Compose([ |
| transforms.ToTensor(), |
| transforms.Normalize(mean, std), |
| ]) |
| |
| train_ds = datasets.CIFAR100('./data', train=True, download=True, transform=train_tf) |
| test_ds = datasets.CIFAR100('./data', train=False, download=True, transform=test_tf) |
| |
| train_loader = DataLoader( |
| train_ds, batch_size=batch_size, shuffle=True, |
| num_workers=8, pin_memory=True, persistent_workers=True |
| ) |
| test_loader = DataLoader( |
| test_ds, batch_size=256, num_workers=2, pin_memory=True, persistent_workers=True, |
| ) |
| |
| model = MobiusNet( |
| in_chans=3, |
| num_classes=100, |
| **config |
| ).to(device) |
| |
| print(model.get_info()) |
| print() |
| print(model.get_topology_info()) |
| print() |
|
|
| model.compile(mode='reduce-overhead') |
| |
| total_params = sum(p.numel() for p in model.parameters()) |
| print(f"Total params: {total_params:,}") |
| print() |
| |
| optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.05) |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) |
| |
| best_acc = 0.0 |
| |
| for epoch in range(1, epochs + 1): |
| model.train() |
| train_loss, train_correct, train_total = 0, 0, 0 |
| |
| pbar = tqdm(train_loader, desc=f"Epoch {epoch:3d}") |
| for x, y in pbar: |
| x, y = x.to(device), y.to(device) |
| |
| optimizer.zero_grad() |
| logits = model(x) |
| loss = F.cross_entropy(logits, y) |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| |
| train_loss += loss.item() * x.size(0) |
| train_correct += (logits.argmax(1) == y).sum().item() |
| train_total += x.size(0) |
| |
| pbar.set_postfix(loss=f"{loss.item():.4f}") |
| |
| scheduler.step() |
| |
| model.eval() |
| val_correct, val_total = 0, 0 |
| with torch.no_grad(): |
| for x, y in test_loader: |
| x, y = x.to(device), y.to(device) |
| logits = model(x) |
| val_correct += (logits.argmax(1) == y).sum().item() |
| val_total += x.size(0) |
| |
| train_acc = train_correct / train_total |
| val_acc = val_correct / val_total |
| best_acc = max(best_acc, val_acc) |
| marker = " ★" if val_acc >= best_acc else "" |
| |
| print(f"Epoch {epoch:3d} | Loss: {train_loss/train_total:.4f} | " |
| f"Train: {train_acc:.4f} | Val: {val_acc:.4f} | Best: {best_acc:.4f}{marker}") |
| |
| print() |
| print("=" * 70) |
| print("FINAL RESULTS") |
| print("=" * 70) |
| print(model.get_info()) |
| print(f"Best accuracy: {best_acc:.4f}") |
| print(f"Total params: {total_params:,}") |
| print("=" * 70) |
| |
| return model, best_acc |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == '__main__': |
| model, best_acc = train_mobius_cifar100( |
| preset='mobius_stretched', |
| epochs=100, |
| lr=1e-3, |
| use_autoaugment=True, |
| ) |