|
|
|
|
|
!pip install -q datasets safetensors huggingface_hub |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.utils.data import DataLoader |
|
|
from datasets import load_dataset |
|
|
from huggingface_hub import hf_hub_download |
|
|
from safetensors.torch import load_file as load_safetensors |
|
|
import matplotlib.pyplot as plt |
|
|
import numpy as np |
|
|
import math |
|
|
import json |
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MobiusLensRaw(nn.Module): |
|
|
def __init__(self, dim, layer_idx, total_layers, scale_range=(1.0, 9.0)): |
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
self.t = layer_idx / max(total_layers - 1, 1) |
|
|
scale_span = scale_range[1] - scale_range[0] |
|
|
step = scale_span / max(total_layers, 1) |
|
|
self.register_buffer('scales', torch.tensor([scale_range[0] + self.t * scale_span, |
|
|
scale_range[0] + self.t * scale_span + step])) |
|
|
self.twist_in_angle = nn.Parameter(torch.tensor(self.t * math.pi)) |
|
|
self.twist_in_proj = nn.Linear(dim, dim, bias=False) |
|
|
self.omega = nn.Parameter(torch.tensor(math.pi)) |
|
|
self.alpha = nn.Parameter(torch.tensor(1.5)) |
|
|
self.phase_l, self.drift_l = nn.Parameter(torch.zeros(2)), nn.Parameter(torch.ones(2)) |
|
|
self.phase_m, self.drift_m = nn.Parameter(torch.zeros(2)), nn.Parameter(torch.zeros(2)) |
|
|
self.phase_r, self.drift_r = nn.Parameter(torch.zeros(2)), nn.Parameter(-torch.ones(2)) |
|
|
self.accum_weights = nn.Parameter(torch.tensor([0.4, 0.2, 0.4])) |
|
|
self.xor_weight = nn.Parameter(torch.tensor(0.7)) |
|
|
self.gate_norm = nn.LayerNorm(dim) |
|
|
self.twist_out_angle = nn.Parameter(torch.tensor(-self.t * math.pi)) |
|
|
self.twist_out_proj = nn.Linear(dim, dim, bias=False) |
|
|
|
|
|
def forward(self, x): |
|
|
cos_t, sin_t = torch.cos(self.twist_in_angle), torch.sin(self.twist_in_angle) |
|
|
x = x * cos_t + self.twist_in_proj(x) * sin_t |
|
|
x_norm = torch.tanh(x) |
|
|
t = x_norm.abs().mean(dim=-1, keepdim=True).unsqueeze(-2) |
|
|
x_exp = x_norm.unsqueeze(-2) |
|
|
s = self.scales.view(-1, 1) |
|
|
a = self.alpha.abs() + 0.1 |
|
|
def wave(phase, drift): |
|
|
pos = s * self.omega * (x_exp + drift.view(-1, 1) * t) + phase.view(-1, 1) |
|
|
return torch.exp(-a * torch.sin(pos).pow(2)).prod(dim=-2) |
|
|
L, M, R = wave(self.phase_l, self.drift_l), wave(self.phase_m, self.drift_m), wave(self.phase_r, self.drift_r) |
|
|
w = torch.softmax(self.accum_weights, dim=0) |
|
|
xor_w = torch.sigmoid(self.xor_weight) |
|
|
lr = xor_w * (L + R - 2*L*R).abs() + (1 - xor_w) * L * R |
|
|
gate = torch.sigmoid(self.gate_norm((w[0]*L + w[1]*M + w[2]*R) * (0.5 + 0.5*lr))) |
|
|
x = x * gate |
|
|
cos_t, sin_t = torch.cos(self.twist_out_angle), torch.sin(self.twist_out_angle) |
|
|
return x * cos_t + self.twist_out_proj(x) * sin_t, gate |
|
|
|
|
|
def forward_raw(self, x): |
|
|
"""Return raw L/M/R values for inspection.""" |
|
|
cos_t, sin_t = torch.cos(self.twist_in_angle), torch.sin(self.twist_in_angle) |
|
|
x_twisted = x * cos_t + self.twist_in_proj(x) * sin_t |
|
|
x_norm = torch.tanh(x_twisted) |
|
|
t = x_norm.abs().mean(dim=-1, keepdim=True).unsqueeze(-2) |
|
|
x_exp = x_norm.unsqueeze(-2) |
|
|
s = self.scales.view(-1, 1) |
|
|
a = self.alpha.abs() + 0.1 |
|
|
|
|
|
def wave_detailed(phase, drift): |
|
|
pos = s * self.omega * (x_exp + drift.view(-1, 1) * t) + phase.view(-1, 1) |
|
|
sin_val = torch.sin(pos) |
|
|
exp_val = torch.exp(-a * sin_val.pow(2)) |
|
|
prod_val = exp_val.prod(dim=-2) |
|
|
return prod_val, sin_val, exp_val |
|
|
|
|
|
L, L_sin, L_exp = wave_detailed(self.phase_l, self.drift_l) |
|
|
M, M_sin, M_exp = wave_detailed(self.phase_m, self.drift_m) |
|
|
R, R_sin, R_exp = wave_detailed(self.phase_r, self.drift_r) |
|
|
|
|
|
w = torch.softmax(self.accum_weights, dim=0) |
|
|
xor_w = torch.sigmoid(self.xor_weight) |
|
|
xor_comp = (L + R - 2*L*R).abs() |
|
|
and_comp = L * R |
|
|
lr = xor_w * xor_comp + (1 - xor_w) * and_comp |
|
|
gate_pre = (w[0]*L + w[1]*M + w[2]*R) * (0.5 + 0.5*lr) |
|
|
gate = torch.sigmoid(self.gate_norm(gate_pre)) |
|
|
|
|
|
return { |
|
|
'x_norm': x_norm, 'L': L, 'M': M, 'R': R, |
|
|
'L_sin': L_sin, 'L_exp': L_exp, |
|
|
'xor_comp': xor_comp, 'and_comp': and_comp, |
|
|
'gate_pre': gate_pre, 'gate': gate, |
|
|
'omega': self.omega.item(), 'alpha': a.item(), |
|
|
'scales': self.scales.cpu().numpy(), |
|
|
'weights': w.detach().cpu().numpy(), |
|
|
'xor_weight': xor_w.item(), |
|
|
} |
|
|
|
|
|
class MobiusBlockRaw(nn.Module): |
|
|
def __init__(self, channels, layer_idx, total_layers, scale_range=(1.0, 9.0), reduction=0.5): |
|
|
super().__init__() |
|
|
self.conv = nn.Sequential(nn.Conv2d(channels, channels, 3, padding=1, groups=channels, bias=False), |
|
|
nn.Conv2d(channels, channels, 1, bias=False), nn.BatchNorm2d(channels)) |
|
|
self.lens = MobiusLensRaw(channels, layer_idx, total_layers, scale_range) |
|
|
third = channels // 3 |
|
|
which_third = layer_idx % 3 |
|
|
mask = torch.ones(channels) |
|
|
mask[which_third*third : which_third*third + third + (channels%3 if which_third==2 else 0)] = reduction |
|
|
self.register_buffer('thirds_mask', mask.view(1, -1, 1, 1)) |
|
|
self.residual_weight = nn.Parameter(torch.tensor(0.9)) |
|
|
|
|
|
def forward(self, x): |
|
|
identity = x |
|
|
h = self.conv(x).permute(0, 2, 3, 1) |
|
|
h, gate = self.lens(h) |
|
|
h = h.permute(0, 3, 1, 2) * self.thirds_mask |
|
|
rw = torch.sigmoid(self.residual_weight) |
|
|
return rw * identity + (1 - rw) * h |
|
|
|
|
|
def forward_raw(self, x): |
|
|
h = self.conv(x).permute(0, 2, 3, 1) |
|
|
return self.lens.forward_raw(h) |
|
|
|
|
|
class MobiusNetRaw(nn.Module): |
|
|
def __init__(self, in_chans=1, num_classes=1000, channels=(64,128,256), |
|
|
depths=(2,2,2), scale_range=(0.5,2.5), use_integrator=True): |
|
|
super().__init__() |
|
|
total_layers = sum(depths) |
|
|
channels = list(channels) |
|
|
self.stem = nn.Sequential(nn.Conv2d(in_chans, channels[0], 3, padding=1, bias=False), nn.BatchNorm2d(channels[0])) |
|
|
self.stages = nn.ModuleList() |
|
|
self.downsamples = nn.ModuleList() |
|
|
layer_idx = 0 |
|
|
for si, d in enumerate(depths): |
|
|
self.stages.append(nn.ModuleList([MobiusBlockRaw(channels[si], layer_idx+i, total_layers, scale_range) for i in range(d)])) |
|
|
layer_idx += d |
|
|
if si < len(depths)-1: |
|
|
self.downsamples.append(nn.Sequential(nn.Conv2d(channels[si], channels[si+1], 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(channels[si+1]))) |
|
|
|
|
|
self.integrator = nn.Sequential(nn.Conv2d(channels[-1], channels[-1], 3, padding=1, bias=False), |
|
|
nn.BatchNorm2d(channels[-1]), nn.GELU()) if use_integrator else nn.Identity() |
|
|
self.pool = nn.AdaptiveAvgPool2d(1) |
|
|
self.head = nn.Linear(channels[-1], num_classes) |
|
|
|
|
|
def get_block_raw(self, x, target_stage, target_block): |
|
|
"""Forward to target block and return raw wave data.""" |
|
|
x = self.stem(x) |
|
|
for si, stage in enumerate(self.stages): |
|
|
for bi, block in enumerate(stage): |
|
|
if si == target_stage and bi == target_block: |
|
|
return block.forward_raw(x) |
|
|
x = block(x) |
|
|
if si < len(self.downsamples): |
|
|
x = self.downsamples[si](x) |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Loading model...") |
|
|
config_path = hf_hub_download("AbstractPhil/mobiusnet-distillations", |
|
|
"checkpoints/mobius_tiny_s_imagenet_clip_vit_l14/20260111_000512/config.json") |
|
|
with open(config_path) as f: |
|
|
config = json.load(f) |
|
|
model_path = hf_hub_download("AbstractPhil/mobiusnet-distillations", |
|
|
"checkpoints/mobius_tiny_s_imagenet_clip_vit_l14/20260111_000512/checkpoints/best_model.safetensors") |
|
|
|
|
|
cfg = config['model'] |
|
|
model = MobiusNetRaw(cfg['in_chans'], cfg['num_classes'], tuple(cfg['channels']), |
|
|
tuple(cfg['depths']), tuple(cfg['scale_range']), cfg['use_integrator']).to(device) |
|
|
model.load_state_dict(load_safetensors(model_path)) |
|
|
model.eval() |
|
|
print("✓ Loaded") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ds = load_dataset("AbstractPhil/imagenet-clip-features-orderly", "clip_vit_l14", |
|
|
split="validation", streaming=True).with_format("torch") |
|
|
loader = DataLoader(ds, batch_size=16) |
|
|
batch = next(iter(loader)) |
|
|
x = batch['clip_features'].view(-1, 1, 24, 32).to(device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
blocks = [(0,0), (0,1), (1,0), (1,1), (2,0), (2,1)] |
|
|
block_names = ['S0B0', 'S0B1', 'S1B0', 'S1B1', 'S2B0', 'S2B1'] |
|
|
|
|
|
fig, axes = plt.subplots(6, 6, figsize=(24, 24)) |
|
|
|
|
|
for bi, ((si, bii), name) in enumerate(zip(blocks, block_names)): |
|
|
with torch.no_grad(): |
|
|
raw = model.get_block_raw(x, si, bii) |
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print(f"{name}: ω={raw['omega']:.3f}, α={raw['alpha']:.3f}, scales={raw['scales']}") |
|
|
print(f" Weights: L={raw['weights'][0]:.3f}, M={raw['weights'][1]:.3f}, R={raw['weights'][2]:.3f}") |
|
|
print(f" XOR weight: {raw['xor_weight']:.3f}") |
|
|
|
|
|
L, M, R = raw['L'], raw['M'], raw['R'] |
|
|
gate = raw['gate'] |
|
|
|
|
|
print(f" L: min={L.min():.6f}, max={L.max():.6f}, mean={L.mean():.6f}, std={L.std():.6f}") |
|
|
print(f" M: min={M.min():.6f}, max={M.max():.6f}, mean={M.mean():.6f}, std={M.std():.6f}") |
|
|
print(f" R: min={R.min():.6f}, max={R.max():.6f}, mean={R.mean():.6f}, std={R.std():.6f}") |
|
|
print(f" Gate: min={gate.min():.4f}, max={gate.max():.4f}, mean={gate.mean():.4f}") |
|
|
|
|
|
|
|
|
print(f" L_sin range: [{raw['L_sin'].min():.4f}, {raw['L_sin'].max():.4f}]") |
|
|
print(f" L_exp range: [{raw['L_exp'].min():.6f}, {raw['L_exp'].max():.6f}]") |
|
|
print(f" x_norm range: [{raw['x_norm'].min():.4f}, {raw['x_norm'].max():.4f}]") |
|
|
|
|
|
|
|
|
axes[bi, 0].hist(L.cpu().numpy().flatten(), bins=50, color='red', alpha=0.7, density=True) |
|
|
axes[bi, 0].set_title(f'{name} L\nμ={L.mean():.4f}, σ={L.std():.4f}', fontsize=10) |
|
|
axes[bi, 0].axvline(x=L.mean().item(), color='black', linestyle='--') |
|
|
|
|
|
axes[bi, 1].hist(M.cpu().numpy().flatten(), bins=50, color='green', alpha=0.7, density=True) |
|
|
axes[bi, 1].set_title(f'{name} M\nμ={M.mean():.4f}', fontsize=10) |
|
|
|
|
|
axes[bi, 2].hist(R.cpu().numpy().flatten(), bins=50, color='blue', alpha=0.7, density=True) |
|
|
axes[bi, 2].set_title(f'{name} R\nμ={R.mean():.4f}', fontsize=10) |
|
|
|
|
|
axes[bi, 3].hist(gate.cpu().numpy().flatten(), bins=50, color='purple', alpha=0.7, density=True) |
|
|
axes[bi, 3].set_title(f'{name} Gate\nμ={gate.mean():.4f}', fontsize=10) |
|
|
|
|
|
|
|
|
L_spatial = L[0].mean(dim=-1).cpu().numpy() |
|
|
axes[bi, 4].imshow(L_spatial, cmap='hot', aspect='auto') |
|
|
axes[bi, 4].set_title(f'{name} L spatial\nα={raw["alpha"]:.2f}', fontsize=10) |
|
|
axes[bi, 4].axis('off') |
|
|
|
|
|
gate_spatial = gate[0].mean(dim=-1).cpu().numpy() |
|
|
axes[bi, 5].imshow(gate_spatial, cmap='viridis', aspect='auto', vmin=0, vmax=1) |
|
|
axes[bi, 5].set_title(f'{name} Gate spatial', fontsize=10) |
|
|
axes[bi, 5].axis('off') |
|
|
|
|
|
plt.suptitle('Raw Wave Diagnostics: L/M/R Distributions', fontsize=14, fontweight='bold') |
|
|
plt.tight_layout() |
|
|
plt.savefig("mobius_raw_diagnostics.png", dpi=150, bbox_inches="tight") |
|
|
plt.show() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n" + "="*70) |
|
|
print("ANALYSIS: Wave Function Behavior") |
|
|
print("="*70) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(""" |
|
|
Wave function: exp(-α * sin²(ω * s * (x + drift*t))) |
|
|
|
|
|
For high α (like 5.12 at S2B1): |
|
|
- This becomes a VERY narrow peak around sin(...)=0 |
|
|
- i.e., when ω*s*(x+drift*t) = n*π |
|
|
|
|
|
The prod over 2 scales means BOTH scales must hit a peak simultaneously. |
|
|
This is extremely rare, so most values → exp(-5.12) ≈ 0.006 |
|
|
|
|
|
BUT: The gate is computed AFTER LayerNorm on gate_pre! |
|
|
gate = sigmoid(LayerNorm(weighted_sum * (0.5 + 0.5*lr))) |
|
|
|
|
|
LayerNorm rescales the near-zero values to have mean=0, std=1 |
|
|
Then sigmoid maps that to ~0.5 centered distribution. |
|
|
|
|
|
This is why gates are ~0.4-0.5 even when raw L/M/R are tiny. |
|
|
""") |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
raw = model.get_block_raw(x, 2, 1) |
|
|
|
|
|
print(f"\nS2B1 gate_pre: min={raw['gate_pre'].min():.6f}, max={raw['gate_pre'].max():.6f}, mean={raw['gate_pre'].mean():.6f}") |
|
|
print(f"S2B1 gate: min={raw['gate'].min():.4f}, max={raw['gate'].max():.4f}, mean={raw['gate'].mean():.4f}") |
|
|
|
|
|
|
|
|
print(f"\nThe information is in relative L/M/R differences across channels:") |
|
|
L_per_channel = raw['L'][0].mean(dim=(0,1)).cpu().numpy() |
|
|
M_per_channel = raw['M'][0].mean(dim=(0,1)).cpu().numpy() |
|
|
R_per_channel = raw['R'][0].mean(dim=(0,1)).cpu().numpy() |
|
|
|
|
|
fig2, ax2 = plt.subplots(1, 1, figsize=(14, 4)) |
|
|
channels = np.arange(len(L_per_channel)) |
|
|
ax2.plot(channels, L_per_channel, 'r-', alpha=0.7, label='L') |
|
|
ax2.plot(channels, M_per_channel, 'g-', alpha=0.7, label='M') |
|
|
ax2.plot(channels, R_per_channel, 'b-', alpha=0.7, label='R') |
|
|
ax2.set_xlabel('Channel') |
|
|
ax2.set_ylabel('Mean activation') |
|
|
ax2.set_title('S2B1: L/M/R per channel (the signal is in the variance)') |
|
|
ax2.legend() |
|
|
plt.tight_layout() |
|
|
plt.savefig("mobius_channel_variance.png", dpi=150) |
|
|
plt.show() |
|
|
|
|
|
print(f"\nPer-channel variance:") |
|
|
print(f" L channels std: {L_per_channel.std():.6f}") |
|
|
print(f" M channels std: {M_per_channel.std():.6f}") |
|
|
print(f" R channels std: {R_per_channel.std():.6f}") |