mobiusnet-distillations / make_chart_2.py
AbstractPhil's picture
Create make_chart_2.py
e9e6a78 verified
#@title Diagnose L/M/R Wave Values (Fixed)
!pip install -q datasets safetensors huggingface_hub
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file as load_safetensors
import matplotlib.pyplot as plt
import numpy as np
import math
import json
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# ============================================================================
# FULL MOBIUSNET WITH RAW WAVE ACCESS
# ============================================================================
class MobiusLensRaw(nn.Module):
def __init__(self, dim, layer_idx, total_layers, scale_range=(1.0, 9.0)):
super().__init__()
self.dim = dim
self.t = layer_idx / max(total_layers - 1, 1)
scale_span = scale_range[1] - scale_range[0]
step = scale_span / max(total_layers, 1)
self.register_buffer('scales', torch.tensor([scale_range[0] + self.t * scale_span,
scale_range[0] + self.t * scale_span + step]))
self.twist_in_angle = nn.Parameter(torch.tensor(self.t * math.pi))
self.twist_in_proj = nn.Linear(dim, dim, bias=False)
self.omega = nn.Parameter(torch.tensor(math.pi))
self.alpha = nn.Parameter(torch.tensor(1.5))
self.phase_l, self.drift_l = nn.Parameter(torch.zeros(2)), nn.Parameter(torch.ones(2))
self.phase_m, self.drift_m = nn.Parameter(torch.zeros(2)), nn.Parameter(torch.zeros(2))
self.phase_r, self.drift_r = nn.Parameter(torch.zeros(2)), nn.Parameter(-torch.ones(2))
self.accum_weights = nn.Parameter(torch.tensor([0.4, 0.2, 0.4]))
self.xor_weight = nn.Parameter(torch.tensor(0.7))
self.gate_norm = nn.LayerNorm(dim)
self.twist_out_angle = nn.Parameter(torch.tensor(-self.t * math.pi))
self.twist_out_proj = nn.Linear(dim, dim, bias=False)
def forward(self, x):
cos_t, sin_t = torch.cos(self.twist_in_angle), torch.sin(self.twist_in_angle)
x = x * cos_t + self.twist_in_proj(x) * sin_t
x_norm = torch.tanh(x)
t = x_norm.abs().mean(dim=-1, keepdim=True).unsqueeze(-2)
x_exp = x_norm.unsqueeze(-2)
s = self.scales.view(-1, 1)
a = self.alpha.abs() + 0.1
def wave(phase, drift):
pos = s * self.omega * (x_exp + drift.view(-1, 1) * t) + phase.view(-1, 1)
return torch.exp(-a * torch.sin(pos).pow(2)).prod(dim=-2)
L, M, R = wave(self.phase_l, self.drift_l), wave(self.phase_m, self.drift_m), wave(self.phase_r, self.drift_r)
w = torch.softmax(self.accum_weights, dim=0)
xor_w = torch.sigmoid(self.xor_weight)
lr = xor_w * (L + R - 2*L*R).abs() + (1 - xor_w) * L * R
gate = torch.sigmoid(self.gate_norm((w[0]*L + w[1]*M + w[2]*R) * (0.5 + 0.5*lr)))
x = x * gate
cos_t, sin_t = torch.cos(self.twist_out_angle), torch.sin(self.twist_out_angle)
return x * cos_t + self.twist_out_proj(x) * sin_t, gate
def forward_raw(self, x):
"""Return raw L/M/R values for inspection."""
cos_t, sin_t = torch.cos(self.twist_in_angle), torch.sin(self.twist_in_angle)
x_twisted = x * cos_t + self.twist_in_proj(x) * sin_t
x_norm = torch.tanh(x_twisted)
t = x_norm.abs().mean(dim=-1, keepdim=True).unsqueeze(-2)
x_exp = x_norm.unsqueeze(-2)
s = self.scales.view(-1, 1)
a = self.alpha.abs() + 0.1
def wave_detailed(phase, drift):
pos = s * self.omega * (x_exp + drift.view(-1, 1) * t) + phase.view(-1, 1)
sin_val = torch.sin(pos)
exp_val = torch.exp(-a * sin_val.pow(2))
prod_val = exp_val.prod(dim=-2)
return prod_val, sin_val, exp_val
L, L_sin, L_exp = wave_detailed(self.phase_l, self.drift_l)
M, M_sin, M_exp = wave_detailed(self.phase_m, self.drift_m)
R, R_sin, R_exp = wave_detailed(self.phase_r, self.drift_r)
w = torch.softmax(self.accum_weights, dim=0)
xor_w = torch.sigmoid(self.xor_weight)
xor_comp = (L + R - 2*L*R).abs()
and_comp = L * R
lr = xor_w * xor_comp + (1 - xor_w) * and_comp
gate_pre = (w[0]*L + w[1]*M + w[2]*R) * (0.5 + 0.5*lr)
gate = torch.sigmoid(self.gate_norm(gate_pre))
return {
'x_norm': x_norm, 'L': L, 'M': M, 'R': R,
'L_sin': L_sin, 'L_exp': L_exp,
'xor_comp': xor_comp, 'and_comp': and_comp,
'gate_pre': gate_pre, 'gate': gate,
'omega': self.omega.item(), 'alpha': a.item(),
'scales': self.scales.cpu().numpy(),
'weights': w.detach().cpu().numpy(),
'xor_weight': xor_w.item(),
}
class MobiusBlockRaw(nn.Module):
def __init__(self, channels, layer_idx, total_layers, scale_range=(1.0, 9.0), reduction=0.5):
super().__init__()
self.conv = nn.Sequential(nn.Conv2d(channels, channels, 3, padding=1, groups=channels, bias=False),
nn.Conv2d(channels, channels, 1, bias=False), nn.BatchNorm2d(channels))
self.lens = MobiusLensRaw(channels, layer_idx, total_layers, scale_range)
third = channels // 3
which_third = layer_idx % 3
mask = torch.ones(channels)
mask[which_third*third : which_third*third + third + (channels%3 if which_third==2 else 0)] = reduction
self.register_buffer('thirds_mask', mask.view(1, -1, 1, 1))
self.residual_weight = nn.Parameter(torch.tensor(0.9))
def forward(self, x):
identity = x
h = self.conv(x).permute(0, 2, 3, 1)
h, gate = self.lens(h)
h = h.permute(0, 3, 1, 2) * self.thirds_mask
rw = torch.sigmoid(self.residual_weight)
return rw * identity + (1 - rw) * h
def forward_raw(self, x):
h = self.conv(x).permute(0, 2, 3, 1)
return self.lens.forward_raw(h)
class MobiusNetRaw(nn.Module):
def __init__(self, in_chans=1, num_classes=1000, channels=(64,128,256),
depths=(2,2,2), scale_range=(0.5,2.5), use_integrator=True):
super().__init__()
total_layers = sum(depths)
channels = list(channels)
self.stem = nn.Sequential(nn.Conv2d(in_chans, channels[0], 3, padding=1, bias=False), nn.BatchNorm2d(channels[0]))
self.stages = nn.ModuleList()
self.downsamples = nn.ModuleList()
layer_idx = 0
for si, d in enumerate(depths):
self.stages.append(nn.ModuleList([MobiusBlockRaw(channels[si], layer_idx+i, total_layers, scale_range) for i in range(d)]))
layer_idx += d
if si < len(depths)-1:
self.downsamples.append(nn.Sequential(nn.Conv2d(channels[si], channels[si+1], 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(channels[si+1])))
# Include integrator and head for weight loading
self.integrator = nn.Sequential(nn.Conv2d(channels[-1], channels[-1], 3, padding=1, bias=False),
nn.BatchNorm2d(channels[-1]), nn.GELU()) if use_integrator else nn.Identity()
self.pool = nn.AdaptiveAvgPool2d(1)
self.head = nn.Linear(channels[-1], num_classes)
def get_block_raw(self, x, target_stage, target_block):
"""Forward to target block and return raw wave data."""
x = self.stem(x)
for si, stage in enumerate(self.stages):
for bi, block in enumerate(stage):
if si == target_stage and bi == target_block:
return block.forward_raw(x)
x = block(x)
if si < len(self.downsamples):
x = self.downsamples[si](x)
return None
# ============================================================================
# LOAD MODEL
# ============================================================================
print("Loading model...")
config_path = hf_hub_download("AbstractPhil/mobiusnet-distillations",
"checkpoints/mobius_tiny_s_imagenet_clip_vit_l14/20260111_000512/config.json")
with open(config_path) as f:
config = json.load(f)
model_path = hf_hub_download("AbstractPhil/mobiusnet-distillations",
"checkpoints/mobius_tiny_s_imagenet_clip_vit_l14/20260111_000512/checkpoints/best_model.safetensors")
cfg = config['model']
model = MobiusNetRaw(cfg['in_chans'], cfg['num_classes'], tuple(cfg['channels']),
tuple(cfg['depths']), tuple(cfg['scale_range']), cfg['use_integrator']).to(device)
model.load_state_dict(load_safetensors(model_path))
model.eval()
print("✓ Loaded")
# ============================================================================
# GET SAMPLE DATA
# ============================================================================
ds = load_dataset("AbstractPhil/imagenet-clip-features-orderly", "clip_vit_l14",
split="validation", streaming=True).with_format("torch")
loader = DataLoader(ds, batch_size=16)
batch = next(iter(loader))
x = batch['clip_features'].view(-1, 1, 24, 32).to(device)
# ============================================================================
# INSPECT EACH BLOCK
# ============================================================================
blocks = [(0,0), (0,1), (1,0), (1,1), (2,0), (2,1)]
block_names = ['S0B0', 'S0B1', 'S1B0', 'S1B1', 'S2B0', 'S2B1']
fig, axes = plt.subplots(6, 6, figsize=(24, 24))
for bi, ((si, bii), name) in enumerate(zip(blocks, block_names)):
with torch.no_grad():
raw = model.get_block_raw(x, si, bii)
print(f"\n{'='*60}")
print(f"{name}: ω={raw['omega']:.3f}, α={raw['alpha']:.3f}, scales={raw['scales']}")
print(f" Weights: L={raw['weights'][0]:.3f}, M={raw['weights'][1]:.3f}, R={raw['weights'][2]:.3f}")
print(f" XOR weight: {raw['xor_weight']:.3f}")
L, M, R = raw['L'], raw['M'], raw['R']
gate = raw['gate']
print(f" L: min={L.min():.6f}, max={L.max():.6f}, mean={L.mean():.6f}, std={L.std():.6f}")
print(f" M: min={M.min():.6f}, max={M.max():.6f}, mean={M.mean():.6f}, std={M.std():.6f}")
print(f" R: min={R.min():.6f}, max={R.max():.6f}, mean={R.mean():.6f}, std={R.std():.6f}")
print(f" Gate: min={gate.min():.4f}, max={gate.max():.4f}, mean={gate.mean():.4f}")
# Check intermediate values
print(f" L_sin range: [{raw['L_sin'].min():.4f}, {raw['L_sin'].max():.4f}]")
print(f" L_exp range: [{raw['L_exp'].min():.6f}, {raw['L_exp'].max():.6f}]")
print(f" x_norm range: [{raw['x_norm'].min():.4f}, {raw['x_norm'].max():.4f}]")
# Plot distributions
axes[bi, 0].hist(L.cpu().numpy().flatten(), bins=50, color='red', alpha=0.7, density=True)
axes[bi, 0].set_title(f'{name} L\nμ={L.mean():.4f}, σ={L.std():.4f}', fontsize=10)
axes[bi, 0].axvline(x=L.mean().item(), color='black', linestyle='--')
axes[bi, 1].hist(M.cpu().numpy().flatten(), bins=50, color='green', alpha=0.7, density=True)
axes[bi, 1].set_title(f'{name} M\nμ={M.mean():.4f}', fontsize=10)
axes[bi, 2].hist(R.cpu().numpy().flatten(), bins=50, color='blue', alpha=0.7, density=True)
axes[bi, 2].set_title(f'{name} R\nμ={R.mean():.4f}', fontsize=10)
axes[bi, 3].hist(gate.cpu().numpy().flatten(), bins=50, color='purple', alpha=0.7, density=True)
axes[bi, 3].set_title(f'{name} Gate\nμ={gate.mean():.4f}', fontsize=10)
# Spatial - single sample, mean across channels
L_spatial = L[0].mean(dim=-1).cpu().numpy()
axes[bi, 4].imshow(L_spatial, cmap='hot', aspect='auto')
axes[bi, 4].set_title(f'{name} L spatial\nα={raw["alpha"]:.2f}', fontsize=10)
axes[bi, 4].axis('off')
gate_spatial = gate[0].mean(dim=-1).cpu().numpy()
axes[bi, 5].imshow(gate_spatial, cmap='viridis', aspect='auto', vmin=0, vmax=1)
axes[bi, 5].set_title(f'{name} Gate spatial', fontsize=10)
axes[bi, 5].axis('off')
plt.suptitle('Raw Wave Diagnostics: L/M/R Distributions', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig("mobius_raw_diagnostics.png", dpi=150, bbox_inches="tight")
plt.show()
# ============================================================================
# ANALYSIS: Why are L/M/R uniform?
# ============================================================================
print("\n" + "="*70)
print("ANALYSIS: Wave Function Behavior")
print("="*70)
# The wave function: exp(-α * sin²(ω * s * (x + drift*t)))
# Let's trace through for S2B1 which has α=5.12
print("""
Wave function: exp(-α * sin²(ω * s * (x + drift*t)))
For high α (like 5.12 at S2B1):
- This becomes a VERY narrow peak around sin(...)=0
- i.e., when ω*s*(x+drift*t) = n*π
The prod over 2 scales means BOTH scales must hit a peak simultaneously.
This is extremely rare, so most values → exp(-5.12) ≈ 0.006
BUT: The gate is computed AFTER LayerNorm on gate_pre!
gate = sigmoid(LayerNorm(weighted_sum * (0.5 + 0.5*lr)))
LayerNorm rescales the near-zero values to have mean=0, std=1
Then sigmoid maps that to ~0.5 centered distribution.
This is why gates are ~0.4-0.5 even when raw L/M/R are tiny.
""")
# Verify: check gate_pre vs gate
with torch.no_grad():
raw = model.get_block_raw(x, 2, 1) # S2B1
print(f"\nS2B1 gate_pre: min={raw['gate_pre'].min():.6f}, max={raw['gate_pre'].max():.6f}, mean={raw['gate_pre'].mean():.6f}")
print(f"S2B1 gate: min={raw['gate'].min():.4f}, max={raw['gate'].max():.4f}, mean={raw['gate'].mean():.4f}")
# The "signal" is in the RELATIVE differences, not absolute values
print(f"\nThe information is in relative L/M/R differences across channels:")
L_per_channel = raw['L'][0].mean(dim=(0,1)).cpu().numpy() # [C]
M_per_channel = raw['M'][0].mean(dim=(0,1)).cpu().numpy()
R_per_channel = raw['R'][0].mean(dim=(0,1)).cpu().numpy()
fig2, ax2 = plt.subplots(1, 1, figsize=(14, 4))
channels = np.arange(len(L_per_channel))
ax2.plot(channels, L_per_channel, 'r-', alpha=0.7, label='L')
ax2.plot(channels, M_per_channel, 'g-', alpha=0.7, label='M')
ax2.plot(channels, R_per_channel, 'b-', alpha=0.7, label='R')
ax2.set_xlabel('Channel')
ax2.set_ylabel('Mean activation')
ax2.set_title('S2B1: L/M/R per channel (the signal is in the variance)')
ax2.legend()
plt.tight_layout()
plt.savefig("mobius_channel_variance.png", dpi=150)
plt.show()
print(f"\nPer-channel variance:")
print(f" L channels std: {L_per_channel.std():.6f}")
print(f" M channels std: {M_per_channel.std():.6f}")
print(f" R channels std: {R_per_channel.std():.6f}")