AbstractPhil commited on
Commit
a05e552
·
verified ·
1 Parent(s): e161cd7

Create fashionmnist_trainer.py

Browse files
Files changed (1) hide show
  1. fashionmnist_trainer.py +655 -0
fashionmnist_trainer.py ADDED
@@ -0,0 +1,655 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Fashion-MNIST Trainer with MobiusCollective
3
+ ============================================
4
+
5
+ Train a wide collective of MobiusLens towers on Fashion-MNIST.
6
+ Designed for Colab with TensorBoard logging and HuggingFace upload.
7
+
8
+ License: Apache 2.0
9
+ Date: 2025-01-10
10
+ Author: AbstractPhil
11
+ """
12
+
13
+ import os
14
+ import json
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from torch import Tensor
19
+ from typing import Tuple, Dict, Any, Optional
20
+ from torchvision import datasets, transforms
21
+ from torch.utils.data import DataLoader
22
+ from torch.utils.tensorboard import SummaryWriter
23
+ from tqdm.auto import tqdm
24
+ from datetime import datetime
25
+ from pathlib import Path
26
+ from safetensors.torch import save_file as save_safetensors
27
+
28
+ # HuggingFace login for Colab
29
+ try:
30
+ from huggingface_hub import HfApi, login
31
+ from google.colab import userdata
32
+ token = userdata.get('HF_TOKEN')
33
+ os.environ['HF_TOKEN'] = token
34
+ login(token=token)
35
+ print("Logged in to HuggingFace via Colab")
36
+ HF_AVAILABLE = True
37
+ except:
38
+ HF_AVAILABLE = False
39
+ print("HuggingFace upload disabled (not in Colab or no token)")
40
+
41
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
42
+ print(f"Device: {device}")
43
+
44
+ # TF32 for Ampere+
45
+ torch.backends.cuda.matmul.allow_tf32 = True
46
+ torch.backends.cudnn.allow_tf32 = True
47
+ torch.set_float32_matmul_precision('high')
48
+
49
+
50
+ # ============================================================================
51
+ # IMPORTS FROM GEOFRACTAL
52
+ # ============================================================================
53
+
54
+ from geofractal.router.wide_router import WideRouter
55
+ from geofractal.router.base_tower import BaseTower
56
+ from geofractal.router.components.torch_component import TorchComponent
57
+ from geofractal.router.components.lens_component import MobiusLens, TriWaveLens
58
+ from geofractal.router.components.fusion_component import AdaptiveFusion
59
+
60
+
61
+ # ============================================================================
62
+ # CONV LENS BLOCK
63
+ # ============================================================================
64
+
65
+ class ConvLensBlock(TorchComponent):
66
+ """Depthwise-separable conv with MobiusLens activation."""
67
+
68
+ def __init__(
69
+ self,
70
+ name: str,
71
+ channels: int,
72
+ layer_idx: int,
73
+ total_layers: int,
74
+ scale_range: Tuple[float, float] = (0.5, 2.5),
75
+ use_mobius: bool = True,
76
+ ):
77
+ super().__init__(name)
78
+
79
+ self.conv = nn.Sequential(
80
+ nn.Conv2d(channels, channels, 3, padding=1, groups=channels, bias=False),
81
+ nn.Conv2d(channels, channels, 1, bias=False),
82
+ nn.BatchNorm2d(channels),
83
+ )
84
+
85
+ if use_mobius:
86
+ self.lens = MobiusLens(f'{name}_lens', channels, layer_idx, total_layers, scale_range)
87
+ else:
88
+ self.lens = TriWaveLens(f'{name}_lens', channels, layer_idx, total_layers, scale_range)
89
+
90
+ self.residual_weight = nn.Parameter(torch.tensor(0.9))
91
+
92
+ def forward(self, x: Tensor) -> Tensor:
93
+ identity = x
94
+ h = self.conv(x)
95
+ B, C, H, W = h.shape
96
+ h = h.permute(0, 2, 3, 1)
97
+ h = self.lens(h)
98
+ h = h.permute(0, 3, 1, 2)
99
+ rw = torch.sigmoid(self.residual_weight)
100
+ return rw * identity + (1 - rw) * h
101
+
102
+
103
+ # ============================================================================
104
+ # LENS TOWER
105
+ # ============================================================================
106
+
107
+ class LensTower(BaseTower):
108
+ """Shallow tower covering a segment of the scale continuum."""
109
+
110
+ def __init__(
111
+ self,
112
+ name: str,
113
+ channels: int,
114
+ depth: int,
115
+ tower_idx: int,
116
+ num_towers: int,
117
+ scale_range: Tuple[float, float] = (0.5, 2.5),
118
+ use_mobius: bool = True,
119
+ ):
120
+ super().__init__(name, strict=False)
121
+
122
+ self.tower_idx = tower_idx
123
+ self.channels = channels
124
+
125
+ total_layers = num_towers * depth
126
+ start_layer = tower_idx * depth
127
+
128
+ for i in range(depth):
129
+ global_idx = start_layer + i
130
+ block = ConvLensBlock(
131
+ f'{name}_block_{i}',
132
+ channels,
133
+ layer_idx=global_idx,
134
+ total_layers=total_layers,
135
+ scale_range=scale_range,
136
+ use_mobius=use_mobius,
137
+ )
138
+ self.append(block)
139
+
140
+ self.attach('norm', nn.BatchNorm2d(channels))
141
+
142
+ def forward(self, x: Tensor) -> Tensor:
143
+ for stage in self.stages:
144
+ x = stage(x)
145
+ return self['norm'](x)
146
+
147
+
148
+ # ============================================================================
149
+ # VISION ADAPTIVE FUSION (wraps AdaptiveFusion for BCHW tensors)
150
+ # ============================================================================
151
+
152
+ class VisionAdaptiveFusion(TorchComponent):
153
+ """
154
+ Wraps AdaptiveFusion for vision tensors (B, C, H, W).
155
+
156
+ Permutes to channel-last, fuses, permutes back.
157
+ """
158
+
159
+ def __init__(self, name: str, num_towers: int, channels: int):
160
+ super().__init__(name)
161
+
162
+ self.num_towers = num_towers
163
+ self.fusion = AdaptiveFusion(
164
+ f'{name}_adaptive',
165
+ num_inputs=num_towers,
166
+ in_features=channels,
167
+ )
168
+
169
+ # Output projection (conv for spatial tensors)
170
+ self.proj = nn.Sequential(
171
+ nn.Conv2d(channels, channels, 1, bias=False),
172
+ nn.BatchNorm2d(channels),
173
+ )
174
+
175
+ def forward(self, *opinions: Tensor) -> Tensor:
176
+ """
177
+ Args:
178
+ *opinions: N tensors of shape (B, C, H, W)
179
+ Returns:
180
+ Fused tensor of shape (B, C, H, W)
181
+ """
182
+ # Permute all to channel-last: (B, H, W, C)
183
+ channel_last = [op.permute(0, 2, 3, 1) for op in opinions]
184
+
185
+ # Fuse using AdaptiveFusion
186
+ fused = self.fusion(*channel_last) # (B, H, W, C)
187
+
188
+ # Permute back: (B, C, H, W)
189
+ fused = fused.permute(0, 3, 1, 2)
190
+
191
+ return self.proj(fused)
192
+
193
+
194
+ # ============================================================================
195
+ # MOBIUS COLLECTIVE
196
+ # ============================================================================
197
+
198
+ class MobiusCollective(WideRouter):
199
+ """
200
+ Wide collective with MobiusLens towers.
201
+
202
+ Architecture:
203
+ - Light stem (configurable stride)
204
+ - Multiple shallow towers in parallel (scale continuum)
205
+ - Adaptive fusion + classification head
206
+ """
207
+
208
+ def __init__(
209
+ self,
210
+ name: str = 'mobius_collective',
211
+ in_channels: int = 1,
212
+ channels: int = 64,
213
+ num_towers: int = 4,
214
+ depth_per_tower: int = 2,
215
+ scale_range: Tuple[float, float] = (0.5, 2.5),
216
+ use_mobius: bool = True,
217
+ num_classes: int = 10,
218
+ stem_stride: int = 2,
219
+ ):
220
+ super().__init__(name, auto_discover=True)
221
+
222
+ self.in_channels = in_channels
223
+ self.channels = channels
224
+ self.num_towers = num_towers
225
+ self.depth_per_tower = depth_per_tower
226
+ self.scale_range = scale_range
227
+ self.use_mobius = use_mobius
228
+ self.num_classes = num_classes
229
+ self.stem_stride = stem_stride
230
+
231
+ # Stem
232
+ self.attach('stem', nn.Sequential(
233
+ nn.Conv2d(in_channels, channels, 3, stride=stem_stride, padding=1, bias=False),
234
+ nn.BatchNorm2d(channels),
235
+ nn.ReLU(inplace=True),
236
+ ))
237
+
238
+ # Towers
239
+ for i in range(num_towers):
240
+ tower = LensTower(
241
+ f'tower_{i}',
242
+ channels=channels,
243
+ depth=depth_per_tower,
244
+ tower_idx=i,
245
+ num_towers=num_towers,
246
+ scale_range=scale_range,
247
+ use_mobius=use_mobius,
248
+ )
249
+ self.attach(f'tower_{i}', tower)
250
+
251
+ self.discover_towers()
252
+
253
+ # Fusion (wraps geofractal's AdaptiveFusion for vision tensors)
254
+ self.attach('fusion', VisionAdaptiveFusion('fusion', num_towers, channels))
255
+
256
+ # Head
257
+ self.attach('pool', nn.AdaptiveAvgPool2d(1))
258
+ self.attach('head', nn.Linear(channels, num_classes))
259
+
260
+ def forward(self, x: Tensor) -> Tensor:
261
+ x = self['stem'](x)
262
+
263
+ opinions = self.wide_forward(x)
264
+ opinion_list = [opinions[f'tower_{i}'] for i in range(self.num_towers)]
265
+
266
+ fused = self['fusion'](*opinion_list)
267
+ fused = self['pool'](fused).flatten(1)
268
+
269
+ return self['head'](fused)
270
+
271
+ def get_config(self) -> Dict[str, Any]:
272
+ return {
273
+ 'in_channels': self.in_channels,
274
+ 'channels': self.channels,
275
+ 'num_towers': self.num_towers,
276
+ 'depth_per_tower': self.depth_per_tower,
277
+ 'scale_range': self.scale_range,
278
+ 'use_mobius': self.use_mobius,
279
+ 'num_classes': self.num_classes,
280
+ 'stem_stride': self.stem_stride,
281
+ }
282
+
283
+ def get_all_lens_stats(self) -> Dict[str, Dict[str, float]]:
284
+ """Return stats from all lenses for logging."""
285
+ stats = {}
286
+ for tower_name in self.tower_names:
287
+ tower = self[tower_name]
288
+ for i, stage in enumerate(tower.stages):
289
+ key = f"{tower_name}_block_{i}"
290
+ stats[key] = stage.lens.get_lens_stats()
291
+ return stats
292
+
293
+
294
+ # ============================================================================
295
+ # PRESETS
296
+ # ============================================================================
297
+
298
+ PRESETS = {
299
+ 'fashion_mobius_tiny': {
300
+ 'channels': 32,
301
+ 'num_towers': 3,
302
+ 'depth_per_tower': 2,
303
+ 'scale_range': (0.5, 2.0),
304
+ 'use_mobius': True,
305
+ },
306
+ 'fashion_mobius_small': {
307
+ 'channels': 64,
308
+ 'num_towers': 4,
309
+ 'depth_per_tower': 2,
310
+ 'scale_range': (0.5, 2.5),
311
+ 'use_mobius': True,
312
+ },
313
+ 'fashion_mobius_base': {
314
+ 'channels': 96,
315
+ 'num_towers': 4,
316
+ 'depth_per_tower': 3,
317
+ 'scale_range': (0.25, 2.75),
318
+ 'use_mobius': True,
319
+ },
320
+ 'fashion_tri_small': {
321
+ 'channels': 64,
322
+ 'num_towers': 4,
323
+ 'depth_per_tower': 2,
324
+ 'scale_range': (0.5, 2.5),
325
+ 'use_mobius': False,
326
+ },
327
+ }
328
+
329
+
330
+ # ============================================================================
331
+ # DATA
332
+ # ============================================================================
333
+
334
+ def get_fashion_mnist_loaders(data_dir: str = './data', batch_size: int = 128):
335
+ """Get Fashion-MNIST train/val loaders with augmentation."""
336
+
337
+ train_transform = transforms.Compose([
338
+ transforms.RandomCrop(28, padding=4),
339
+ transforms.RandomHorizontalFlip(),
340
+ transforms.ToTensor(),
341
+ transforms.Normalize((0.2860,), (0.3530,)),
342
+ ])
343
+
344
+ val_transform = transforms.Compose([
345
+ transforms.ToTensor(),
346
+ transforms.Normalize((0.2860,), (0.3530,)),
347
+ ])
348
+
349
+ train_dataset = datasets.FashionMNIST(
350
+ data_dir, train=True, download=True, transform=train_transform
351
+ )
352
+ val_dataset = datasets.FashionMNIST(
353
+ data_dir, train=False, download=True, transform=val_transform
354
+ )
355
+
356
+ train_loader = DataLoader(
357
+ train_dataset, batch_size=batch_size, shuffle=True,
358
+ num_workers=4, pin_memory=True, persistent_workers=True
359
+ )
360
+ val_loader = DataLoader(
361
+ val_dataset, batch_size=256, shuffle=False,
362
+ num_workers=2, pin_memory=True, persistent_workers=True
363
+ )
364
+
365
+ return train_loader, val_loader
366
+
367
+
368
+ # ============================================================================
369
+ # CHECKPOINT MANAGER
370
+ # ============================================================================
371
+
372
+ class CheckpointManager:
373
+ """Handles saving, logging, and optional HF upload."""
374
+
375
+ def __init__(
376
+ self,
377
+ output_dir: str,
378
+ experiment_name: str,
379
+ hf_repo: Optional[str] = None,
380
+ save_every: int = 10,
381
+ upload_every: int = 20,
382
+ ):
383
+ self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
384
+ self.experiment_name = experiment_name
385
+ self.hf_repo = hf_repo
386
+ self.save_every = save_every
387
+ self.upload_every = upload_every
388
+
389
+ self.run_dir = Path(output_dir) / experiment_name / self.timestamp
390
+ self.ckpt_dir = self.run_dir / "checkpoints"
391
+ self.tb_dir = self.run_dir / "tensorboard"
392
+
393
+ self.ckpt_dir.mkdir(parents=True, exist_ok=True)
394
+ self.tb_dir.mkdir(parents=True, exist_ok=True)
395
+
396
+ self.writer = SummaryWriter(log_dir=str(self.tb_dir))
397
+ self.hf_api = HfApi() if HF_AVAILABLE and hf_repo else None
398
+
399
+ self.best_acc = 0.0
400
+ self.best_epoch = 0
401
+
402
+ print(f"Checkpoints: {self.run_dir}")
403
+
404
+ def save_config(self, model_config: Dict, train_config: Dict):
405
+ config = {
406
+ 'model': model_config,
407
+ 'training': train_config,
408
+ 'timestamp': self.timestamp,
409
+ }
410
+ with open(self.run_dir / "config.json", 'w') as f:
411
+ json.dump(config, f, indent=2)
412
+
413
+ def log_scalars(self, epoch: int, scalars: Dict[str, float], prefix: str = ""):
414
+ for name, value in scalars.items():
415
+ tag = f"{prefix}/{name}" if prefix else name
416
+ self.writer.add_scalar(tag, value, epoch)
417
+
418
+ def log_lens_stats(self, epoch: int, model: nn.Module):
419
+ raw = model._orig_mod if hasattr(model, '_orig_mod') else model
420
+ stats = raw.get_all_lens_stats()
421
+ for block_name, block_stats in stats.items():
422
+ for stat_name, value in block_stats.items():
423
+ if isinstance(value, (int, float)):
424
+ self.writer.add_scalar(f"lens/{block_name}/{stat_name}", value, epoch)
425
+
426
+ def save_checkpoint(
427
+ self,
428
+ model: nn.Module,
429
+ optimizer: torch.optim.Optimizer,
430
+ scheduler,
431
+ epoch: int,
432
+ train_acc: float,
433
+ val_acc: float,
434
+ train_loss: float,
435
+ ):
436
+ raw = model._orig_mod if hasattr(model, '_orig_mod') else model
437
+ is_best = val_acc > self.best_acc
438
+
439
+ if is_best:
440
+ self.best_acc = val_acc
441
+ self.best_epoch = epoch
442
+
443
+ # Save best
444
+ save_safetensors(raw.state_dict(), str(self.ckpt_dir / "best_model.safetensors"))
445
+ torch.save({
446
+ 'epoch': epoch,
447
+ 'model_state_dict': raw.state_dict(),
448
+ 'optimizer_state_dict': optimizer.state_dict(),
449
+ 'scheduler_state_dict': scheduler.state_dict(),
450
+ 'best_acc': self.best_acc,
451
+ 'train_acc': train_acc,
452
+ 'val_acc': val_acc,
453
+ }, self.ckpt_dir / "best_model.pt")
454
+
455
+ # Periodic save
456
+ if epoch % self.save_every == 0:
457
+ save_safetensors(raw.state_dict(), str(self.ckpt_dir / f"epoch_{epoch:04d}.safetensors"))
458
+
459
+ def upload(self, epoch: int, force: bool = False):
460
+ if not self.hf_api or not self.hf_repo:
461
+ return
462
+ if not force and epoch % self.upload_every != 0:
463
+ return
464
+
465
+ try:
466
+ hf_path = f"fashion_mnist/{self.experiment_name}/{self.timestamp}"
467
+
468
+ for f in [self.run_dir / "config.json", self.ckpt_dir / "best_model.safetensors"]:
469
+ if f.exists():
470
+ self.hf_api.upload_file(
471
+ path_or_fileobj=str(f),
472
+ path_in_repo=f"{hf_path}/{f.name}",
473
+ repo_id=self.hf_repo,
474
+ repo_type="model",
475
+ )
476
+ print(f"Uploaded to {self.hf_repo}/{hf_path}")
477
+ except Exception as e:
478
+ print(f"Upload failed: {e}")
479
+
480
+ def close(self):
481
+ self.writer.close()
482
+
483
+
484
+ # ============================================================================
485
+ # TRAINING
486
+ # ============================================================================
487
+
488
+ def train_fashion_mnist(
489
+ preset: str = 'fashion_mobius_small',
490
+ epochs: int = 50,
491
+ lr: float = 1e-3,
492
+ batch_size: int = 128,
493
+ output_dir: str = './outputs',
494
+ hf_repo: Optional[str] = 'AbstractPhil/mobiusnet-collective',
495
+ use_compile: bool = True,
496
+ save_every: int = 10,
497
+ upload_every: int = 20,
498
+ ):
499
+ """Train MobiusCollective on Fashion-MNIST."""
500
+
501
+ config = PRESETS[preset]
502
+
503
+ print("=" * 70)
504
+ print(f"FASHION-MNIST - {preset.upper()}")
505
+ print("=" * 70)
506
+ print(f"Channels: {config['channels']}")
507
+ print(f"Towers: {config['num_towers']} x {config['depth_per_tower']} depth")
508
+ print(f"Scale range: {config['scale_range']}")
509
+ print(f"Lens: {'Mobius' if config['use_mobius'] else 'TriWave'}")
510
+ print()
511
+
512
+ # Data
513
+ train_loader, val_loader = get_fashion_mnist_loaders('./data', batch_size)
514
+
515
+ # Model
516
+ model = MobiusCollective(
517
+ name=preset,
518
+ in_channels=1, # Fashion-MNIST is grayscale
519
+ num_classes=10,
520
+ stem_stride=2, # 28x28 -> 14x14
521
+ **config,
522
+ ).to(device)
523
+
524
+ total_params = sum(p.numel() for p in model.parameters())
525
+ print(f"Total params: {total_params:,}")
526
+
527
+ # Checkpoint manager
528
+ ckpt = CheckpointManager(
529
+ output_dir=output_dir,
530
+ experiment_name=preset,
531
+ hf_repo=hf_repo,
532
+ save_every=save_every,
533
+ upload_every=upload_every,
534
+ )
535
+
536
+ # Save config
537
+ train_config = {
538
+ 'epochs': epochs,
539
+ 'lr': lr,
540
+ 'batch_size': batch_size,
541
+ 'optimizer': 'AdamW',
542
+ 'scheduler': 'CosineAnnealingLR',
543
+ 'total_params': total_params,
544
+ }
545
+ ckpt.save_config(model.get_config(), train_config)
546
+
547
+ # Compile
548
+ if use_compile and hasattr(torch, 'compile'):
549
+ print("Compiling model...")
550
+ model = torch.compile(model, mode='reduce-overhead')
551
+
552
+ # Optimizer
553
+ optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.05)
554
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
555
+
556
+ best_acc = 0.0
557
+
558
+ for epoch in range(1, epochs + 1):
559
+ # Train
560
+ model.train()
561
+ train_loss, train_correct, train_total = 0, 0, 0
562
+
563
+ pbar = tqdm(train_loader, desc=f"Epoch {epoch:3d}")
564
+ for x, y in pbar:
565
+ x, y = x.to(device), y.to(device)
566
+
567
+ optimizer.zero_grad()
568
+ logits = model(x)
569
+ loss = F.cross_entropy(logits, y)
570
+ loss.backward()
571
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
572
+ optimizer.step()
573
+
574
+ train_loss += loss.item() * x.size(0)
575
+ train_correct += (logits.argmax(1) == y).sum().item()
576
+ train_total += x.size(0)
577
+
578
+ pbar.set_postfix(loss=f"{loss.item():.4f}")
579
+
580
+ scheduler.step()
581
+
582
+ # Validate
583
+ model.eval()
584
+ val_correct, val_total = 0, 0
585
+ with torch.no_grad():
586
+ for x, y in val_loader:
587
+ x, y = x.to(device), y.to(device)
588
+ logits = model(x)
589
+ val_correct += (logits.argmax(1) == y).sum().item()
590
+ val_total += x.size(0)
591
+
592
+ # Metrics
593
+ train_acc = train_correct / train_total
594
+ val_acc = val_correct / val_total
595
+ avg_loss = train_loss / train_total
596
+ current_lr = scheduler.get_last_lr()[0]
597
+
598
+ is_best = val_acc > best_acc
599
+ if is_best:
600
+ best_acc = val_acc
601
+
602
+ marker = " ★" if is_best else ""
603
+ print(f"Epoch {epoch:3d} | Loss: {avg_loss:.4f} | "
604
+ f"Train: {train_acc:.4f} | Val: {val_acc:.4f} | Best: {best_acc:.4f}{marker}")
605
+
606
+ # Logging
607
+ ckpt.log_scalars(epoch, {
608
+ 'loss': avg_loss,
609
+ 'train_acc': train_acc,
610
+ 'val_acc': val_acc,
611
+ 'best_acc': best_acc,
612
+ 'lr': current_lr,
613
+ }, prefix='train')
614
+
615
+ ckpt.log_lens_stats(epoch, model)
616
+
617
+ # Save
618
+ ckpt.save_checkpoint(model, optimizer, scheduler, epoch, train_acc, val_acc, avg_loss)
619
+
620
+ # Upload
621
+ ckpt.upload(epoch)
622
+
623
+ # Final upload
624
+ ckpt.upload(epochs, force=True)
625
+ ckpt.close()
626
+
627
+ print()
628
+ print("=" * 70)
629
+ print("TRAINING COMPLETE")
630
+ print("=" * 70)
631
+ print(f"Preset: {preset}")
632
+ print(f"Best accuracy: {best_acc:.4f}")
633
+ print(f"Params: {total_params:,}")
634
+ print(f"Checkpoints: {ckpt.run_dir}")
635
+ print("=" * 70)
636
+
637
+ return model, best_acc
638
+
639
+
640
+ # ============================================================================
641
+ # MAIN
642
+ # ============================================================================
643
+
644
+ if __name__ == '__main__':
645
+ model, best_acc = train_fashion_mnist(
646
+ preset='fashion_mobius_small',
647
+ epochs=50,
648
+ lr=1e-3,
649
+ batch_size=128,
650
+ output_dir='./outputs',
651
+ hf_repo='AbstractPhil/mobiusnet-collective', # Set to None to disable upload
652
+ use_compile=True,
653
+ save_every=10,
654
+ upload_every=20,
655
+ )