#@title Diagnose L/M/R Wave Values (Fixed) !pip install -q datasets safetensors huggingface_hub import torch import torch.nn as nn from torch.utils.data import DataLoader from datasets import load_dataset from huggingface_hub import hf_hub_download from safetensors.torch import load_file as load_safetensors import matplotlib.pyplot as plt import numpy as np import math import json device = 'cuda' if torch.cuda.is_available() else 'cpu' # ============================================================================ # FULL MOBIUSNET WITH RAW WAVE ACCESS # ============================================================================ class MobiusLensRaw(nn.Module): def __init__(self, dim, layer_idx, total_layers, scale_range=(1.0, 9.0)): super().__init__() self.dim = dim self.t = layer_idx / max(total_layers - 1, 1) scale_span = scale_range[1] - scale_range[0] step = scale_span / max(total_layers, 1) self.register_buffer('scales', torch.tensor([scale_range[0] + self.t * scale_span, scale_range[0] + self.t * scale_span + step])) self.twist_in_angle = nn.Parameter(torch.tensor(self.t * math.pi)) self.twist_in_proj = nn.Linear(dim, dim, bias=False) self.omega = nn.Parameter(torch.tensor(math.pi)) self.alpha = nn.Parameter(torch.tensor(1.5)) self.phase_l, self.drift_l = nn.Parameter(torch.zeros(2)), nn.Parameter(torch.ones(2)) self.phase_m, self.drift_m = nn.Parameter(torch.zeros(2)), nn.Parameter(torch.zeros(2)) self.phase_r, self.drift_r = nn.Parameter(torch.zeros(2)), nn.Parameter(-torch.ones(2)) self.accum_weights = nn.Parameter(torch.tensor([0.4, 0.2, 0.4])) self.xor_weight = nn.Parameter(torch.tensor(0.7)) self.gate_norm = nn.LayerNorm(dim) self.twist_out_angle = nn.Parameter(torch.tensor(-self.t * math.pi)) self.twist_out_proj = nn.Linear(dim, dim, bias=False) def forward(self, x): cos_t, sin_t = torch.cos(self.twist_in_angle), torch.sin(self.twist_in_angle) x = x * cos_t + self.twist_in_proj(x) * sin_t x_norm = torch.tanh(x) t = x_norm.abs().mean(dim=-1, keepdim=True).unsqueeze(-2) x_exp = x_norm.unsqueeze(-2) s = self.scales.view(-1, 1) a = self.alpha.abs() + 0.1 def wave(phase, drift): pos = s * self.omega * (x_exp + drift.view(-1, 1) * t) + phase.view(-1, 1) return torch.exp(-a * torch.sin(pos).pow(2)).prod(dim=-2) L, M, R = wave(self.phase_l, self.drift_l), wave(self.phase_m, self.drift_m), wave(self.phase_r, self.drift_r) w = torch.softmax(self.accum_weights, dim=0) xor_w = torch.sigmoid(self.xor_weight) lr = xor_w * (L + R - 2*L*R).abs() + (1 - xor_w) * L * R gate = torch.sigmoid(self.gate_norm((w[0]*L + w[1]*M + w[2]*R) * (0.5 + 0.5*lr))) x = x * gate cos_t, sin_t = torch.cos(self.twist_out_angle), torch.sin(self.twist_out_angle) return x * cos_t + self.twist_out_proj(x) * sin_t, gate def forward_raw(self, x): """Return raw L/M/R values for inspection.""" cos_t, sin_t = torch.cos(self.twist_in_angle), torch.sin(self.twist_in_angle) x_twisted = x * cos_t + self.twist_in_proj(x) * sin_t x_norm = torch.tanh(x_twisted) t = x_norm.abs().mean(dim=-1, keepdim=True).unsqueeze(-2) x_exp = x_norm.unsqueeze(-2) s = self.scales.view(-1, 1) a = self.alpha.abs() + 0.1 def wave_detailed(phase, drift): pos = s * self.omega * (x_exp + drift.view(-1, 1) * t) + phase.view(-1, 1) sin_val = torch.sin(pos) exp_val = torch.exp(-a * sin_val.pow(2)) prod_val = exp_val.prod(dim=-2) return prod_val, sin_val, exp_val L, L_sin, L_exp = wave_detailed(self.phase_l, self.drift_l) M, M_sin, M_exp = wave_detailed(self.phase_m, self.drift_m) R, R_sin, R_exp = wave_detailed(self.phase_r, self.drift_r) w = torch.softmax(self.accum_weights, dim=0) xor_w = torch.sigmoid(self.xor_weight) xor_comp = (L + R - 2*L*R).abs() and_comp = L * R lr = xor_w * xor_comp + (1 - xor_w) * and_comp gate_pre = (w[0]*L + w[1]*M + w[2]*R) * (0.5 + 0.5*lr) gate = torch.sigmoid(self.gate_norm(gate_pre)) return { 'x_norm': x_norm, 'L': L, 'M': M, 'R': R, 'L_sin': L_sin, 'L_exp': L_exp, 'xor_comp': xor_comp, 'and_comp': and_comp, 'gate_pre': gate_pre, 'gate': gate, 'omega': self.omega.item(), 'alpha': a.item(), 'scales': self.scales.cpu().numpy(), 'weights': w.detach().cpu().numpy(), 'xor_weight': xor_w.item(), } class MobiusBlockRaw(nn.Module): def __init__(self, channels, layer_idx, total_layers, scale_range=(1.0, 9.0), reduction=0.5): super().__init__() self.conv = nn.Sequential(nn.Conv2d(channels, channels, 3, padding=1, groups=channels, bias=False), nn.Conv2d(channels, channels, 1, bias=False), nn.BatchNorm2d(channels)) self.lens = MobiusLensRaw(channels, layer_idx, total_layers, scale_range) third = channels // 3 which_third = layer_idx % 3 mask = torch.ones(channels) mask[which_third*third : which_third*third + third + (channels%3 if which_third==2 else 0)] = reduction self.register_buffer('thirds_mask', mask.view(1, -1, 1, 1)) self.residual_weight = nn.Parameter(torch.tensor(0.9)) def forward(self, x): identity = x h = self.conv(x).permute(0, 2, 3, 1) h, gate = self.lens(h) h = h.permute(0, 3, 1, 2) * self.thirds_mask rw = torch.sigmoid(self.residual_weight) return rw * identity + (1 - rw) * h def forward_raw(self, x): h = self.conv(x).permute(0, 2, 3, 1) return self.lens.forward_raw(h) class MobiusNetRaw(nn.Module): def __init__(self, in_chans=1, num_classes=1000, channels=(64,128,256), depths=(2,2,2), scale_range=(0.5,2.5), use_integrator=True): super().__init__() total_layers = sum(depths) channels = list(channels) self.stem = nn.Sequential(nn.Conv2d(in_chans, channels[0], 3, padding=1, bias=False), nn.BatchNorm2d(channels[0])) self.stages = nn.ModuleList() self.downsamples = nn.ModuleList() layer_idx = 0 for si, d in enumerate(depths): self.stages.append(nn.ModuleList([MobiusBlockRaw(channels[si], layer_idx+i, total_layers, scale_range) for i in range(d)])) layer_idx += d if si < len(depths)-1: self.downsamples.append(nn.Sequential(nn.Conv2d(channels[si], channels[si+1], 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(channels[si+1]))) # Include integrator and head for weight loading self.integrator = nn.Sequential(nn.Conv2d(channels[-1], channels[-1], 3, padding=1, bias=False), nn.BatchNorm2d(channels[-1]), nn.GELU()) if use_integrator else nn.Identity() self.pool = nn.AdaptiveAvgPool2d(1) self.head = nn.Linear(channels[-1], num_classes) def get_block_raw(self, x, target_stage, target_block): """Forward to target block and return raw wave data.""" x = self.stem(x) for si, stage in enumerate(self.stages): for bi, block in enumerate(stage): if si == target_stage and bi == target_block: return block.forward_raw(x) x = block(x) if si < len(self.downsamples): x = self.downsamples[si](x) return None # ============================================================================ # LOAD MODEL # ============================================================================ print("Loading model...") config_path = hf_hub_download("AbstractPhil/mobiusnet-distillations", "checkpoints/mobius_tiny_s_imagenet_clip_vit_l14/20260111_000512/config.json") with open(config_path) as f: config = json.load(f) model_path = hf_hub_download("AbstractPhil/mobiusnet-distillations", "checkpoints/mobius_tiny_s_imagenet_clip_vit_l14/20260111_000512/checkpoints/best_model.safetensors") cfg = config['model'] model = MobiusNetRaw(cfg['in_chans'], cfg['num_classes'], tuple(cfg['channels']), tuple(cfg['depths']), tuple(cfg['scale_range']), cfg['use_integrator']).to(device) model.load_state_dict(load_safetensors(model_path)) model.eval() print("✓ Loaded") # ============================================================================ # GET SAMPLE DATA # ============================================================================ ds = load_dataset("AbstractPhil/imagenet-clip-features-orderly", "clip_vit_l14", split="validation", streaming=True).with_format("torch") loader = DataLoader(ds, batch_size=16) batch = next(iter(loader)) x = batch['clip_features'].view(-1, 1, 24, 32).to(device) # ============================================================================ # INSPECT EACH BLOCK # ============================================================================ blocks = [(0,0), (0,1), (1,0), (1,1), (2,0), (2,1)] block_names = ['S0B0', 'S0B1', 'S1B0', 'S1B1', 'S2B0', 'S2B1'] fig, axes = plt.subplots(6, 6, figsize=(24, 24)) for bi, ((si, bii), name) in enumerate(zip(blocks, block_names)): with torch.no_grad(): raw = model.get_block_raw(x, si, bii) print(f"\n{'='*60}") print(f"{name}: ω={raw['omega']:.3f}, α={raw['alpha']:.3f}, scales={raw['scales']}") print(f" Weights: L={raw['weights'][0]:.3f}, M={raw['weights'][1]:.3f}, R={raw['weights'][2]:.3f}") print(f" XOR weight: {raw['xor_weight']:.3f}") L, M, R = raw['L'], raw['M'], raw['R'] gate = raw['gate'] print(f" L: min={L.min():.6f}, max={L.max():.6f}, mean={L.mean():.6f}, std={L.std():.6f}") print(f" M: min={M.min():.6f}, max={M.max():.6f}, mean={M.mean():.6f}, std={M.std():.6f}") print(f" R: min={R.min():.6f}, max={R.max():.6f}, mean={R.mean():.6f}, std={R.std():.6f}") print(f" Gate: min={gate.min():.4f}, max={gate.max():.4f}, mean={gate.mean():.4f}") # Check intermediate values print(f" L_sin range: [{raw['L_sin'].min():.4f}, {raw['L_sin'].max():.4f}]") print(f" L_exp range: [{raw['L_exp'].min():.6f}, {raw['L_exp'].max():.6f}]") print(f" x_norm range: [{raw['x_norm'].min():.4f}, {raw['x_norm'].max():.4f}]") # Plot distributions axes[bi, 0].hist(L.cpu().numpy().flatten(), bins=50, color='red', alpha=0.7, density=True) axes[bi, 0].set_title(f'{name} L\nμ={L.mean():.4f}, σ={L.std():.4f}', fontsize=10) axes[bi, 0].axvline(x=L.mean().item(), color='black', linestyle='--') axes[bi, 1].hist(M.cpu().numpy().flatten(), bins=50, color='green', alpha=0.7, density=True) axes[bi, 1].set_title(f'{name} M\nμ={M.mean():.4f}', fontsize=10) axes[bi, 2].hist(R.cpu().numpy().flatten(), bins=50, color='blue', alpha=0.7, density=True) axes[bi, 2].set_title(f'{name} R\nμ={R.mean():.4f}', fontsize=10) axes[bi, 3].hist(gate.cpu().numpy().flatten(), bins=50, color='purple', alpha=0.7, density=True) axes[bi, 3].set_title(f'{name} Gate\nμ={gate.mean():.4f}', fontsize=10) # Spatial - single sample, mean across channels L_spatial = L[0].mean(dim=-1).cpu().numpy() axes[bi, 4].imshow(L_spatial, cmap='hot', aspect='auto') axes[bi, 4].set_title(f'{name} L spatial\nα={raw["alpha"]:.2f}', fontsize=10) axes[bi, 4].axis('off') gate_spatial = gate[0].mean(dim=-1).cpu().numpy() axes[bi, 5].imshow(gate_spatial, cmap='viridis', aspect='auto', vmin=0, vmax=1) axes[bi, 5].set_title(f'{name} Gate spatial', fontsize=10) axes[bi, 5].axis('off') plt.suptitle('Raw Wave Diagnostics: L/M/R Distributions', fontsize=14, fontweight='bold') plt.tight_layout() plt.savefig("mobius_raw_diagnostics.png", dpi=150, bbox_inches="tight") plt.show() # ============================================================================ # ANALYSIS: Why are L/M/R uniform? # ============================================================================ print("\n" + "="*70) print("ANALYSIS: Wave Function Behavior") print("="*70) # The wave function: exp(-α * sin²(ω * s * (x + drift*t))) # Let's trace through for S2B1 which has α=5.12 print(""" Wave function: exp(-α * sin²(ω * s * (x + drift*t))) For high α (like 5.12 at S2B1): - This becomes a VERY narrow peak around sin(...)=0 - i.e., when ω*s*(x+drift*t) = n*π The prod over 2 scales means BOTH scales must hit a peak simultaneously. This is extremely rare, so most values → exp(-5.12) ≈ 0.006 BUT: The gate is computed AFTER LayerNorm on gate_pre! gate = sigmoid(LayerNorm(weighted_sum * (0.5 + 0.5*lr))) LayerNorm rescales the near-zero values to have mean=0, std=1 Then sigmoid maps that to ~0.5 centered distribution. This is why gates are ~0.4-0.5 even when raw L/M/R are tiny. """) # Verify: check gate_pre vs gate with torch.no_grad(): raw = model.get_block_raw(x, 2, 1) # S2B1 print(f"\nS2B1 gate_pre: min={raw['gate_pre'].min():.6f}, max={raw['gate_pre'].max():.6f}, mean={raw['gate_pre'].mean():.6f}") print(f"S2B1 gate: min={raw['gate'].min():.4f}, max={raw['gate'].max():.4f}, mean={raw['gate'].mean():.4f}") # The "signal" is in the RELATIVE differences, not absolute values print(f"\nThe information is in relative L/M/R differences across channels:") L_per_channel = raw['L'][0].mean(dim=(0,1)).cpu().numpy() # [C] M_per_channel = raw['M'][0].mean(dim=(0,1)).cpu().numpy() R_per_channel = raw['R'][0].mean(dim=(0,1)).cpu().numpy() fig2, ax2 = plt.subplots(1, 1, figsize=(14, 4)) channels = np.arange(len(L_per_channel)) ax2.plot(channels, L_per_channel, 'r-', alpha=0.7, label='L') ax2.plot(channels, M_per_channel, 'g-', alpha=0.7, label='M') ax2.plot(channels, R_per_channel, 'b-', alpha=0.7, label='R') ax2.set_xlabel('Channel') ax2.set_ylabel('Mean activation') ax2.set_title('S2B1: L/M/R per channel (the signal is in the variance)') ax2.legend() plt.tight_layout() plt.savefig("mobius_channel_variance.png", dpi=150) plt.show() print(f"\nPer-channel variance:") print(f" L channels std: {L_per_channel.std():.6f}") print(f" M channels std: {M_per_channel.std():.6f}") print(f" R channels std: {R_per_channel.std():.6f}")