AbstractPhil commited on
Commit
e9e6a78
·
verified ·
1 Parent(s): 3c5907a

Create make_chart_2.py

Browse files
Files changed (1) hide show
  1. make_chart_2.py +310 -0
make_chart_2.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #@title Diagnose L/M/R Wave Values (Fixed)
2
+ !pip install -q datasets safetensors huggingface_hub
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.utils.data import DataLoader
7
+ from datasets import load_dataset
8
+ from huggingface_hub import hf_hub_download
9
+ from safetensors.torch import load_file as load_safetensors
10
+ import matplotlib.pyplot as plt
11
+ import numpy as np
12
+ import math
13
+ import json
14
+
15
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
16
+
17
+ # ============================================================================
18
+ # FULL MOBIUSNET WITH RAW WAVE ACCESS
19
+ # ============================================================================
20
+
21
+ class MobiusLensRaw(nn.Module):
22
+ def __init__(self, dim, layer_idx, total_layers, scale_range=(1.0, 9.0)):
23
+ super().__init__()
24
+ self.dim = dim
25
+ self.t = layer_idx / max(total_layers - 1, 1)
26
+ scale_span = scale_range[1] - scale_range[0]
27
+ step = scale_span / max(total_layers, 1)
28
+ self.register_buffer('scales', torch.tensor([scale_range[0] + self.t * scale_span,
29
+ scale_range[0] + self.t * scale_span + step]))
30
+ self.twist_in_angle = nn.Parameter(torch.tensor(self.t * math.pi))
31
+ self.twist_in_proj = nn.Linear(dim, dim, bias=False)
32
+ self.omega = nn.Parameter(torch.tensor(math.pi))
33
+ self.alpha = nn.Parameter(torch.tensor(1.5))
34
+ self.phase_l, self.drift_l = nn.Parameter(torch.zeros(2)), nn.Parameter(torch.ones(2))
35
+ self.phase_m, self.drift_m = nn.Parameter(torch.zeros(2)), nn.Parameter(torch.zeros(2))
36
+ self.phase_r, self.drift_r = nn.Parameter(torch.zeros(2)), nn.Parameter(-torch.ones(2))
37
+ self.accum_weights = nn.Parameter(torch.tensor([0.4, 0.2, 0.4]))
38
+ self.xor_weight = nn.Parameter(torch.tensor(0.7))
39
+ self.gate_norm = nn.LayerNorm(dim)
40
+ self.twist_out_angle = nn.Parameter(torch.tensor(-self.t * math.pi))
41
+ self.twist_out_proj = nn.Linear(dim, dim, bias=False)
42
+
43
+ def forward(self, x):
44
+ cos_t, sin_t = torch.cos(self.twist_in_angle), torch.sin(self.twist_in_angle)
45
+ x = x * cos_t + self.twist_in_proj(x) * sin_t
46
+ x_norm = torch.tanh(x)
47
+ t = x_norm.abs().mean(dim=-1, keepdim=True).unsqueeze(-2)
48
+ x_exp = x_norm.unsqueeze(-2)
49
+ s = self.scales.view(-1, 1)
50
+ a = self.alpha.abs() + 0.1
51
+ def wave(phase, drift):
52
+ pos = s * self.omega * (x_exp + drift.view(-1, 1) * t) + phase.view(-1, 1)
53
+ return torch.exp(-a * torch.sin(pos).pow(2)).prod(dim=-2)
54
+ L, M, R = wave(self.phase_l, self.drift_l), wave(self.phase_m, self.drift_m), wave(self.phase_r, self.drift_r)
55
+ w = torch.softmax(self.accum_weights, dim=0)
56
+ xor_w = torch.sigmoid(self.xor_weight)
57
+ lr = xor_w * (L + R - 2*L*R).abs() + (1 - xor_w) * L * R
58
+ gate = torch.sigmoid(self.gate_norm((w[0]*L + w[1]*M + w[2]*R) * (0.5 + 0.5*lr)))
59
+ x = x * gate
60
+ cos_t, sin_t = torch.cos(self.twist_out_angle), torch.sin(self.twist_out_angle)
61
+ return x * cos_t + self.twist_out_proj(x) * sin_t, gate
62
+
63
+ def forward_raw(self, x):
64
+ """Return raw L/M/R values for inspection."""
65
+ cos_t, sin_t = torch.cos(self.twist_in_angle), torch.sin(self.twist_in_angle)
66
+ x_twisted = x * cos_t + self.twist_in_proj(x) * sin_t
67
+ x_norm = torch.tanh(x_twisted)
68
+ t = x_norm.abs().mean(dim=-1, keepdim=True).unsqueeze(-2)
69
+ x_exp = x_norm.unsqueeze(-2)
70
+ s = self.scales.view(-1, 1)
71
+ a = self.alpha.abs() + 0.1
72
+
73
+ def wave_detailed(phase, drift):
74
+ pos = s * self.omega * (x_exp + drift.view(-1, 1) * t) + phase.view(-1, 1)
75
+ sin_val = torch.sin(pos)
76
+ exp_val = torch.exp(-a * sin_val.pow(2))
77
+ prod_val = exp_val.prod(dim=-2)
78
+ return prod_val, sin_val, exp_val
79
+
80
+ L, L_sin, L_exp = wave_detailed(self.phase_l, self.drift_l)
81
+ M, M_sin, M_exp = wave_detailed(self.phase_m, self.drift_m)
82
+ R, R_sin, R_exp = wave_detailed(self.phase_r, self.drift_r)
83
+
84
+ w = torch.softmax(self.accum_weights, dim=0)
85
+ xor_w = torch.sigmoid(self.xor_weight)
86
+ xor_comp = (L + R - 2*L*R).abs()
87
+ and_comp = L * R
88
+ lr = xor_w * xor_comp + (1 - xor_w) * and_comp
89
+ gate_pre = (w[0]*L + w[1]*M + w[2]*R) * (0.5 + 0.5*lr)
90
+ gate = torch.sigmoid(self.gate_norm(gate_pre))
91
+
92
+ return {
93
+ 'x_norm': x_norm, 'L': L, 'M': M, 'R': R,
94
+ 'L_sin': L_sin, 'L_exp': L_exp,
95
+ 'xor_comp': xor_comp, 'and_comp': and_comp,
96
+ 'gate_pre': gate_pre, 'gate': gate,
97
+ 'omega': self.omega.item(), 'alpha': a.item(),
98
+ 'scales': self.scales.cpu().numpy(),
99
+ 'weights': w.detach().cpu().numpy(),
100
+ 'xor_weight': xor_w.item(),
101
+ }
102
+
103
+ class MobiusBlockRaw(nn.Module):
104
+ def __init__(self, channels, layer_idx, total_layers, scale_range=(1.0, 9.0), reduction=0.5):
105
+ super().__init__()
106
+ self.conv = nn.Sequential(nn.Conv2d(channels, channels, 3, padding=1, groups=channels, bias=False),
107
+ nn.Conv2d(channels, channels, 1, bias=False), nn.BatchNorm2d(channels))
108
+ self.lens = MobiusLensRaw(channels, layer_idx, total_layers, scale_range)
109
+ third = channels // 3
110
+ which_third = layer_idx % 3
111
+ mask = torch.ones(channels)
112
+ mask[which_third*third : which_third*third + third + (channels%3 if which_third==2 else 0)] = reduction
113
+ self.register_buffer('thirds_mask', mask.view(1, -1, 1, 1))
114
+ self.residual_weight = nn.Parameter(torch.tensor(0.9))
115
+
116
+ def forward(self, x):
117
+ identity = x
118
+ h = self.conv(x).permute(0, 2, 3, 1)
119
+ h, gate = self.lens(h)
120
+ h = h.permute(0, 3, 1, 2) * self.thirds_mask
121
+ rw = torch.sigmoid(self.residual_weight)
122
+ return rw * identity + (1 - rw) * h
123
+
124
+ def forward_raw(self, x):
125
+ h = self.conv(x).permute(0, 2, 3, 1)
126
+ return self.lens.forward_raw(h)
127
+
128
+ class MobiusNetRaw(nn.Module):
129
+ def __init__(self, in_chans=1, num_classes=1000, channels=(64,128,256),
130
+ depths=(2,2,2), scale_range=(0.5,2.5), use_integrator=True):
131
+ super().__init__()
132
+ total_layers = sum(depths)
133
+ channels = list(channels)
134
+ self.stem = nn.Sequential(nn.Conv2d(in_chans, channels[0], 3, padding=1, bias=False), nn.BatchNorm2d(channels[0]))
135
+ self.stages = nn.ModuleList()
136
+ self.downsamples = nn.ModuleList()
137
+ layer_idx = 0
138
+ for si, d in enumerate(depths):
139
+ self.stages.append(nn.ModuleList([MobiusBlockRaw(channels[si], layer_idx+i, total_layers, scale_range) for i in range(d)]))
140
+ layer_idx += d
141
+ if si < len(depths)-1:
142
+ self.downsamples.append(nn.Sequential(nn.Conv2d(channels[si], channels[si+1], 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(channels[si+1])))
143
+ # Include integrator and head for weight loading
144
+ self.integrator = nn.Sequential(nn.Conv2d(channels[-1], channels[-1], 3, padding=1, bias=False),
145
+ nn.BatchNorm2d(channels[-1]), nn.GELU()) if use_integrator else nn.Identity()
146
+ self.pool = nn.AdaptiveAvgPool2d(1)
147
+ self.head = nn.Linear(channels[-1], num_classes)
148
+
149
+ def get_block_raw(self, x, target_stage, target_block):
150
+ """Forward to target block and return raw wave data."""
151
+ x = self.stem(x)
152
+ for si, stage in enumerate(self.stages):
153
+ for bi, block in enumerate(stage):
154
+ if si == target_stage and bi == target_block:
155
+ return block.forward_raw(x)
156
+ x = block(x)
157
+ if si < len(self.downsamples):
158
+ x = self.downsamples[si](x)
159
+ return None
160
+
161
+ # ============================================================================
162
+ # LOAD MODEL
163
+ # ============================================================================
164
+
165
+ print("Loading model...")
166
+ config_path = hf_hub_download("AbstractPhil/mobiusnet-distillations",
167
+ "checkpoints/mobius_tiny_s_imagenet_clip_vit_l14/20260111_000512/config.json")
168
+ with open(config_path) as f:
169
+ config = json.load(f)
170
+ model_path = hf_hub_download("AbstractPhil/mobiusnet-distillations",
171
+ "checkpoints/mobius_tiny_s_imagenet_clip_vit_l14/20260111_000512/checkpoints/best_model.safetensors")
172
+
173
+ cfg = config['model']
174
+ model = MobiusNetRaw(cfg['in_chans'], cfg['num_classes'], tuple(cfg['channels']),
175
+ tuple(cfg['depths']), tuple(cfg['scale_range']), cfg['use_integrator']).to(device)
176
+ model.load_state_dict(load_safetensors(model_path))
177
+ model.eval()
178
+ print("✓ Loaded")
179
+
180
+ # ============================================================================
181
+ # GET SAMPLE DATA
182
+ # ============================================================================
183
+
184
+ ds = load_dataset("AbstractPhil/imagenet-clip-features-orderly", "clip_vit_l14",
185
+ split="validation", streaming=True).with_format("torch")
186
+ loader = DataLoader(ds, batch_size=16)
187
+ batch = next(iter(loader))
188
+ x = batch['clip_features'].view(-1, 1, 24, 32).to(device)
189
+
190
+ # ============================================================================
191
+ # INSPECT EACH BLOCK
192
+ # ============================================================================
193
+
194
+ blocks = [(0,0), (0,1), (1,0), (1,1), (2,0), (2,1)]
195
+ block_names = ['S0B0', 'S0B1', 'S1B0', 'S1B1', 'S2B0', 'S2B1']
196
+
197
+ fig, axes = plt.subplots(6, 6, figsize=(24, 24))
198
+
199
+ for bi, ((si, bii), name) in enumerate(zip(blocks, block_names)):
200
+ with torch.no_grad():
201
+ raw = model.get_block_raw(x, si, bii)
202
+
203
+ print(f"\n{'='*60}")
204
+ print(f"{name}: ω={raw['omega']:.3f}, α={raw['alpha']:.3f}, scales={raw['scales']}")
205
+ print(f" Weights: L={raw['weights'][0]:.3f}, M={raw['weights'][1]:.3f}, R={raw['weights'][2]:.3f}")
206
+ print(f" XOR weight: {raw['xor_weight']:.3f}")
207
+
208
+ L, M, R = raw['L'], raw['M'], raw['R']
209
+ gate = raw['gate']
210
+
211
+ print(f" L: min={L.min():.6f}, max={L.max():.6f}, mean={L.mean():.6f}, std={L.std():.6f}")
212
+ print(f" M: min={M.min():.6f}, max={M.max():.6f}, mean={M.mean():.6f}, std={M.std():.6f}")
213
+ print(f" R: min={R.min():.6f}, max={R.max():.6f}, mean={R.mean():.6f}, std={R.std():.6f}")
214
+ print(f" Gate: min={gate.min():.4f}, max={gate.max():.4f}, mean={gate.mean():.4f}")
215
+
216
+ # Check intermediate values
217
+ print(f" L_sin range: [{raw['L_sin'].min():.4f}, {raw['L_sin'].max():.4f}]")
218
+ print(f" L_exp range: [{raw['L_exp'].min():.6f}, {raw['L_exp'].max():.6f}]")
219
+ print(f" x_norm range: [{raw['x_norm'].min():.4f}, {raw['x_norm'].max():.4f}]")
220
+
221
+ # Plot distributions
222
+ axes[bi, 0].hist(L.cpu().numpy().flatten(), bins=50, color='red', alpha=0.7, density=True)
223
+ axes[bi, 0].set_title(f'{name} L\nμ={L.mean():.4f}, σ={L.std():.4f}', fontsize=10)
224
+ axes[bi, 0].axvline(x=L.mean().item(), color='black', linestyle='--')
225
+
226
+ axes[bi, 1].hist(M.cpu().numpy().flatten(), bins=50, color='green', alpha=0.7, density=True)
227
+ axes[bi, 1].set_title(f'{name} M\nμ={M.mean():.4f}', fontsize=10)
228
+
229
+ axes[bi, 2].hist(R.cpu().numpy().flatten(), bins=50, color='blue', alpha=0.7, density=True)
230
+ axes[bi, 2].set_title(f'{name} R\nμ={R.mean():.4f}', fontsize=10)
231
+
232
+ axes[bi, 3].hist(gate.cpu().numpy().flatten(), bins=50, color='purple', alpha=0.7, density=True)
233
+ axes[bi, 3].set_title(f'{name} Gate\nμ={gate.mean():.4f}', fontsize=10)
234
+
235
+ # Spatial - single sample, mean across channels
236
+ L_spatial = L[0].mean(dim=-1).cpu().numpy()
237
+ axes[bi, 4].imshow(L_spatial, cmap='hot', aspect='auto')
238
+ axes[bi, 4].set_title(f'{name} L spatial\nα={raw["alpha"]:.2f}', fontsize=10)
239
+ axes[bi, 4].axis('off')
240
+
241
+ gate_spatial = gate[0].mean(dim=-1).cpu().numpy()
242
+ axes[bi, 5].imshow(gate_spatial, cmap='viridis', aspect='auto', vmin=0, vmax=1)
243
+ axes[bi, 5].set_title(f'{name} Gate spatial', fontsize=10)
244
+ axes[bi, 5].axis('off')
245
+
246
+ plt.suptitle('Raw Wave Diagnostics: L/M/R Distributions', fontsize=14, fontweight='bold')
247
+ plt.tight_layout()
248
+ plt.savefig("mobius_raw_diagnostics.png", dpi=150, bbox_inches="tight")
249
+ plt.show()
250
+
251
+ # ============================================================================
252
+ # ANALYSIS: Why are L/M/R uniform?
253
+ # ============================================================================
254
+
255
+ print("\n" + "="*70)
256
+ print("ANALYSIS: Wave Function Behavior")
257
+ print("="*70)
258
+
259
+ # The wave function: exp(-α * sin²(ω * s * (x + drift*t)))
260
+ # Let's trace through for S2B1 which has α=5.12
261
+
262
+ print("""
263
+ Wave function: exp(-α * sin²(ω * s * (x + drift*t)))
264
+
265
+ For high α (like 5.12 at S2B1):
266
+ - This becomes a VERY narrow peak around sin(...)=0
267
+ - i.e., when ω*s*(x+drift*t) = n*π
268
+
269
+ The prod over 2 scales means BOTH scales must hit a peak simultaneously.
270
+ This is extremely rare, so most values → exp(-5.12) ≈ 0.006
271
+
272
+ BUT: The gate is computed AFTER LayerNorm on gate_pre!
273
+ gate = sigmoid(LayerNorm(weighted_sum * (0.5 + 0.5*lr)))
274
+
275
+ LayerNorm rescales the near-zero values to have mean=0, std=1
276
+ Then sigmoid maps that to ~0.5 centered distribution.
277
+
278
+ This is why gates are ~0.4-0.5 even when raw L/M/R are tiny.
279
+ """)
280
+
281
+ # Verify: check gate_pre vs gate
282
+ with torch.no_grad():
283
+ raw = model.get_block_raw(x, 2, 1) # S2B1
284
+
285
+ print(f"\nS2B1 gate_pre: min={raw['gate_pre'].min():.6f}, max={raw['gate_pre'].max():.6f}, mean={raw['gate_pre'].mean():.6f}")
286
+ print(f"S2B1 gate: min={raw['gate'].min():.4f}, max={raw['gate'].max():.4f}, mean={raw['gate'].mean():.4f}")
287
+
288
+ # The "signal" is in the RELATIVE differences, not absolute values
289
+ print(f"\nThe information is in relative L/M/R differences across channels:")
290
+ L_per_channel = raw['L'][0].mean(dim=(0,1)).cpu().numpy() # [C]
291
+ M_per_channel = raw['M'][0].mean(dim=(0,1)).cpu().numpy()
292
+ R_per_channel = raw['R'][0].mean(dim=(0,1)).cpu().numpy()
293
+
294
+ fig2, ax2 = plt.subplots(1, 1, figsize=(14, 4))
295
+ channels = np.arange(len(L_per_channel))
296
+ ax2.plot(channels, L_per_channel, 'r-', alpha=0.7, label='L')
297
+ ax2.plot(channels, M_per_channel, 'g-', alpha=0.7, label='M')
298
+ ax2.plot(channels, R_per_channel, 'b-', alpha=0.7, label='R')
299
+ ax2.set_xlabel('Channel')
300
+ ax2.set_ylabel('Mean activation')
301
+ ax2.set_title('S2B1: L/M/R per channel (the signal is in the variance)')
302
+ ax2.legend()
303
+ plt.tight_layout()
304
+ plt.savefig("mobius_channel_variance.png", dpi=150)
305
+ plt.show()
306
+
307
+ print(f"\nPer-channel variance:")
308
+ print(f" L channels std: {L_per_channel.std():.6f}")
309
+ print(f" M channels std: {M_per_channel.std():.6f}")
310
+ print(f" R channels std: {R_per_channel.std():.6f}")