AbstractPhil commited on
Commit
5e1998b
·
verified ·
1 Parent(s): 1ebdd36

Create trainer.py

Browse files
Files changed (1) hide show
  1. trainer.py +1376 -0
trainer.py ADDED
@@ -0,0 +1,1376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MobiusNet Trainer with TensorBoard, SafeTensors, and HuggingFace Upload
3
+ =======================================================================
4
+ """
5
+
6
+ import os
7
+ import re
8
+ import json
9
+ import math
10
+ import shutil
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from torch import Tensor
15
+ from typing import Tuple, Optional, Dict, Any
16
+ from torchvision import datasets, transforms
17
+ from torch.utils.data import DataLoader, Dataset
18
+ from torch.utils.tensorboard import SummaryWriter
19
+ from tqdm.auto import tqdm
20
+ from datetime import datetime
21
+ from pathlib import Path
22
+ from safetensors.torch import save_file as save_safetensors, load_file as load_safetensors
23
+ from huggingface_hub import HfApi, login
24
+
25
+ # Colab HF login
26
+ try:
27
+ from google.colab import userdata
28
+ token = userdata.get('HF_TOKEN')
29
+ os.environ['HF_TOKEN'] = token
30
+ login(token=token)
31
+ print("Logged in to HuggingFace via Colab")
32
+ except:
33
+ # Not in Colab or token not set
34
+ pass
35
+
36
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
37
+ print(f"Device: {device}")
38
+
39
+ # Enable TF32 for faster computation on Ampere+ GPUs
40
+ torch.backends.cuda.matmul.allow_tf32 = True
41
+ torch.backends.cudnn.allow_tf32 = True
42
+ torch.set_float32_matmul_precision('high')
43
+
44
+
45
+ # ============================================================================
46
+ # MÖBIUS LENS
47
+ # ============================================================================
48
+
49
+ class MobiusLens(nn.Module):
50
+ def __init__(
51
+ self,
52
+ dim: int,
53
+ layer_idx: int,
54
+ total_layers: int,
55
+ scale_range: Tuple[float, float] = (1.0, 9.0),
56
+ ):
57
+ super().__init__()
58
+
59
+ self.dim = dim
60
+ self.layer_idx = layer_idx
61
+ self.total_layers = total_layers
62
+ self.t = layer_idx / max(total_layers - 1, 1)
63
+
64
+ scale_span = scale_range[1] - scale_range[0]
65
+ step = scale_span / max(total_layers, 1)
66
+ scale_low = scale_range[0] + self.t * scale_span
67
+ scale_high = scale_low + step
68
+
69
+ self.register_buffer('scales', torch.tensor([scale_low, scale_high]))
70
+
71
+ self.twist_in_angle = nn.Parameter(torch.tensor(self.t * math.pi))
72
+ self.twist_in_proj = nn.Linear(dim, dim, bias=False)
73
+ nn.init.orthogonal_(self.twist_in_proj.weight)
74
+
75
+ self.omega = nn.Parameter(torch.tensor(math.pi))
76
+ self.alpha = nn.Parameter(torch.tensor(1.5))
77
+
78
+ self.phase_l = nn.Parameter(torch.zeros(2))
79
+ self.drift_l = nn.Parameter(torch.ones(2))
80
+ self.phase_m = nn.Parameter(torch.zeros(2))
81
+ self.drift_m = nn.Parameter(torch.zeros(2))
82
+ self.phase_r = nn.Parameter(torch.zeros(2))
83
+ self.drift_r = nn.Parameter(-torch.ones(2))
84
+
85
+ self.accum_weights = nn.Parameter(torch.tensor([0.4, 0.2, 0.4]))
86
+ self.xor_weight = nn.Parameter(torch.tensor(0.7))
87
+
88
+ self.gate_norm = nn.LayerNorm(dim)
89
+
90
+ self.twist_out_angle = nn.Parameter(torch.tensor(-self.t * math.pi))
91
+ self.twist_out_proj = nn.Linear(dim, dim, bias=False)
92
+ nn.init.orthogonal_(self.twist_out_proj.weight)
93
+
94
+ def _twist_in(self, x: Tensor) -> Tensor:
95
+ cos_t = torch.cos(self.twist_in_angle)
96
+ sin_t = torch.sin(self.twist_in_angle)
97
+ return x * cos_t + self.twist_in_proj(x) * sin_t
98
+
99
+ def _center_lens(self, x: Tensor) -> Tensor:
100
+ x_norm = torch.tanh(x)
101
+ t = x_norm.abs().mean(dim=-1, keepdim=True).unsqueeze(-2)
102
+
103
+ x_exp = x_norm.unsqueeze(-2)
104
+ s = self.scales.view(-1, 1)
105
+
106
+ def wave(phase, drift):
107
+ a = self.alpha.abs() + 0.1
108
+ pos = s * self.omega * (x_exp + drift.view(-1, 1) * t) + phase.view(-1, 1)
109
+ return torch.exp(-a * torch.sin(pos).pow(2)).prod(dim=-2)
110
+
111
+ L = wave(self.phase_l, self.drift_l)
112
+ M = wave(self.phase_m, self.drift_m)
113
+ R = wave(self.phase_r, self.drift_r)
114
+
115
+ w = torch.softmax(self.accum_weights, dim=0)
116
+ xor_w = torch.sigmoid(self.xor_weight)
117
+
118
+ xor_comp = (L + R - 2 * L * R).abs()
119
+ and_comp = L * R
120
+ lr = xor_w * xor_comp + (1 - xor_w) * and_comp
121
+
122
+ gate = w[0] * L + w[1] * M + w[2] * R
123
+ gate = gate * (0.5 + 0.5 * lr)
124
+ gate = torch.sigmoid(self.gate_norm(gate))
125
+
126
+ return x * gate
127
+
128
+ def _twist_out(self, x: Tensor) -> Tensor:
129
+ cos_t = torch.cos(self.twist_out_angle)
130
+ sin_t = torch.sin(self.twist_out_angle)
131
+ return x * cos_t + self.twist_out_proj(x) * sin_t
132
+
133
+ def forward(self, x: Tensor) -> Tensor:
134
+ return self._twist_out(self._center_lens(self._twist_in(x)))
135
+
136
+ def get_lens_stats(self) -> Dict[str, float]:
137
+ """Return lens parameters for logging."""
138
+ return {
139
+ 'omega': self.omega.item(),
140
+ 'alpha': self.alpha.item(),
141
+ 'twist_in_angle': self.twist_in_angle.item(),
142
+ 'twist_out_angle': self.twist_out_angle.item(),
143
+ 'xor_weight': torch.sigmoid(self.xor_weight).item(),
144
+ 'accum_weights_l': torch.softmax(self.accum_weights, dim=0)[0].item(),
145
+ 'accum_weights_m': torch.softmax(self.accum_weights, dim=0)[1].item(),
146
+ 'accum_weights_r': torch.softmax(self.accum_weights, dim=0)[2].item(),
147
+ }
148
+
149
+
150
+ # ============================================================================
151
+ # MÖBIUS CONV BLOCK
152
+ # ============================================================================
153
+
154
+ class MobiusConvBlock(nn.Module):
155
+ def __init__(
156
+ self,
157
+ channels: int,
158
+ layer_idx: int,
159
+ total_layers: int,
160
+ scale_range: Tuple[float, float] = (1.0, 9.0),
161
+ reduction: float = 0.5,
162
+ ):
163
+ super().__init__()
164
+
165
+ self.conv = nn.Sequential(
166
+ nn.Conv2d(channels, channels, 3, padding=1, groups=channels, bias=False),
167
+ nn.Conv2d(channels, channels, 1, bias=False),
168
+ nn.BatchNorm2d(channels),
169
+ )
170
+
171
+ self.lens = MobiusLens(channels, layer_idx, total_layers, scale_range)
172
+
173
+ third = channels // 3
174
+ which_third = layer_idx % 3
175
+ mask = torch.ones(channels)
176
+ start = which_third * third
177
+ end = start + third + (channels % 3 if which_third == 2 else 0)
178
+ mask[start:end] = reduction
179
+ self.register_buffer('thirds_mask', mask.view(1, -1, 1, 1))
180
+
181
+ self.residual_weight = nn.Parameter(torch.tensor(0.9))
182
+
183
+ def forward(self, x: Tensor) -> Tensor:
184
+ identity = x
185
+
186
+ h = self.conv(x)
187
+ B, D, H, W = h.shape
188
+ h = h.permute(0, 2, 3, 1)
189
+ h = self.lens(h)
190
+ h = h.permute(0, 3, 1, 2)
191
+ h = h * self.thirds_mask
192
+
193
+ rw = torch.sigmoid(self.residual_weight)
194
+ return rw * identity + (1 - rw) * h
195
+
196
+ def get_residual_weight(self) -> float:
197
+ return torch.sigmoid(self.residual_weight).item()
198
+
199
+
200
+ # ============================================================================
201
+ # MÖBIUS NET
202
+ # ============================================================================
203
+
204
+ class MobiusNet(nn.Module):
205
+ def __init__(
206
+ self,
207
+ in_chans: int = 3,
208
+ num_classes: int = 200,
209
+ channels: Tuple[int, ...] = (64, 128, 256, 512),
210
+ depths: Tuple[int, ...] = (2, 2, 2, 2),
211
+ scale_range: Tuple[float, float] = (0.5, 2.5),
212
+ use_integrator: bool = True,
213
+ ):
214
+ super().__init__()
215
+
216
+ num_stages = len(depths)
217
+ total_layers = sum(depths)
218
+
219
+ self.total_layers = total_layers
220
+ self.scale_range = scale_range
221
+ self.channels = tuple(channels)
222
+ self.depths = tuple(depths)
223
+ self.num_stages = num_stages
224
+ self.use_integrator = use_integrator
225
+ self.num_classes = num_classes
226
+ self.in_chans = in_chans
227
+
228
+ channels = list(channels)
229
+ while len(channels) < num_stages:
230
+ channels.append(channels[-1])
231
+
232
+ self.stem = nn.Sequential(
233
+ nn.Conv2d(in_chans, channels[0], 3, stride=1, padding=1, bias=False),
234
+ nn.BatchNorm2d(channels[0]),
235
+ )
236
+
237
+ layer_idx = 0
238
+ self.stages = nn.ModuleList()
239
+ self.downsamples = nn.ModuleList()
240
+
241
+ for stage_idx in range(num_stages):
242
+ ch = channels[stage_idx]
243
+
244
+ stage = nn.ModuleList()
245
+ for _ in range(depths[stage_idx]):
246
+ stage.append(MobiusConvBlock(ch, layer_idx, total_layers, scale_range))
247
+ layer_idx += 1
248
+ self.stages.append(stage)
249
+
250
+ if stage_idx < num_stages - 1:
251
+ ch_next = channels[stage_idx + 1]
252
+ self.downsamples.append(nn.Sequential(
253
+ nn.Conv2d(ch, ch_next, 3, stride=2, padding=1, bias=False),
254
+ nn.BatchNorm2d(ch_next),
255
+ ))
256
+
257
+ final_ch = channels[num_stages - 1]
258
+ if use_integrator:
259
+ self.integrator = nn.Sequential(
260
+ nn.Conv2d(final_ch, final_ch, 3, padding=1, bias=False),
261
+ nn.BatchNorm2d(final_ch),
262
+ nn.GELU(),
263
+ )
264
+ else:
265
+ self.integrator = nn.Identity()
266
+
267
+ self.pool = nn.AdaptiveAvgPool2d(1)
268
+ self.head = nn.Linear(final_ch, num_classes)
269
+
270
+ def forward(self, x: Tensor) -> Tensor:
271
+ x = self.stem(x)
272
+
273
+ for i, stage in enumerate(self.stages):
274
+ for block in stage:
275
+ x = block(x)
276
+ if i < len(self.downsamples):
277
+ x = self.downsamples[i](x)
278
+
279
+ x = self.integrator(x)
280
+ return self.head(self.pool(x).flatten(1))
281
+
282
+ def get_config(self) -> Dict[str, Any]:
283
+ """Return model configuration for saving."""
284
+ return {
285
+ 'in_chans': self.in_chans,
286
+ 'num_classes': self.num_classes,
287
+ 'channels': self.channels,
288
+ 'depths': self.depths,
289
+ 'scale_range': self.scale_range,
290
+ 'use_integrator': self.use_integrator,
291
+ 'total_layers': self.total_layers,
292
+ 'num_stages': self.num_stages,
293
+ }
294
+
295
+ def get_all_lens_stats(self) -> Dict[str, Dict[str, float]]:
296
+ """Return stats from all lenses for logging."""
297
+ stats = {}
298
+ layer_idx = 0
299
+ for stage_idx, stage in enumerate(self.stages):
300
+ for block_idx, block in enumerate(stage):
301
+ key = f"stage{stage_idx}_block{block_idx}"
302
+ stats[key] = block.lens.get_lens_stats()
303
+ stats[key]['residual_weight'] = block.get_residual_weight()
304
+ layer_idx += 1
305
+ return stats
306
+
307
+
308
+ # ============================================================================
309
+ # TINY IMAGENET DATASET
310
+ # ============================================================================
311
+
312
+ def get_tiny_imagenet_loaders(data_dir='./data/tiny-imagenet-200', batch_size=128):
313
+ train_dir = os.path.join(data_dir, 'train')
314
+ val_dir = os.path.join(data_dir, 'val')
315
+
316
+ val_images_dir = os.path.join(val_dir, 'images')
317
+ if os.path.exists(val_images_dir):
318
+ print("Reorganizing validation folder...")
319
+ reorganize_val_folder(val_dir)
320
+
321
+ train_transform = transforms.Compose([
322
+ transforms.RandomCrop(64, padding=8),
323
+ transforms.RandomHorizontalFlip(),
324
+ transforms.AutoAugment(transforms.AutoAugmentPolicy.IMAGENET),
325
+ transforms.ToTensor(),
326
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
327
+ ])
328
+
329
+ val_transform = transforms.Compose([
330
+ transforms.ToTensor(),
331
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
332
+ ])
333
+
334
+ train_dataset = datasets.ImageFolder(train_dir, transform=train_transform)
335
+ val_dataset = datasets.ImageFolder(val_dir, transform=val_transform)
336
+
337
+ train_loader = DataLoader(
338
+ train_dataset, batch_size=batch_size, shuffle=True,
339
+ num_workers=8, pin_memory=True, persistent_workers=True
340
+ )
341
+ val_loader = DataLoader(
342
+ val_dataset, batch_size=256, shuffle=False,
343
+ num_workers=4, pin_memory=True, persistent_workers=True
344
+ )
345
+
346
+ return train_loader, val_loader
347
+
348
+
349
+ def reorganize_val_folder(val_dir):
350
+ """Reorganize Tiny ImageNet val folder into class subfolders."""
351
+ val_images_dir = os.path.join(val_dir, 'images')
352
+ val_annotations = os.path.join(val_dir, 'val_annotations.txt')
353
+
354
+ if not os.path.exists(val_images_dir):
355
+ return
356
+
357
+ with open(val_annotations, 'r') as f:
358
+ for line in f:
359
+ parts = line.strip().split('\t')
360
+ img_name, class_id = parts[0], parts[1]
361
+
362
+ class_dir = os.path.join(val_dir, class_id)
363
+ os.makedirs(class_dir, exist_ok=True)
364
+
365
+ src = os.path.join(val_images_dir, img_name)
366
+ dst = os.path.join(class_dir, img_name)
367
+
368
+ if os.path.exists(src):
369
+ shutil.move(src, dst)
370
+
371
+ if os.path.exists(val_images_dir):
372
+ shutil.rmtree(val_images_dir)
373
+ if os.path.exists(val_annotations):
374
+ os.remove(val_annotations)
375
+
376
+ print("Validation folder reorganized.")
377
+
378
+
379
+ # ============================================================================
380
+ # CLIP FEATURES DATASET
381
+ # ============================================================================
382
+
383
+ # CLIP feature dims and reshape targets
384
+ CLIP_SHAPES = {
385
+ 'clip_vit_b16': (512, 1, 16, 32), # 512 = 16*32
386
+ 'clip_vit_b32': (512, 1, 16, 32),
387
+ 'clip_vit_l14': (768, 1, 24, 32), # 768 = 24*32
388
+ 'clip_vit_laion_b32': (512, 1, 16, 32),
389
+ 'clip_vit_laion_bigg14': (1280, 1, 32, 40), # 1280 = 32*40
390
+ 'clip_vit_laion_h14': (1024, 1, 32, 32), # 1024 = 32*32
391
+ }
392
+
393
+
394
+ class CLIPFeaturesDataset(Dataset):
395
+ """Dataset wrapper that reshapes CLIP features to 2D spatial format."""
396
+
397
+ def __init__(self, hf_dataset, target_shape: Tuple[int, int, int]):
398
+ """
399
+ Args:
400
+ hf_dataset: HuggingFace dataset split
401
+ target_shape: (channels, height, width) to reshape features into
402
+ """
403
+ self.dataset = hf_dataset
404
+ self.target_shape = target_shape # (C, H, W)
405
+
406
+ def __len__(self):
407
+ return len(self.dataset)
408
+
409
+ def __getitem__(self, idx):
410
+ item = self.dataset[idx]
411
+ features = torch.tensor(item['clip_features'], dtype=torch.float32)
412
+ label = torch.tensor(item['label'], dtype=torch.long)
413
+
414
+ # Reshape [dim] -> [C, H, W]
415
+ features = features.view(*self.target_shape)
416
+
417
+ return features, label
418
+
419
+
420
+ def get_clip_feature_loaders(
421
+ subset: str = 'clip_vit_b32',
422
+ batch_size: int = 256,
423
+ num_workers: int = 8,
424
+ ):
425
+ """
426
+ Load CLIP features from HuggingFace and reshape for conv processing.
427
+
428
+ Args:
429
+ subset: Which CLIP model features ('clip_vit_b32', 'clip_vit_l14', etc.)
430
+ batch_size: Batch size
431
+ num_workers: DataLoader workers
432
+
433
+ Returns:
434
+ train_loader, val_loader, (in_chans, height, width)
435
+ """
436
+ from datasets import load_dataset
437
+
438
+ if subset not in CLIP_SHAPES:
439
+ raise ValueError(f"Unknown subset: {subset}. Choose from {list(CLIP_SHAPES.keys())}")
440
+
441
+ feat_dim, in_chans, h, w = CLIP_SHAPES[subset]
442
+
443
+ print(f"Loading dataset: AbstractPhil/imagenet-clip-features-orderly ({subset})")
444
+ print(f"Feature dim: {feat_dim} -> [{in_chans}, {h}, {w}]")
445
+
446
+ dataset = load_dataset(
447
+ "AbstractPhil/imagenet-clip-features-orderly",
448
+ subset,
449
+ trust_remote_code=True,
450
+ )
451
+
452
+ target_shape = (in_chans, h, w)
453
+
454
+ train_data = CLIPFeaturesDataset(dataset['train'], target_shape)
455
+ val_data = CLIPFeaturesDataset(dataset['validation'], target_shape)
456
+
457
+ print(f"Train samples: {len(train_data):,}")
458
+ print(f"Val samples: {len(val_data):,}")
459
+
460
+ train_loader = DataLoader(
461
+ train_data,
462
+ batch_size=batch_size,
463
+ shuffle=True,
464
+ num_workers=num_workers,
465
+ pin_memory=True,
466
+ persistent_workers=True if num_workers > 0 else False,
467
+ drop_last=True,
468
+ )
469
+
470
+ val_loader = DataLoader(
471
+ val_data,
472
+ batch_size=batch_size * 2,
473
+ shuffle=False,
474
+ num_workers=max(1, num_workers // 2),
475
+ pin_memory=True,
476
+ persistent_workers=True if num_workers > 1 else False,
477
+ )
478
+
479
+ return train_loader, val_loader, (in_chans, h, w)
480
+
481
+
482
+ # ============================================================================
483
+ # PRESETS
484
+ # ============================================================================
485
+
486
+ PRESETS = {
487
+ 'mobius_tiny_s': {
488
+ 'channels': (64, 128, 256),
489
+ 'depths': (2, 2, 2),
490
+ 'scale_range': (0.5, 2.5),
491
+ },
492
+ 'mobius_tiny_m': {
493
+ 'channels': (64, 128, 256, 512, 768),
494
+ 'depths': (2, 2, 4, 2, 2),
495
+ 'scale_range': (0.25, 2.75),
496
+ },
497
+ 'mobius_tiny_l': {
498
+ 'channels': (96, 192, 384, 768),
499
+ 'depths': (3, 3, 3, 3),
500
+ 'scale_range': (0.5, 3.5),
501
+ },
502
+ 'mobius_base': {
503
+ 'channels': (128, 256, 512, 768, 1024),
504
+ 'depths': (2, 2, 2, 2, 2),
505
+ 'scale_range': (0.25, 2.75),
506
+ },
507
+ }
508
+
509
+
510
+ # ============================================================================
511
+ # CHECKPOINT MANAGER
512
+ # ============================================================================
513
+
514
+ class CheckpointManager:
515
+ def __init__(
516
+ self,
517
+ base_dir: str,
518
+ variant_name: str,
519
+ dataset_name: str,
520
+ hf_repo: str = "AbstractPhil/mobiusnet",
521
+ upload_every_n_epochs: int = 10,
522
+ save_every_n_epochs: int = 10,
523
+ timestamp: Optional[str] = None,
524
+ ):
525
+ self.timestamp = timestamp or datetime.now().strftime("%Y%m%d_%H%M%S")
526
+ self.variant_name = variant_name
527
+ self.dataset_name = dataset_name
528
+ self.hf_repo = hf_repo
529
+ self.upload_every_n_epochs = upload_every_n_epochs
530
+ self.save_every_n_epochs = save_every_n_epochs
531
+
532
+ # Directory structure
533
+ self.run_name = f"{variant_name}_{dataset_name}"
534
+ self.run_dir = Path(base_dir) / "checkpoints" / self.run_name / self.timestamp
535
+ self.checkpoints_dir = self.run_dir / "checkpoints"
536
+ self.tensorboard_dir = self.run_dir / "tensorboard"
537
+
538
+ # Create directories
539
+ self.checkpoints_dir.mkdir(parents=True, exist_ok=True)
540
+ self.tensorboard_dir.mkdir(parents=True, exist_ok=True)
541
+
542
+ # TensorBoard writer
543
+ self.writer = SummaryWriter(log_dir=str(self.tensorboard_dir))
544
+
545
+ # HuggingFace API
546
+ self.hf_api = HfApi()
547
+ self.uploaded_files = set()
548
+
549
+ # Track best
550
+ self.best_acc = 0.0
551
+ self.best_epoch = 0
552
+ self.best_changed_since_upload = False
553
+
554
+ print(f"Checkpoint directory: {self.run_dir}")
555
+
556
+ @staticmethod
557
+ def extract_timestamp(checkpoint_path: str) -> Optional[str]:
558
+ """Extract timestamp from checkpoint path."""
559
+ # Match YYYYMMDD_HHMMSS pattern
560
+ match = re.search(r'(\d{8}_\d{6})', checkpoint_path)
561
+ if match:
562
+ return match.group(1)
563
+ return None
564
+
565
+ def save_config(self, config: Dict[str, Any], training_config: Dict[str, Any]):
566
+ """Save model and training configuration."""
567
+ full_config = {
568
+ 'model': config,
569
+ 'training': training_config,
570
+ 'timestamp': self.timestamp,
571
+ 'variant_name': self.variant_name,
572
+ 'dataset_name': self.dataset_name,
573
+ }
574
+
575
+ config_path = self.run_dir / "config.json"
576
+ with open(config_path, 'w') as f:
577
+ json.dump(full_config, f, indent=2)
578
+
579
+ return config_path
580
+
581
+ def save_checkpoint(
582
+ self,
583
+ model: nn.Module,
584
+ optimizer: torch.optim.Optimizer,
585
+ scheduler: Any,
586
+ epoch: int,
587
+ train_acc: float,
588
+ val_acc: float,
589
+ train_loss: float,
590
+ is_best: bool = False,
591
+ ):
592
+ """Save checkpoint every N epochs, always save best (overwriting)."""
593
+
594
+ # Unwrap compiled model if necessary
595
+ raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model
596
+
597
+ # Checkpoint data
598
+ checkpoint = {
599
+ 'epoch': epoch,
600
+ 'train_acc': train_acc,
601
+ 'val_acc': val_acc,
602
+ 'train_loss': train_loss,
603
+ 'best_acc': self.best_acc,
604
+ 'optimizer_state_dict': optimizer.state_dict(),
605
+ 'scheduler_state_dict': scheduler.state_dict(),
606
+ }
607
+
608
+ # Save epoch checkpoint every N epochs
609
+ if epoch % self.save_every_n_epochs == 0:
610
+ epoch_pt_path = self.checkpoints_dir / f"checkpoint_epoch_{epoch:04d}.pt"
611
+ torch.save({**checkpoint, 'model_state_dict': raw_model.state_dict()}, epoch_pt_path)
612
+
613
+ epoch_st_path = self.checkpoints_dir / f"checkpoint_epoch_{epoch:04d}.safetensors"
614
+ save_safetensors(raw_model.state_dict(), str(epoch_st_path))
615
+
616
+ # Save best model (overwrites previous best)
617
+ if is_best:
618
+ self.best_acc = val_acc
619
+ self.best_epoch = epoch
620
+ self.best_changed_since_upload = True
621
+
622
+ # PyTorch best
623
+ best_pt_path = self.checkpoints_dir / "best_model.pt"
624
+ torch.save({**checkpoint, 'model_state_dict': raw_model.state_dict()}, best_pt_path)
625
+
626
+ # SafeTensors best
627
+ best_st_path = self.checkpoints_dir / "best_model.safetensors"
628
+ save_safetensors(raw_model.state_dict(), str(best_st_path))
629
+
630
+ # Save accuracy info
631
+ acc_path = self.run_dir / "best_accuracy.json"
632
+ with open(acc_path, 'w') as f:
633
+ json.dump({
634
+ 'best_acc': val_acc,
635
+ 'best_epoch': epoch,
636
+ 'train_acc': train_acc,
637
+ 'train_loss': train_loss,
638
+ }, f, indent=2)
639
+
640
+ def save_final(self, model: nn.Module, final_acc: float, final_epoch: int):
641
+ """Save final model."""
642
+ raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model
643
+
644
+ # SafeTensors final
645
+ final_st_path = self.checkpoints_dir / "final_model.safetensors"
646
+ save_safetensors(raw_model.state_dict(), str(final_st_path))
647
+
648
+ # PyTorch final
649
+ final_pt_path = self.checkpoints_dir / "final_model.pt"
650
+ torch.save({
651
+ 'model_state_dict': raw_model.state_dict(),
652
+ 'final_acc': final_acc,
653
+ 'final_epoch': final_epoch,
654
+ 'best_acc': self.best_acc,
655
+ 'best_epoch': self.best_epoch,
656
+ }, final_pt_path)
657
+
658
+ # Final accuracy info
659
+ acc_path = self.run_dir / "final_accuracy.json"
660
+ with open(acc_path, 'w') as f:
661
+ json.dump({
662
+ 'final_acc': final_acc,
663
+ 'final_epoch': final_epoch,
664
+ 'best_acc': self.best_acc,
665
+ 'best_epoch': self.best_epoch,
666
+ }, f, indent=2)
667
+
668
+ return final_st_path, final_pt_path
669
+
670
+ def log_scalars(self, epoch: int, scalars: Dict[str, float], prefix: str = ""):
671
+ """Log scalars to TensorBoard."""
672
+ for name, value in scalars.items():
673
+ tag = f"{prefix}/{name}" if prefix else name
674
+ self.writer.add_scalar(tag, value, epoch)
675
+
676
+ def log_lens_stats(self, epoch: int, model: nn.Module):
677
+ """Log lens statistics to TensorBoard."""
678
+ raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model
679
+ stats = raw_model.get_all_lens_stats()
680
+
681
+ for block_name, block_stats in stats.items():
682
+ for stat_name, value in block_stats.items():
683
+ self.writer.add_scalar(f"lens/{block_name}/{stat_name}", value, epoch)
684
+
685
+ def log_histograms(self, epoch: int, model: nn.Module):
686
+ """Log weight histograms to TensorBoard."""
687
+ raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model
688
+
689
+ for name, param in raw_model.named_parameters():
690
+ if param.requires_grad:
691
+ self.writer.add_histogram(f"weights/{name}", param.data, epoch)
692
+ if param.grad is not None:
693
+ self.writer.add_histogram(f"gradients/{name}", param.grad, epoch)
694
+
695
+ def upload_to_hf(self, epoch: int, force: bool = False):
696
+ """Upload checkpoint every N epochs. Best uploads only on upload epochs if changed."""
697
+ if not force and epoch % self.upload_every_n_epochs != 0:
698
+ return
699
+
700
+ try:
701
+ hf_base_path = f"checkpoints/{self.run_name}/{self.timestamp}"
702
+
703
+ files_to_upload = []
704
+
705
+ # Always upload config
706
+ config_path = self.run_dir / "config.json"
707
+ if config_path.exists():
708
+ files_to_upload.append(config_path)
709
+
710
+ # Upload checkpoint if saved this epoch
711
+ if epoch % self.save_every_n_epochs == 0:
712
+ ckpt_st = self.checkpoints_dir / f"checkpoint_epoch_{epoch:04d}.safetensors"
713
+ ckpt_pt = self.checkpoints_dir / f"checkpoint_epoch_{epoch:04d}.pt"
714
+ if ckpt_st.exists():
715
+ files_to_upload.append(ckpt_st)
716
+ if ckpt_pt.exists():
717
+ files_to_upload.append(ckpt_pt)
718
+
719
+ # Upload best if it changed since last upload
720
+ if self.best_changed_since_upload:
721
+ best_files = [
722
+ self.checkpoints_dir / "best_model.safetensors",
723
+ self.checkpoints_dir / "best_model.pt",
724
+ self.run_dir / "best_accuracy.json",
725
+ ]
726
+ for f in best_files:
727
+ if f.exists():
728
+ files_to_upload.append(f)
729
+ self.best_changed_since_upload = False
730
+
731
+ # Upload files
732
+ for local_path in files_to_upload:
733
+ rel_path = local_path.relative_to(self.run_dir)
734
+ hf_path = f"{hf_base_path}/{rel_path}"
735
+
736
+ try:
737
+ self.hf_api.upload_file(
738
+ path_or_fileobj=str(local_path),
739
+ path_in_repo=hf_path,
740
+ repo_id=self.hf_repo,
741
+ repo_type="model",
742
+ )
743
+ print(f"Uploaded: {hf_path}")
744
+ except Exception as e:
745
+ print(f"Failed to upload {rel_path}: {e}")
746
+
747
+ except Exception as e:
748
+ print(f"HuggingFace upload error: {e}")
749
+
750
+ def close(self):
751
+ """Close TensorBoard writer."""
752
+ self.writer.close()
753
+
754
+ @staticmethod
755
+ def load_checkpoint(
756
+ checkpoint_path: str,
757
+ model: nn.Module,
758
+ optimizer: Optional[torch.optim.Optimizer] = None,
759
+ scheduler: Optional[Any] = None,
760
+ hf_repo: str = "AbstractPhil/mobiusnet",
761
+ device: torch.device = torch.device('cpu'),
762
+ ) -> Dict[str, Any]:
763
+ """
764
+ Load checkpoint from local path or HuggingFace repo.
765
+
766
+ Args:
767
+ checkpoint_path: Either:
768
+ - Local file path to .pt checkpoint
769
+ - Local directory containing checkpoints
770
+ - HuggingFace path like "checkpoints/variant_dataset/timestamp"
771
+ model: Model to load weights into
772
+ optimizer: Optional optimizer to restore state
773
+ scheduler: Optional scheduler to restore state
774
+ hf_repo: HuggingFace repo ID
775
+ device: Device to load tensors to
776
+
777
+ Returns:
778
+ Dict with checkpoint info (epoch, best_acc, etc.)
779
+ """
780
+ from huggingface_hub import hf_hub_download, list_repo_files
781
+
782
+ checkpoint_file = None
783
+
784
+ # Check if it's a local file
785
+ if os.path.isfile(checkpoint_path):
786
+ checkpoint_file = checkpoint_path
787
+
788
+ # Check if it's a local directory
789
+ elif os.path.isdir(checkpoint_path):
790
+ # Look for best_model.pt or latest checkpoint
791
+ best_path = os.path.join(checkpoint_path, "checkpoints", "best_model.pt")
792
+ if os.path.exists(best_path):
793
+ checkpoint_file = best_path
794
+ else:
795
+ # Find latest epoch checkpoint
796
+ ckpt_dir = os.path.join(checkpoint_path, "checkpoints")
797
+ if os.path.isdir(ckpt_dir):
798
+ pt_files = sorted([f for f in os.listdir(ckpt_dir) if f.startswith("checkpoint_epoch_") and f.endswith(".pt")])
799
+ if pt_files:
800
+ checkpoint_file = os.path.join(ckpt_dir, pt_files[-1])
801
+
802
+ # Try HuggingFace download
803
+ if checkpoint_file is None:
804
+ print(f"Attempting to download from HuggingFace: {hf_repo}/{checkpoint_path}")
805
+ try:
806
+ # If checkpoint_path is a directory path in the repo
807
+ if not checkpoint_path.endswith(".pt"):
808
+ # Try to download best_model.pt
809
+ try:
810
+ checkpoint_file = hf_hub_download(
811
+ repo_id=hf_repo,
812
+ filename=f"{checkpoint_path}/checkpoints/best_model.pt",
813
+ repo_type="model",
814
+ )
815
+ print(f"Downloaded best_model.pt from {hf_repo}")
816
+ except:
817
+ # List files and find latest checkpoint
818
+ files = list_repo_files(repo_id=hf_repo, repo_type="model")
819
+ ckpt_files = sorted([f for f in files if checkpoint_path in f and f.endswith(".pt") and "checkpoint_epoch_" in f])
820
+ if ckpt_files:
821
+ checkpoint_file = hf_hub_download(
822
+ repo_id=hf_repo,
823
+ filename=ckpt_files[-1],
824
+ repo_type="model",
825
+ )
826
+ print(f"Downloaded {ckpt_files[-1]} from {hf_repo}")
827
+ else:
828
+ # Direct file path
829
+ checkpoint_file = hf_hub_download(
830
+ repo_id=hf_repo,
831
+ filename=checkpoint_path,
832
+ repo_type="model",
833
+ )
834
+ print(f"Downloaded {checkpoint_path} from {hf_repo}")
835
+ except Exception as e:
836
+ raise FileNotFoundError(f"Could not find or download checkpoint: {checkpoint_path}. Error: {e}")
837
+
838
+ if checkpoint_file is None:
839
+ raise FileNotFoundError(f"Could not find checkpoint: {checkpoint_path}")
840
+
841
+ print(f"Loading checkpoint from: {checkpoint_file}")
842
+ checkpoint = torch.load(checkpoint_file, map_location=device, weights_only=False)
843
+
844
+ # Load model weights
845
+ raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model
846
+ raw_model.load_state_dict(checkpoint['model_state_dict'])
847
+ print(f"Loaded model weights")
848
+
849
+ # Load optimizer state
850
+ if optimizer is not None and 'optimizer_state_dict' in checkpoint:
851
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
852
+ print(f"Loaded optimizer state")
853
+
854
+ # Load scheduler state
855
+ if scheduler is not None and 'scheduler_state_dict' in checkpoint:
856
+ scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
857
+ print(f"Loaded scheduler state")
858
+
859
+ info = {
860
+ 'epoch': checkpoint.get('epoch', 0),
861
+ 'best_acc': checkpoint.get('best_acc', 0.0),
862
+ 'train_acc': checkpoint.get('train_acc', 0.0),
863
+ 'val_acc': checkpoint.get('val_acc', 0.0),
864
+ 'train_loss': checkpoint.get('train_loss', 0.0),
865
+ }
866
+
867
+ print(f"Resuming from epoch {info['epoch']} (best_acc: {info['best_acc']:.4f})")
868
+
869
+ return info
870
+
871
+
872
+ # ============================================================================
873
+ # TRAINING
874
+ # ============================================================================
875
+
876
+ def train_tiny_imagenet(
877
+ preset: str = 'mobius_tiny_m',
878
+ epochs: int = 100,
879
+ lr: float = 1e-3,
880
+ batch_size: int = 128,
881
+ use_integrator: bool = True,
882
+ data_dir: str = './data/tiny-imagenet-200',
883
+ output_dir: str = './outputs',
884
+ hf_repo: str = "AbstractPhil/mobiusnet",
885
+ save_every_n_epochs: int = 10,
886
+ upload_every_n_epochs: int = 10,
887
+ log_histograms_every: int = 10,
888
+ use_compile: bool = True,
889
+ continue_from: Optional[str] = None,
890
+ ):
891
+ """
892
+ Train MobiusNet on Tiny ImageNet.
893
+
894
+ Args:
895
+ preset: Model preset name
896
+ epochs: Total epochs to train
897
+ lr: Learning rate
898
+ batch_size: Batch size
899
+ use_integrator: Whether to use integrator layer
900
+ data_dir: Path to Tiny ImageNet data
901
+ output_dir: Output directory for checkpoints
902
+ hf_repo: HuggingFace repo for uploads/downloads
903
+ save_every_n_epochs: Save checkpoint every N epochs
904
+ upload_every_n_epochs: Upload to HF every N epochs
905
+ log_histograms_every: Log weight histograms every N epochs
906
+ use_compile: Whether to use torch.compile
907
+ continue_from: Resume from checkpoint. Can be:
908
+ - Local .pt file path
909
+ - Local checkpoint directory
910
+ - HuggingFace path (e.g., "checkpoints/mobius_base_tiny_imagenet/20240101_120000")
911
+ """
912
+ config = PRESETS[preset]
913
+ dataset_name = "tiny_imagenet"
914
+
915
+ print("=" * 70)
916
+ print(f"MÖBIUS NET - {preset.upper()} - TINY IMAGENET")
917
+ print("=" * 70)
918
+ print(f"Device: {device}")
919
+ print(f"Channels: {config['channels']}")
920
+ print(f"Depths: {config['depths']}")
921
+ print(f"Scale range: {config['scale_range']}")
922
+ print(f"Integrator: {use_integrator}")
923
+ if continue_from:
924
+ print(f"Continuing from: {continue_from}")
925
+ print()
926
+
927
+ # Extract timestamp from checkpoint path if continuing
928
+ resume_timestamp = None
929
+ if continue_from:
930
+ resume_timestamp = CheckpointManager.extract_timestamp(continue_from)
931
+ if resume_timestamp:
932
+ print(f"Using original timestamp: {resume_timestamp}")
933
+
934
+ # Initialize checkpoint manager
935
+ ckpt_manager = CheckpointManager(
936
+ base_dir=output_dir,
937
+ variant_name=preset,
938
+ dataset_name=dataset_name,
939
+ hf_repo=hf_repo,
940
+ upload_every_n_epochs=upload_every_n_epochs,
941
+ save_every_n_epochs=save_every_n_epochs,
942
+ timestamp=resume_timestamp,
943
+ )
944
+
945
+ # Data
946
+ train_loader, val_loader = get_tiny_imagenet_loaders(data_dir, batch_size)
947
+
948
+ # Model
949
+ model = MobiusNet(
950
+ in_chans=3,
951
+ num_classes=200,
952
+ use_integrator=use_integrator,
953
+ **config
954
+ ).to(device)
955
+
956
+ total_params = sum(p.numel() for p in model.parameters())
957
+ print(f"Total params: {total_params:,}")
958
+ print()
959
+
960
+ # Save config
961
+ training_config = {
962
+ 'epochs': epochs,
963
+ 'lr': lr,
964
+ 'batch_size': batch_size,
965
+ 'optimizer': 'AdamW',
966
+ 'weight_decay': 0.05,
967
+ 'scheduler': 'CosineAnnealingLR',
968
+ 'total_params': total_params,
969
+ }
970
+ ckpt_manager.save_config(model.get_config(), training_config)
971
+
972
+ # Compile model
973
+ if use_compile:
974
+ model = torch.compile(model, mode='reduce-overhead')
975
+
976
+ # Optimizer and scheduler
977
+ optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.05)
978
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
979
+
980
+ # Load checkpoint if continuing
981
+ start_epoch = 1
982
+ best_acc = 0.0
983
+
984
+ if continue_from:
985
+ ckpt_info = CheckpointManager.load_checkpoint(
986
+ checkpoint_path=continue_from,
987
+ model=model,
988
+ optimizer=optimizer,
989
+ scheduler=scheduler,
990
+ hf_repo=hf_repo,
991
+ device=device,
992
+ )
993
+ start_epoch = ckpt_info['epoch'] + 1
994
+ best_acc = ckpt_info['best_acc']
995
+ ckpt_manager.best_acc = best_acc
996
+ ckpt_manager.best_epoch = ckpt_info['epoch']
997
+ print(f"Resuming training from epoch {start_epoch}")
998
+
999
+ for epoch in range(start_epoch, epochs + 1):
1000
+ # Training
1001
+ model.train()
1002
+ train_loss, train_correct, train_total = 0, 0, 0
1003
+
1004
+ pbar = tqdm(train_loader, desc=f"Epoch {epoch:3d}")
1005
+ for x, y in pbar:
1006
+ x, y = x.to(device), y.to(device)
1007
+
1008
+ optimizer.zero_grad()
1009
+ logits = model(x)
1010
+ loss = F.cross_entropy(logits, y)
1011
+ loss.backward()
1012
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
1013
+ optimizer.step()
1014
+
1015
+ train_loss += loss.item() * x.size(0)
1016
+ train_correct += (logits.argmax(1) == y).sum().item()
1017
+ train_total += x.size(0)
1018
+
1019
+ pbar.set_postfix(loss=f"{loss.item():.4f}")
1020
+
1021
+ scheduler.step()
1022
+
1023
+ # Validation
1024
+ model.eval()
1025
+ val_correct, val_total = 0, 0
1026
+ with torch.no_grad():
1027
+ for x, y in val_loader:
1028
+ x, y = x.to(device), y.to(device)
1029
+ logits = model(x)
1030
+ val_correct += (logits.argmax(1) == y).sum().item()
1031
+ val_total += x.size(0)
1032
+
1033
+ # Metrics
1034
+ train_acc = train_correct / train_total
1035
+ val_acc = val_correct / val_total
1036
+ avg_loss = train_loss / train_total
1037
+ current_lr = scheduler.get_last_lr()[0]
1038
+
1039
+ is_best = val_acc > best_acc
1040
+ if is_best:
1041
+ best_acc = val_acc
1042
+
1043
+ marker = " ★" if is_best else ""
1044
+ print(f"Epoch {epoch:3d} | Loss: {avg_loss:.4f} | "
1045
+ f"Train: {train_acc:.4f} | Val: {val_acc:.4f} | Best: {best_acc:.4f}{marker}")
1046
+
1047
+ # TensorBoard logging
1048
+ ckpt_manager.log_scalars(epoch, {
1049
+ 'loss': avg_loss,
1050
+ 'train_acc': train_acc,
1051
+ 'val_acc': val_acc,
1052
+ 'best_acc': best_acc,
1053
+ 'learning_rate': current_lr,
1054
+ }, prefix="train")
1055
+
1056
+ # Log lens stats
1057
+ ckpt_manager.log_lens_stats(epoch, model)
1058
+
1059
+ # Log histograms periodically
1060
+ if epoch % log_histograms_every == 0:
1061
+ ckpt_manager.log_histograms(epoch, model)
1062
+
1063
+ # Save checkpoint
1064
+ ckpt_manager.save_checkpoint(
1065
+ model=model,
1066
+ optimizer=optimizer,
1067
+ scheduler=scheduler,
1068
+ epoch=epoch,
1069
+ train_acc=train_acc,
1070
+ val_acc=val_acc,
1071
+ train_loss=avg_loss,
1072
+ is_best=is_best,
1073
+ )
1074
+
1075
+ # Upload to HuggingFace (handles both checkpoint and best)
1076
+ ckpt_manager.upload_to_hf(epoch)
1077
+
1078
+ # Save final model
1079
+ ckpt_manager.save_final(model, val_acc, epochs)
1080
+
1081
+ # Final upload
1082
+ ckpt_manager.upload_to_hf(epochs, force=True)
1083
+ ckpt_manager.close()
1084
+
1085
+ print()
1086
+ print("=" * 70)
1087
+ print("FINAL RESULTS")
1088
+ print("=" * 70)
1089
+ print(f"Preset: {preset}")
1090
+ print(f"Best accuracy: {best_acc:.4f}")
1091
+ print(f"Total params: {total_params:,}")
1092
+ print(f"Checkpoints: {ckpt_manager.run_dir}")
1093
+ print("=" * 70)
1094
+
1095
+ return model, best_acc
1096
+
1097
+
1098
+ # ============================================================================
1099
+ # CLIP FEATURES TRAINING
1100
+ # ============================================================================
1101
+
1102
+ def train_clip_features(
1103
+ preset: str = 'mobius_tiny_m',
1104
+ clip_subset: str = 'clip_vit_b32',
1105
+ epochs: int = 50,
1106
+ lr: float = 1e-3,
1107
+ batch_size: int = 256,
1108
+ use_integrator: bool = True,
1109
+ output_dir: str = './outputs',
1110
+ hf_repo: str = "AbstractPhil/mobiusnet",
1111
+ save_every_n_epochs: int = 5,
1112
+ upload_every_n_epochs: int = 5,
1113
+ log_histograms_every: int = 10,
1114
+ use_compile: bool = True,
1115
+ continue_from: Optional[str] = None,
1116
+ num_workers: int = 8,
1117
+ ):
1118
+ """
1119
+ Train MobiusNet on CLIP features for ImageNet classification.
1120
+
1121
+ Args:
1122
+ preset: Model preset name
1123
+ clip_subset: CLIP model features to use ('clip_vit_b32', 'clip_vit_l14', etc.)
1124
+ epochs: Total epochs
1125
+ lr: Learning rate
1126
+ batch_size: Batch size (can be larger since no image augmentation)
1127
+ use_integrator: Whether to use integrator layer
1128
+ output_dir: Output directory
1129
+ hf_repo: HuggingFace repo
1130
+ save_every_n_epochs: Save checkpoint interval
1131
+ upload_every_n_epochs: Upload to HF interval
1132
+ log_histograms_every: Histogram logging interval
1133
+ use_compile: Use torch.compile
1134
+ continue_from: Resume checkpoint path
1135
+ num_workers: DataLoader workers
1136
+ """
1137
+ config = PRESETS[preset]
1138
+ dataset_name = f"imagenet_{clip_subset}"
1139
+
1140
+ print("=" * 70)
1141
+ print(f"MÖBIUS NET - {preset.upper()} - IMAGENET CLIP FEATURES")
1142
+ print(f"CLIP Subset: {clip_subset}")
1143
+ print("=" * 70)
1144
+ print(f"Device: {device}")
1145
+ print(f"Channels: {config['channels']}")
1146
+ print(f"Depths: {config['depths']}")
1147
+ print(f"Scale range: {config['scale_range']}")
1148
+ print(f"Integrator: {use_integrator}")
1149
+ if continue_from:
1150
+ print(f"Continuing from: {continue_from}")
1151
+ print()
1152
+
1153
+ # Extract timestamp if continuing
1154
+ resume_timestamp = None
1155
+ if continue_from:
1156
+ resume_timestamp = CheckpointManager.extract_timestamp(continue_from)
1157
+ if resume_timestamp:
1158
+ print(f"Using original timestamp: {resume_timestamp}")
1159
+
1160
+ # Initialize checkpoint manager
1161
+ ckpt_manager = CheckpointManager(
1162
+ base_dir=output_dir,
1163
+ variant_name=preset,
1164
+ dataset_name=dataset_name,
1165
+ hf_repo=hf_repo,
1166
+ upload_every_n_epochs=upload_every_n_epochs,
1167
+ save_every_n_epochs=save_every_n_epochs,
1168
+ timestamp=resume_timestamp,
1169
+ )
1170
+
1171
+ # Data
1172
+ train_loader, val_loader, (in_chans, h, w) = get_clip_feature_loaders(
1173
+ subset=clip_subset,
1174
+ batch_size=batch_size,
1175
+ num_workers=num_workers,
1176
+ )
1177
+
1178
+ print(f"Input shape: [{in_chans}, {h}, {w}]")
1179
+
1180
+ # Model - note in_chans=1 for CLIP features reshaped to 2D
1181
+ model = MobiusNet(
1182
+ in_chans=in_chans,
1183
+ num_classes=1000, # ImageNet
1184
+ use_integrator=use_integrator,
1185
+ **config
1186
+ ).to(device)
1187
+
1188
+ total_params = sum(p.numel() for p in model.parameters())
1189
+ print(f"Total params: {total_params:,}")
1190
+ print()
1191
+
1192
+ # Save config
1193
+ training_config = {
1194
+ 'epochs': epochs,
1195
+ 'lr': lr,
1196
+ 'batch_size': batch_size,
1197
+ 'clip_subset': clip_subset,
1198
+ 'input_shape': [in_chans, h, w],
1199
+ 'optimizer': 'AdamW',
1200
+ 'weight_decay': 0.05,
1201
+ 'scheduler': 'CosineAnnealingLR',
1202
+ 'total_params': total_params,
1203
+ }
1204
+ ckpt_manager.save_config(model.get_config(), training_config)
1205
+
1206
+ # Compile
1207
+ if use_compile:
1208
+ model = torch.compile(model, mode='reduce-overhead')
1209
+
1210
+ # Optimizer and scheduler
1211
+ optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.05)
1212
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
1213
+
1214
+ # Load checkpoint if continuing
1215
+ start_epoch = 1
1216
+ best_acc = 0.0
1217
+
1218
+ if continue_from:
1219
+ ckpt_info = CheckpointManager.load_checkpoint(
1220
+ checkpoint_path=continue_from,
1221
+ model=model,
1222
+ optimizer=optimizer,
1223
+ scheduler=scheduler,
1224
+ hf_repo=hf_repo,
1225
+ device=device,
1226
+ )
1227
+ start_epoch = ckpt_info['epoch'] + 1
1228
+ best_acc = ckpt_info['best_acc']
1229
+ ckpt_manager.best_acc = best_acc
1230
+ ckpt_manager.best_epoch = ckpt_info['epoch']
1231
+ print(f"Resuming training from epoch {start_epoch}")
1232
+
1233
+ for epoch in range(start_epoch, epochs + 1):
1234
+ # Training
1235
+ model.train()
1236
+ train_loss, train_correct, train_total = 0, 0, 0
1237
+
1238
+ pbar = tqdm(train_loader, desc=f"Epoch {epoch:3d}")
1239
+ for features, labels in pbar:
1240
+ features, labels = features.to(device), labels.to(device)
1241
+
1242
+ optimizer.zero_grad()
1243
+ logits = model(features)
1244
+ loss = F.cross_entropy(logits, labels)
1245
+ loss.backward()
1246
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
1247
+ optimizer.step()
1248
+
1249
+ train_loss += loss.item() * features.size(0)
1250
+ train_correct += (logits.argmax(1) == labels).sum().item()
1251
+ train_total += features.size(0)
1252
+
1253
+ pbar.set_postfix(loss=f"{loss.item():.4f}")
1254
+
1255
+ scheduler.step()
1256
+
1257
+ # Validation
1258
+ model.eval()
1259
+ val_correct, val_total = 0, 0
1260
+ val_top5_correct = 0
1261
+
1262
+ with torch.no_grad():
1263
+ for features, labels in val_loader:
1264
+ features, labels = features.to(device), labels.to(device)
1265
+ logits = model(features)
1266
+
1267
+ # Top-1
1268
+ val_correct += (logits.argmax(1) == labels).sum().item()
1269
+ val_total += features.size(0)
1270
+
1271
+ # Top-5
1272
+ _, top5_preds = logits.topk(5, dim=1)
1273
+ val_top5_correct += (top5_preds == labels.unsqueeze(1)).any(dim=1).sum().item()
1274
+
1275
+ # Metrics
1276
+ train_acc = train_correct / train_total
1277
+ val_acc = val_correct / val_total
1278
+ val_top5_acc = val_top5_correct / val_total
1279
+ avg_loss = train_loss / train_total
1280
+ current_lr = scheduler.get_last_lr()[0]
1281
+
1282
+ is_best = val_acc > best_acc
1283
+ if is_best:
1284
+ best_acc = val_acc
1285
+
1286
+ marker = " ★" if is_best else ""
1287
+ print(f"Epoch {epoch:3d} | Loss: {avg_loss:.4f} | "
1288
+ f"Train: {train_acc:.4f} | Val: {val_acc:.4f} (Top5: {val_top5_acc:.4f}) | "
1289
+ f"Best: {best_acc:.4f}{marker}")
1290
+
1291
+ # TensorBoard
1292
+ ckpt_manager.log_scalars(epoch, {
1293
+ 'loss': avg_loss,
1294
+ 'train_acc': train_acc,
1295
+ 'val_acc': val_acc,
1296
+ 'val_top5_acc': val_top5_acc,
1297
+ 'best_acc': best_acc,
1298
+ 'learning_rate': current_lr,
1299
+ }, prefix="train")
1300
+
1301
+ ckpt_manager.log_lens_stats(epoch, model)
1302
+
1303
+ if epoch % log_histograms_every == 0:
1304
+ ckpt_manager.log_histograms(epoch, model)
1305
+
1306
+ # Save
1307
+ ckpt_manager.save_checkpoint(
1308
+ model=model,
1309
+ optimizer=optimizer,
1310
+ scheduler=scheduler,
1311
+ epoch=epoch,
1312
+ train_acc=train_acc,
1313
+ val_acc=val_acc,
1314
+ train_loss=avg_loss,
1315
+ is_best=is_best,
1316
+ )
1317
+
1318
+ # Upload
1319
+ ckpt_manager.upload_to_hf(epoch)
1320
+
1321
+ # Final
1322
+ ckpt_manager.save_final(model, val_acc, epochs)
1323
+ ckpt_manager.upload_to_hf(epochs, force=True)
1324
+ ckpt_manager.close()
1325
+
1326
+ print()
1327
+ print("=" * 70)
1328
+ print("FINAL RESULTS")
1329
+ print("=" * 70)
1330
+ print(f"Preset: {preset}")
1331
+ print(f"CLIP subset: {clip_subset}")
1332
+ print(f"Best Top-1 accuracy: {best_acc:.4f}")
1333
+ print(f"Total params: {total_params:,}")
1334
+ print(f"Checkpoints: {ckpt_manager.run_dir}")
1335
+ print("=" * 70)
1336
+
1337
+ return model, best_acc
1338
+
1339
+
1340
+ # ============================================================================
1341
+ # RUN
1342
+ # ============================================================================
1343
+
1344
+ if __name__ == '__main__':
1345
+ # Choose training mode:
1346
+
1347
+ # Option 1: Train on Tiny ImageNet (raw images)
1348
+ # model, best_acc = train_tiny_imagenet(
1349
+ # preset='mobius_base',
1350
+ # epochs=200,
1351
+ # lr=3e-4,
1352
+ # batch_size=128,
1353
+ # use_integrator=True,
1354
+ # data_dir='./data/tiny-imagenet-200',
1355
+ # output_dir='./outputs',
1356
+ # hf_repo='AbstractPhil/mobiusnet',
1357
+ # save_every_n_epochs=10,
1358
+ # upload_every_n_epochs=10,
1359
+ # continue_from=None,
1360
+ # )
1361
+
1362
+ # Option 2: Train on ImageNet CLIP features
1363
+ model, best_acc = train_clip_features(
1364
+ preset='mobius_tiny_s',
1365
+ clip_subset='clip_vit_laion_b32', # or 'clip_vit_l14', 'clip_vit_laion_h14', etc.
1366
+ epochs=50,
1367
+ lr=1e-3,
1368
+ batch_size=256,
1369
+ use_integrator=True,
1370
+ output_dir='./outputs',
1371
+ hf_repo='AbstractPhil/mobiusnet-distillations',
1372
+ save_every_n_epochs=5,
1373
+ upload_every_n_epochs=5,
1374
+ num_workers=8,
1375
+ continue_from=None,
1376
+ )