|
|
""" |
|
|
MobiusNet - CIFAR-100 (Dynamic Stages) |
|
|
====================================== |
|
|
Properly handles variable stage counts. |
|
|
|
|
|
Author: AbstractPhil |
|
|
https://huggingface.co/AbstractPhil/mobiusnet |
|
|
License: Apache 2.0 |
|
|
""" |
|
|
|
|
|
import math |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch import Tensor |
|
|
from typing import Tuple |
|
|
from torchvision import datasets, transforms |
|
|
from torch.utils.data import DataLoader |
|
|
from tqdm.auto import tqdm |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
print(f"Device: {device}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MobiusLens(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
layer_idx: int, |
|
|
total_layers: int, |
|
|
scale_range: Tuple[float, float] = (1.0, 9.0), |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.t = layer_idx / max(total_layers - 1, 1) |
|
|
|
|
|
scale_span = scale_range[1] - scale_range[0] |
|
|
step = scale_span / max(total_layers, 1) |
|
|
scale_low = scale_range[0] + self.t * scale_span |
|
|
scale_high = scale_low + step |
|
|
|
|
|
self.register_buffer('scales', torch.tensor([scale_low, scale_high])) |
|
|
|
|
|
|
|
|
self.twist_in_angle = nn.Parameter(torch.tensor(self.t * math.pi)) |
|
|
self.twist_in_proj = nn.Linear(dim, dim, bias=False) |
|
|
nn.init.orthogonal_(self.twist_in_proj.weight) |
|
|
|
|
|
|
|
|
self.omega = nn.Parameter(torch.tensor(math.pi)) |
|
|
self.alpha = nn.Parameter(torch.tensor(1.5)) |
|
|
|
|
|
self.phase_l = nn.Parameter(torch.zeros(2)) |
|
|
self.drift_l = nn.Parameter(torch.ones(2)) |
|
|
self.phase_m = nn.Parameter(torch.zeros(2)) |
|
|
self.drift_m = nn.Parameter(torch.zeros(2)) |
|
|
self.phase_r = nn.Parameter(torch.zeros(2)) |
|
|
self.drift_r = nn.Parameter(-torch.ones(2)) |
|
|
|
|
|
self.accum_weights = nn.Parameter(torch.tensor([0.4, 0.2, 0.4])) |
|
|
self.xor_weight = nn.Parameter(torch.tensor(0.7)) |
|
|
|
|
|
|
|
|
self.twist_out_angle = nn.Parameter(torch.tensor(-self.t * math.pi)) |
|
|
self.twist_out_proj = nn.Linear(dim, dim, bias=False) |
|
|
nn.init.orthogonal_(self.twist_out_proj.weight) |
|
|
|
|
|
def _twist_in(self, x: Tensor) -> Tensor: |
|
|
cos_t = torch.cos(self.twist_in_angle) |
|
|
sin_t = torch.sin(self.twist_in_angle) |
|
|
return x * cos_t + self.twist_in_proj(x) * sin_t |
|
|
|
|
|
def _center_lens(self, x: Tensor) -> Tensor: |
|
|
x_norm = torch.tanh(x) |
|
|
t = x_norm.abs().mean(dim=-1, keepdim=True).unsqueeze(-2) |
|
|
|
|
|
x_exp = x_norm.unsqueeze(-2) |
|
|
s = self.scales.view(-1, 1) |
|
|
|
|
|
def wave(phase, drift): |
|
|
a = self.alpha.abs() + 0.1 |
|
|
pos = s * self.omega * (x_exp + drift.view(-1, 1) * t) + phase.view(-1, 1) |
|
|
return torch.exp(-a * torch.sin(pos).pow(2)).prod(dim=-2) |
|
|
|
|
|
L = wave(self.phase_l, self.drift_l) |
|
|
M = wave(self.phase_m, self.drift_m) |
|
|
R = wave(self.phase_r, self.drift_r) |
|
|
|
|
|
w = torch.softmax(self.accum_weights, dim=0) |
|
|
xor_w = torch.sigmoid(self.xor_weight) |
|
|
|
|
|
xor_comp = (L + R - 2 * L * R).abs() |
|
|
and_comp = L * R |
|
|
lr = xor_w * xor_comp + (1 - xor_w) * and_comp |
|
|
|
|
|
gate = w[0] * L + w[1] * M + w[2] * R |
|
|
gate = gate * (0.5 + 0.5 * lr) |
|
|
gate = gate / (gate.mean() + 1e-6) * 0.5 |
|
|
|
|
|
return x * gate.clamp(0, 1) |
|
|
|
|
|
def _twist_out(self, x: Tensor) -> Tensor: |
|
|
cos_t = torch.cos(self.twist_out_angle) |
|
|
sin_t = torch.sin(self.twist_out_angle) |
|
|
return x * cos_t + self.twist_out_proj(x) * sin_t |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
return self._twist_out(self._center_lens(self._twist_in(x))) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MobiusConvBlock(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
channels: int, |
|
|
layer_idx: int, |
|
|
total_layers: int, |
|
|
scale_range: Tuple[float, float] = (1.0, 9.0), |
|
|
reduction: float = 0.5, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.conv = nn.Sequential( |
|
|
nn.Conv2d(channels, channels, 3, padding=1, groups=channels, bias=False), |
|
|
nn.Conv2d(channels, channels, 1, bias=False), |
|
|
nn.BatchNorm2d(channels), |
|
|
) |
|
|
|
|
|
self.lens = MobiusLens(channels, layer_idx, total_layers, scale_range) |
|
|
|
|
|
third = channels // 3 |
|
|
which_third = layer_idx % 3 |
|
|
mask = torch.ones(channels) |
|
|
start = which_third * third |
|
|
end = start + third + (channels % 3 if which_third == 2 else 0) |
|
|
mask[start:end] = reduction |
|
|
self.register_buffer('thirds_mask', mask.view(1, -1, 1, 1)) |
|
|
|
|
|
self.residual_weight = nn.Parameter(torch.tensor(0.9)) |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
identity = x |
|
|
|
|
|
h = self.conv(x) |
|
|
B, D, H, W = h.shape |
|
|
h = h.permute(0, 2, 3, 1) |
|
|
h = self.lens(h) |
|
|
h = h.permute(0, 3, 1, 2) |
|
|
h = h * self.thirds_mask |
|
|
|
|
|
rw = torch.sigmoid(self.residual_weight) |
|
|
return rw * identity + (1 - rw) * h |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MobiusNet(nn.Module): |
|
|
""" |
|
|
Pure conv with Möbius topology. |
|
|
Dynamic number of stages based on len(depths). |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
in_chans: int = 3, |
|
|
num_classes: int = 100, |
|
|
channels: Tuple[int, ...] = (64, 64, 128, 128), |
|
|
depths: Tuple[int, ...] = (8, 4, 2), |
|
|
scale_range: Tuple[float, float] = (0.5, 2.5), |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
num_stages = len(depths) |
|
|
total_layers = sum(depths) |
|
|
|
|
|
self.total_layers = total_layers |
|
|
self.scale_range = scale_range |
|
|
self.channels = channels |
|
|
self.depths = depths |
|
|
self.num_stages = num_stages |
|
|
|
|
|
|
|
|
channels = list(channels) |
|
|
while len(channels) < num_stages: |
|
|
channels.append(channels[-1]) |
|
|
|
|
|
|
|
|
self.stem = nn.Sequential( |
|
|
nn.Conv2d(in_chans, channels[0], 3, padding=1, bias=False), |
|
|
nn.BatchNorm2d(channels[0]), |
|
|
) |
|
|
|
|
|
|
|
|
layer_idx = 0 |
|
|
self.stages = nn.ModuleList() |
|
|
self.downsamples = nn.ModuleList() |
|
|
|
|
|
for stage_idx in range(num_stages): |
|
|
ch = channels[stage_idx] |
|
|
|
|
|
|
|
|
stage = nn.ModuleList() |
|
|
for _ in range(depths[stage_idx]): |
|
|
stage.append(MobiusConvBlock( |
|
|
ch, layer_idx, total_layers, scale_range |
|
|
)) |
|
|
layer_idx += 1 |
|
|
self.stages.append(stage) |
|
|
|
|
|
|
|
|
if stage_idx < num_stages - 1: |
|
|
ch_next = channels[stage_idx + 1] |
|
|
self.downsamples.append(nn.Sequential( |
|
|
nn.Conv2d(ch, ch_next, 3, stride=2, padding=1, bias=False), |
|
|
nn.BatchNorm2d(ch_next), |
|
|
)) |
|
|
|
|
|
|
|
|
self.pool = nn.AdaptiveAvgPool2d(1) |
|
|
self.head = nn.Linear(channels[num_stages - 1], num_classes) |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
x = self.stem(x) |
|
|
|
|
|
for i, stage in enumerate(self.stages): |
|
|
for block in stage: |
|
|
x = block(x) |
|
|
if i < len(self.downsamples): |
|
|
x = self.downsamples[i](x) |
|
|
|
|
|
return self.head(self.pool(x).flatten(1)) |
|
|
|
|
|
def get_info(self) -> str: |
|
|
return ( |
|
|
f"MobiusNet: channels={self.channels}, depths={self.depths}, " |
|
|
f"total_layers={self.total_layers}, scale_range={self.scale_range}" |
|
|
) |
|
|
|
|
|
def get_topology_info(self) -> str: |
|
|
lines = ["Möbius Ribbon Topology:"] |
|
|
lines.append("=" * 60) |
|
|
|
|
|
scale_span = self.scale_range[1] - self.scale_range[0] |
|
|
layer_idx = 0 |
|
|
|
|
|
for stage_idx, depth in enumerate(self.depths): |
|
|
ch = self.channels[stage_idx] if stage_idx < len(self.channels) else self.channels[-1] |
|
|
for local_idx in range(depth): |
|
|
t = layer_idx / max(self.total_layers - 1, 1) |
|
|
scale_low = self.scale_range[0] + t * scale_span |
|
|
scale_high = scale_low + scale_span / self.total_layers |
|
|
|
|
|
lines.append( |
|
|
f"Layer {layer_idx:2d} (Stage {stage_idx+1}, ch={ch:3d}): " |
|
|
f"t={t:.3f}, scales=[{scale_low:.3f}, {scale_high:.3f}]" |
|
|
) |
|
|
layer_idx += 1 |
|
|
|
|
|
if stage_idx < self.num_stages - 1: |
|
|
ch_next = self.channels[stage_idx + 1] if stage_idx + 1 < len(self.channels) else self.channels[-1] |
|
|
lines.append(f" ↓ Downsample {ch} → {ch_next}") |
|
|
|
|
|
lines.append("=" * 60) |
|
|
return "\n".join(lines) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PRESETS = { |
|
|
'mobius_xs': { |
|
|
'channels': (64, 64, 128), |
|
|
'depths': (4, 2, 2), |
|
|
'scale_range': (0.5, 2.5), |
|
|
}, |
|
|
'mobius_stretched': { |
|
|
'channels': (32, 64, 96, 128, 192, 256, 320, 384, 448), |
|
|
'depths': (4, 4, 4, 3, 3, 3, 2, 2, 2), |
|
|
'scale_range': (0.2915, 2.85), |
|
|
}, |
|
|
'mobius_m': { |
|
|
'channels': (64, 128, 256, 256), |
|
|
'depths': (8, 4, 2), |
|
|
'scale_range': (0.5, 3.0), |
|
|
}, |
|
|
'mobius_deep': { |
|
|
'channels': (64, 64, 128, 128), |
|
|
'depths': (12, 6, 4), |
|
|
'scale_range': (0.5, 3.5), |
|
|
}, |
|
|
'mobius_wide': { |
|
|
'channels': (96, 96, 192, 192), |
|
|
'depths': (8, 4, 2), |
|
|
'scale_range': (0.5, 2.5), |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_mobius_cifar100( |
|
|
preset: str = 'mobius_s', |
|
|
epochs: int = 100, |
|
|
lr: float = 1e-3, |
|
|
batch_size: int = 128, |
|
|
use_autoaugment: bool = True, |
|
|
): |
|
|
config = PRESETS[preset] |
|
|
|
|
|
print("=" * 70) |
|
|
print(f"MÖBIUS NET - {preset.upper()} - CIFAR-100") |
|
|
print("=" * 70) |
|
|
print(f"Device: {device}") |
|
|
print(f"Channels: {config['channels']}") |
|
|
print(f"Depths: {config['depths']}") |
|
|
print(f"Scale range: {config['scale_range']}") |
|
|
print(f"AutoAugment: {use_autoaugment}") |
|
|
print() |
|
|
|
|
|
|
|
|
mean = (0.5071, 0.4867, 0.4408) |
|
|
std = (0.2675, 0.2565, 0.2761) |
|
|
|
|
|
train_transforms = [ |
|
|
transforms.RandomCrop(32, padding=4), |
|
|
transforms.RandomHorizontalFlip(), |
|
|
] |
|
|
if use_autoaugment: |
|
|
train_transforms.append(transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10)) |
|
|
train_transforms.extend([ |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean, std), |
|
|
]) |
|
|
|
|
|
train_tf = transforms.Compose(train_transforms) |
|
|
test_tf = transforms.Compose([ |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean, std), |
|
|
]) |
|
|
|
|
|
train_ds = datasets.CIFAR100('./data', train=True, download=True, transform=train_tf) |
|
|
test_ds = datasets.CIFAR100('./data', train=False, download=True, transform=test_tf) |
|
|
|
|
|
train_loader = DataLoader( |
|
|
train_ds, batch_size=batch_size, shuffle=True, |
|
|
num_workers=8, pin_memory=True, persistent_workers=True |
|
|
) |
|
|
test_loader = DataLoader( |
|
|
test_ds, batch_size=256, num_workers=2, pin_memory=True, persistent_workers=True, |
|
|
) |
|
|
|
|
|
model = MobiusNet( |
|
|
in_chans=3, |
|
|
num_classes=100, |
|
|
**config |
|
|
).to(device) |
|
|
|
|
|
print(model.get_info()) |
|
|
print() |
|
|
print(model.get_topology_info()) |
|
|
print() |
|
|
|
|
|
model.compile(mode='reduce-overhead') |
|
|
|
|
|
total_params = sum(p.numel() for p in model.parameters()) |
|
|
print(f"Total params: {total_params:,}") |
|
|
print() |
|
|
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.05) |
|
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) |
|
|
|
|
|
best_acc = 0.0 |
|
|
|
|
|
for epoch in range(1, epochs + 1): |
|
|
model.train() |
|
|
train_loss, train_correct, train_total = 0, 0, 0 |
|
|
|
|
|
pbar = tqdm(train_loader, desc=f"Epoch {epoch:3d}") |
|
|
for x, y in pbar: |
|
|
x, y = x.to(device), y.to(device) |
|
|
|
|
|
optimizer.zero_grad() |
|
|
logits = model(x) |
|
|
loss = F.cross_entropy(logits, y) |
|
|
loss.backward() |
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
|
|
optimizer.step() |
|
|
|
|
|
train_loss += loss.item() * x.size(0) |
|
|
train_correct += (logits.argmax(1) == y).sum().item() |
|
|
train_total += x.size(0) |
|
|
|
|
|
pbar.set_postfix(loss=f"{loss.item():.4f}") |
|
|
|
|
|
scheduler.step() |
|
|
|
|
|
model.eval() |
|
|
val_correct, val_total = 0, 0 |
|
|
with torch.no_grad(): |
|
|
for x, y in test_loader: |
|
|
x, y = x.to(device), y.to(device) |
|
|
logits = model(x) |
|
|
val_correct += (logits.argmax(1) == y).sum().item() |
|
|
val_total += x.size(0) |
|
|
|
|
|
train_acc = train_correct / train_total |
|
|
val_acc = val_correct / val_total |
|
|
best_acc = max(best_acc, val_acc) |
|
|
marker = " ★" if val_acc >= best_acc else "" |
|
|
|
|
|
print(f"Epoch {epoch:3d} | Loss: {train_loss/train_total:.4f} | " |
|
|
f"Train: {train_acc:.4f} | Val: {val_acc:.4f} | Best: {best_acc:.4f}{marker}") |
|
|
|
|
|
print() |
|
|
print("=" * 70) |
|
|
print("FINAL RESULTS") |
|
|
print("=" * 70) |
|
|
print(model.get_info()) |
|
|
print(f"Best accuracy: {best_acc:.4f}") |
|
|
print(f"Total params: {total_params:,}") |
|
|
print("=" * 70) |
|
|
|
|
|
return model, best_acc |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
model, best_acc = train_mobius_cifar100( |
|
|
preset='mobius_stretched', |
|
|
epochs=100, |
|
|
lr=1e-3, |
|
|
use_autoaugment=True, |
|
|
) |