AbstractPhil commited on
Commit
034ffc8
Β·
verified Β·
1 Parent(s): c017c1a

Create trainer_v5_alpha_cutmix.py

Browse files
Files changed (1) hide show
  1. trainer_v5_alpha_cutmix.py +1800 -0
trainer_v5_alpha_cutmix.py ADDED
@@ -0,0 +1,1800 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # train_cantor_fusion_hf.py - PRODUCTION WITH ADAMW + WARM RESTARTS + LR BOOST
2
+
3
+ """
4
+ Cantor Fusion Classifier with AdamW + Cosine Warm Restarts + LR Boost
5
+ ----------------------------------------------------------------------
6
+ Features:
7
+ - AdamW optimizer (best for ViTs)
8
+ - CosineAnnealingWarmRestarts with configurable LR boost at restarts
9
+ - restart_lr_mult: Multiply LR at restart points for aggressive exploration
10
+ - HuggingFace Hub uploads (ONE shared repo, organized by run)
11
+ - TensorBoard logging (loss, accuracy, fusion metrics, LR tracking)
12
+ - Easy CIFAR-10/100 switching
13
+ - Automatic checkpoint management
14
+ - SafeTensors format (ClamAV safe)
15
+
16
+ New Feature: restart_lr_mult
17
+ When restart_lr_mult > 1.0, learning rate at restart is BOOSTED:
18
+ - Normal: 3e-4 β†’ 1e-7 β†’ restart at 3e-4
19
+ - Boosted (1.5x): 3e-4 β†’ 1e-7 β†’ restart at 4.5e-4 β†’ 1e-7
20
+ - Creates wider exploration curves to escape solidified local minima
21
+
22
+ Author: AbstractPhil
23
+ License: MIT
24
+ """
25
+
26
+ import torch
27
+ import torch.nn as nn
28
+ import torch.nn.functional as F
29
+ from torch.utils.data import DataLoader
30
+ from torch.utils.tensorboard import SummaryWriter
31
+ from torchvision import datasets, transforms
32
+ from torch.cuda.amp import autocast, GradScaler
33
+ from safetensors.torch import save_file, load_file
34
+
35
+ import math
36
+ import os
37
+ import json
38
+ from typing import Optional, Dict, List, Tuple, Union
39
+ from dataclasses import dataclass, asdict
40
+ import time
41
+ from pathlib import Path
42
+ from tqdm import tqdm
43
+
44
+ # HuggingFace
45
+ from huggingface_hub import HfApi, create_repo, upload_folder, upload_file
46
+ import yaml
47
+
48
+ # Import from your repo
49
+ from geovocab2.train.model.layers.attention.cantor_multiheaded_fusion import (
50
+ CantorMultiheadFusion,
51
+ CantorFusionConfig
52
+ )
53
+ from geovocab2.shapes.factory.cantor_route_factory import (
54
+ CantorRouteFactory,
55
+ RouteMode,
56
+ SimplexConfig
57
+ )
58
+
59
+
60
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
61
+ # Mixing Augmentations (AlphaMix / Fractal AlphaMix)
62
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
63
+
64
+ def alphamix_data(x, y, alpha_range=(0.3, 0.7), spatial_ratio=0.25):
65
+ """
66
+ Standard AlphaMix: Single spatially localized transparent overlay.
67
+
68
+ Args:
69
+ x: Input images [B, C, H, W]
70
+ y: Labels [B]
71
+ alpha_range: Range for transparency sampling
72
+ spatial_ratio: Ratio of image area to overlay
73
+
74
+ Returns:
75
+ composited_x: Mixed images
76
+ y_a: Original labels
77
+ y_b: Mixed labels
78
+ alpha: Effective mixing coefficient
79
+ """
80
+ batch_size = x.size(0)
81
+ index = torch.randperm(batch_size, device=x.device)
82
+
83
+ y_a, y_b = y, y[index]
84
+
85
+ # Sample alpha from Beta distribution
86
+ alpha_min, alpha_max = alpha_range
87
+ beta_sample = torch.distributions.Beta(2.0, 2.0).sample().item()
88
+ alpha = alpha_min + (alpha_max - alpha_min) * beta_sample
89
+
90
+ # Compute overlay region
91
+ _, _, H, W = x.shape
92
+ overlay_ratio = torch.sqrt(torch.tensor(spatial_ratio)).item()
93
+ overlay_h = int(H * overlay_ratio)
94
+ overlay_w = int(W * overlay_ratio)
95
+
96
+ top = torch.randint(0, H - overlay_h + 1, (1,), device=x.device).item()
97
+ left = torch.randint(0, W - overlay_w + 1, (1,), device=x.device).item()
98
+
99
+ # Blend
100
+ composited_x = x.clone()
101
+ overlay_region = alpha * x[:, :, top:top+overlay_h, left:left+overlay_w]
102
+ background_region = (1 - alpha) * x[index, :, top:top+overlay_h, left:left+overlay_w]
103
+ composited_x[:, :, top:top+overlay_h, left:left+overlay_w] = overlay_region + background_region
104
+
105
+ return composited_x, y_a, y_b, alpha
106
+
107
+
108
+ def alphamix_fractal(
109
+ x: torch.Tensor,
110
+ y: torch.Tensor,
111
+ alpha_range=(0.3, 0.7),
112
+ steps_range=(1, 3),
113
+ triad_scales=(1/3, 1/9, 1/27),
114
+ beta_shape=(2.0, 2.0),
115
+ seed: Optional[int] = None,
116
+ ):
117
+ """
118
+ Fractal AlphaMix: Triadic multi-patch overlays aligned to Cantor geometry.
119
+ Pure torch, GPU-compatible.
120
+
121
+ Args:
122
+ x: Input images [B, C, H, W]
123
+ y: Labels [B]
124
+ alpha_range: Range for transparency sampling
125
+ steps_range: Range for number of patches to apply
126
+ triad_scales: Triadic scales (1/3, 1/9, 1/27 for Cantor-like)
127
+ beta_shape: Beta distribution parameters for sampling
128
+ seed: Optional random seed
129
+
130
+ Returns:
131
+ x_mix: Mixed images
132
+ y_a: Original labels
133
+ y_b: Mixed labels
134
+ alpha_eff: Effective area-weighted mixing coefficient
135
+ """
136
+ if seed is not None:
137
+ torch.manual_seed(seed)
138
+
139
+ B, C, H, W = x.shape
140
+ device = x.device
141
+
142
+ # Permutation for mixing
143
+ idx = torch.randperm(B, device=device)
144
+ y_a, y_b = y, y[idx]
145
+
146
+ x_mix = x.clone()
147
+ total_area = H * W
148
+
149
+ # Beta distribution for transparency sampling
150
+ k1, k2 = beta_shape
151
+ beta_dist = torch.distributions.Beta(k1, k2)
152
+ alpha_min, alpha_max = alpha_range
153
+
154
+ # Storage for effective alpha calculation
155
+ alpha_elems = []
156
+ area_weights = []
157
+
158
+ # Sample number of patches (same for all images in batch)
159
+ steps = torch.randint(steps_range[0], steps_range[1] + 1, (1,), device=device).item()
160
+
161
+ for _ in range(steps):
162
+ # Choose triadic scale
163
+ scale_idx = torch.randint(0, len(triad_scales), (1,), device=device).item()
164
+ scale = triad_scales[scale_idx]
165
+
166
+ # Compute patch dimensions (triadic area)
167
+ patch_area = max(1, int(total_area * scale))
168
+ side = int(torch.sqrt(torch.tensor(patch_area, dtype=torch.float32)).item())
169
+ h = max(1, min(H, side))
170
+ w = max(1, min(W, side))
171
+
172
+ # Random position
173
+ top = torch.randint(0, H - h + 1, (1,), device=device).item()
174
+ left = torch.randint(0, W - w + 1, (1,), device=device).item()
175
+
176
+ # Sample transparency from Beta distribution
177
+ alpha_raw = beta_dist.sample().item()
178
+ alpha = alpha_min + (alpha_max - alpha_min) * alpha_raw
179
+
180
+ # Track for effective alpha
181
+ alpha_elems.append(alpha)
182
+ area_weights.append(h * w)
183
+
184
+ # Blend patches
185
+ fg = alpha * x[:, :, top:top + h, left:left + w]
186
+ bg = (1 - alpha) * x[idx, :, top:top + h, left:left + w]
187
+ x_mix[:, :, top:top + h, left:left + w] = fg + bg
188
+
189
+ # Compute area-weighted effective alpha
190
+ alpha_t = torch.tensor(alpha_elems, dtype=torch.float32, device=device)
191
+ area_t = torch.tensor(area_weights, dtype=torch.float32, device=device)
192
+ alpha_eff = (alpha_t * area_t).sum() / (area_t.sum() + 1e-12)
193
+ alpha_eff = alpha_eff.item()
194
+
195
+ return x_mix, y_a, y_b, alpha_eff
196
+
197
+
198
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
199
+ # Custom Scheduler with LR Boost at Restarts
200
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
201
+
202
+ class CosineAnnealingWarmRestartsWithBoost(torch.optim.lr_scheduler._LRScheduler):
203
+ """
204
+ Cosine Annealing with Warm Restarts and optional LR boost at restart points.
205
+
206
+ At each restart, the max LR is multiplied by `restart_lr_mult`, creating
207
+ wider exploration curves that can help escape solidified local minima.
208
+
209
+ Args:
210
+ optimizer: Wrapped optimizer
211
+ T_0: Number of iterations for the first restart
212
+ T_mult: Factor to increase T_i after each restart (default: 1)
213
+ eta_min: Minimum learning rate (default: 0)
214
+ restart_lr_mult: Multiply max LR by this at each restart (default: 1.0)
215
+ Values > 1.0 create boosted exploration cycles
216
+ last_epoch: The index of last epoch (default: -1)
217
+
218
+ Example:
219
+ >>> scheduler = CosineAnnealingWarmRestartsWithBoost(
220
+ ... optimizer, T_0=50, T_mult=2, restart_lr_mult=1.5
221
+ ... )
222
+ # Cycle 1: 3e-4 β†’ 1e-7 (50 epochs)
223
+ # Restart: LR jumps to 4.5e-4 (1.5x boost)
224
+ # Cycle 2: 4.5e-4 β†’ 1e-7 (100 epochs)
225
+ # Restart: LR jumps to 6.75e-4 (1.5x boost again)
226
+ # Cycle 3: 6.75e-4 β†’ 1e-7 (200 epochs)
227
+ """
228
+
229
+ def __init__(
230
+ self,
231
+ optimizer: torch.optim.Optimizer,
232
+ T_0: int,
233
+ T_mult: float = 1,
234
+ eta_min: float = 0,
235
+ restart_lr_mult: float = 1.0,
236
+ last_epoch: int = -1
237
+ ):
238
+ if T_0 <= 0 or not isinstance(T_0, int):
239
+ raise ValueError(f"Expected positive integer T_0, but got {T_0}")
240
+ if T_mult < 1:
241
+ raise ValueError(f"Expected T_mult >= 1, but got {T_mult}")
242
+ if restart_lr_mult <= 0:
243
+ raise ValueError(f"Expected positive restart_lr_mult, but got {restart_lr_mult}")
244
+
245
+ self.T_0 = T_0
246
+ self.T_i = T_0
247
+ self.T_mult = T_mult
248
+ self.eta_min = eta_min
249
+ self.restart_lr_mult = restart_lr_mult
250
+ self.T_cur = last_epoch
251
+
252
+ # Track boosted base LRs and restart count
253
+ self.current_base_lrs = None
254
+ self.restart_count = 0
255
+
256
+ super().__init__(optimizer, last_epoch)
257
+
258
+ def get_lr(self):
259
+ if self.T_cur == -1:
260
+ # First step - return base LRs
261
+ return self.base_lrs
262
+
263
+ # Use boosted base LRs if we've had restarts
264
+ if self.current_base_lrs is None:
265
+ base_lrs_to_use = self.base_lrs
266
+ else:
267
+ base_lrs_to_use = self.current_base_lrs
268
+
269
+ # Cosine annealing from current base LR to eta_min
270
+ return [
271
+ self.eta_min + (base_lr - self.eta_min) *
272
+ (1 + math.cos(math.pi * self.T_cur / self.T_i)) / 2
273
+ for base_lr in base_lrs_to_use
274
+ ]
275
+
276
+ def step(self, epoch=None):
277
+ if epoch is None and self.last_epoch < 0:
278
+ epoch = 0
279
+
280
+ if epoch is None:
281
+ epoch = self.last_epoch + 1
282
+ self.T_cur = self.T_cur + 1
283
+
284
+ # Check if we hit a restart point
285
+ if self.T_cur >= self.T_i:
286
+ # APPLY BOOST HERE before reset
287
+ self.restart_count += 1
288
+ if self.current_base_lrs is None:
289
+ self.current_base_lrs = list(self.base_lrs)
290
+
291
+ # Boost the base LRs
292
+ self.current_base_lrs = [
293
+ base_lr * self.restart_lr_mult
294
+ for base_lr in self.current_base_lrs
295
+ ]
296
+
297
+ # Now reset cycle
298
+ self.T_cur = self.T_cur - self.T_i
299
+ self.T_i = int(self.T_i * self.T_mult)
300
+ else:
301
+ if epoch < 0:
302
+ raise ValueError(f"Expected non-negative epoch, but got {epoch}")
303
+ if epoch >= self.T_0:
304
+ if self.T_mult == 1:
305
+ self.T_cur = epoch % self.T_0
306
+ # Count how many restarts have occurred
307
+ self.restart_count = epoch // self.T_0
308
+ else:
309
+ n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult))
310
+ self.restart_count = n
311
+ self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1)
312
+ self.T_i = self.T_0 * self.T_mult ** n
313
+
314
+ # Apply cumulative boost
315
+ if self.current_base_lrs is None:
316
+ self.current_base_lrs = [
317
+ base_lr * (self.restart_lr_mult ** self.restart_count)
318
+ for base_lr in self.base_lrs
319
+ ]
320
+ else:
321
+ self.T_i = self.T_0
322
+ self.T_cur = epoch
323
+
324
+ self.last_epoch = math.floor(epoch)
325
+
326
+ for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
327
+ param_group['lr'] = lr
328
+
329
+ self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
330
+
331
+
332
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
333
+ # Configuration
334
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
335
+
336
+ @dataclass
337
+ class CantorTrainingConfig:
338
+ """Complete configuration for Cantor fusion training with AdamW + Warm Restarts."""
339
+
340
+ # Dataset
341
+ dataset: str = "cifar10" # "cifar10" or "cifar100"
342
+ num_classes: int = 10
343
+
344
+ # Architecture
345
+ image_size: int = 32
346
+ patch_size: int = 4
347
+ embed_dim: int = 384
348
+ num_fusion_blocks: int = 6
349
+ num_heads: int = 8
350
+ fusion_window: int = 32
351
+ fusion_mode: str = "weighted" # "weighted" or "consciousness"
352
+ k_simplex: int = 4
353
+ use_beatrix: bool = False
354
+ beatrix_tau: float = 0.25
355
+
356
+ # Optimization
357
+ precompute_geometric: bool = True
358
+ use_torch_compile: bool = True
359
+ use_mixed_precision: bool = False
360
+
361
+ # Regularization
362
+ dropout: float = 0.1
363
+ drop_path_rate: float = 0.1
364
+ label_smoothing: float = 0.1
365
+
366
+ # Training - Optimizer (AdamW)
367
+ optimizer_type: str = "adamw" # "sgd" or "adamw"
368
+ batch_size: int = 128
369
+ num_epochs: int = 300
370
+ learning_rate: float = 3e-4 # AdamW default
371
+ weight_decay: float = 0.05
372
+ grad_clip: float = 1.0
373
+
374
+ # SGD-specific (if needed)
375
+ sgd_momentum: float = 0.9
376
+ sgd_nesterov: bool = True
377
+
378
+ # AdamW-specific
379
+ adamw_betas: Tuple[float, float] = (0.9, 0.999)
380
+ adamw_eps: float = 1e-8
381
+
382
+ # Learning rate schedule - WARM RESTARTS WITH BOOST
383
+ scheduler_type: str = "cosine_restarts" # "multistep", "cosine", "cosine_restarts"
384
+
385
+ # CosineAnnealingWarmRestarts parameters
386
+ restart_period: int = 50 # T_0: epochs until first restart
387
+ restart_mult: float = 2.0 # T_mult: multiply period after each restart (can be float like 1.5)
388
+ restart_lr_mult: float = 1.0 # NEW: LR multiplier at restarts (>1.0 for boosted exploration)
389
+ min_lr: float = 1e-7 # eta_min: minimum learning rate
390
+
391
+ # MultiStepLR (for SGD fallback)
392
+ lr_milestones: List[int] = None
393
+ lr_gamma: float = 0.2
394
+
395
+ # Cosine annealing (regular, no restarts)
396
+ warmup_epochs: int = 0
397
+
398
+ # Data augmentation
399
+ use_augmentation: bool = True
400
+ use_autoaugment: bool = True
401
+ use_cutout: bool = False
402
+ cutout_length: int = 16
403
+
404
+ # Mixing augmentation (AlphaMix / Fractal AlphaMix)
405
+ use_mixing: bool = False
406
+ mixing_type: str = "alphamix" # "alphamix" or "fractal"
407
+ mixing_alpha_range: Tuple[float, float] = (0.3, 0.7)
408
+ mixing_spatial_ratio: float = 0.25 # For standard alphamix
409
+ mixing_prob: float = 1.0 # Probability of applying mixing
410
+ # Fractal AlphaMix specific
411
+ fractal_steps_range: Tuple[int, int] = (1, 3)
412
+ fractal_triad_scales: Tuple[float, ...] = (1/3, 1/9, 1/27)
413
+
414
+ # System
415
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
416
+ num_workers: int = 8
417
+ seed: int = 42
418
+
419
+ # Paths
420
+ weights_dir: str = "weights"
421
+ model_name: str = "vit-beans-v3"
422
+ run_name: Optional[str] = None # Auto-generated if None
423
+
424
+ # HuggingFace - ONE SHARED REPO
425
+ hf_username: str = "AbstractPhil"
426
+ hf_repo_name: Optional[str] = None
427
+ upload_to_hf: bool = True
428
+ hf_token: Optional[str] = None
429
+
430
+ # Logging
431
+ log_interval: int = 50
432
+ save_interval: int = 10
433
+ checkpoint_upload_interval: int = 20
434
+
435
+ def __post_init__(self):
436
+ # Auto-set num_classes based on dataset
437
+ if self.dataset == "cifar10":
438
+ self.num_classes = 10
439
+ elif self.dataset == "cifar100":
440
+ self.num_classes = 100
441
+ else:
442
+ raise ValueError(f"Unknown dataset: {self.dataset}")
443
+
444
+ # Set default milestones if None (for multistep fallback)
445
+ if self.lr_milestones is None:
446
+ if self.num_epochs >= 200:
447
+ self.lr_milestones = [60, 120, 160]
448
+ elif self.num_epochs >= 100:
449
+ self.lr_milestones = [30, 60, 80]
450
+ else:
451
+ self.lr_milestones = [
452
+ int(self.num_epochs * 0.5),
453
+ int(self.num_epochs * 0.75)
454
+ ]
455
+
456
+ # Auto-generate run name
457
+ if self.run_name is None:
458
+ timestamp = time.strftime("%Y%m%d_%H%M%S")
459
+ opt_name = self.optimizer_type.upper()
460
+ sched_name = "WarmRestart" if self.scheduler_type == "cosine_restarts" else self.scheduler_type
461
+ boost_str = f"_boost{self.restart_lr_mult}x" if self.restart_lr_mult > 1.0 else ""
462
+ self.run_name = f"{self.dataset}_{self.fusion_mode}_{opt_name}_{sched_name}{boost_str}_{timestamp}"
463
+
464
+ # ONE SHARED REPO for all runs
465
+ if self.hf_repo_name is None:
466
+ self.hf_repo_name = self.model_name
467
+
468
+ # Set HF token from environment if not provided
469
+ if self.hf_token is None:
470
+ self.hf_token = os.environ.get("HF_TOKEN")
471
+
472
+ # Calculate derived values
473
+ assert self.image_size % self.patch_size == 0
474
+ self.num_patches = (self.image_size // self.patch_size) ** 2
475
+ self.patch_dim = self.patch_size * self.patch_size * 3
476
+
477
+ # Create paths
478
+ self.output_dir = Path(self.weights_dir) / self.model_name / self.run_name
479
+ self.checkpoint_dir = self.output_dir / "checkpoints"
480
+ self.tensorboard_dir = self.output_dir / "tensorboard"
481
+
482
+ # Create directories
483
+ self.output_dir.mkdir(parents=True, exist_ok=True)
484
+ self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
485
+ self.tensorboard_dir.mkdir(parents=True, exist_ok=True)
486
+
487
+ def save(self, path: Union[str, Path]):
488
+ """Save config to YAML file."""
489
+ path = Path(path)
490
+ config_dict = asdict(self)
491
+ # Convert tuples to lists for YAML
492
+ if 'adamw_betas' in config_dict:
493
+ config_dict['adamw_betas'] = list(config_dict['adamw_betas'])
494
+ with open(path, 'w') as f:
495
+ yaml.dump(config_dict, f, default_flow_style=False)
496
+
497
+ @classmethod
498
+ def load(cls, path: Union[str, Path]):
499
+ """Load config from YAML file."""
500
+ path = Path(path)
501
+ with open(path, 'r') as f:
502
+ config_dict = yaml.safe_load(f)
503
+ # Convert lists back to tuples
504
+ if 'adamw_betas' in config_dict:
505
+ config_dict['adamw_betas'] = tuple(config_dict['adamw_betas'])
506
+ return cls(**config_dict)
507
+
508
+
509
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
510
+ # Model Components (unchanged from previous version)
511
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
512
+
513
+ class PatchEmbedding(nn.Module):
514
+ """Patch embedding layer."""
515
+ def __init__(self, config: CantorTrainingConfig):
516
+ super().__init__()
517
+ self.config = config
518
+ self.proj = nn.Conv2d(3, config.embed_dim, kernel_size=config.patch_size, stride=config.patch_size)
519
+ self.pos_embed = nn.Parameter(torch.randn(1, config.num_patches, config.embed_dim) * 0.02)
520
+
521
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
522
+ x = self.proj(x)
523
+ x = x.flatten(2).transpose(1, 2)
524
+ x = x + self.pos_embed
525
+ return x
526
+
527
+
528
+ class DropPath(nn.Module):
529
+ """Stochastic depth."""
530
+ def __init__(self, drop_prob: float = 0.0):
531
+ super().__init__()
532
+ self.drop_prob = drop_prob
533
+
534
+ def forward(self, x):
535
+ if self.drop_prob == 0. or not self.training:
536
+ return x
537
+ keep_prob = 1 - self.drop_prob
538
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
539
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
540
+ random_tensor.floor_()
541
+ return x.div(keep_prob) * random_tensor
542
+
543
+
544
+ class CantorFusionBlock(nn.Module):
545
+ """Cantor fusion block."""
546
+ def __init__(self, config: CantorTrainingConfig, drop_path: float = 0.0):
547
+ super().__init__()
548
+ self.norm1 = nn.LayerNorm(config.embed_dim)
549
+
550
+ fusion_config = CantorFusionConfig(
551
+ dim=config.embed_dim,
552
+ num_heads=config.num_heads,
553
+ fusion_window=config.fusion_window,
554
+ fusion_mode=config.fusion_mode,
555
+ k_simplex=config.k_simplex,
556
+ use_beatrix_routing=config.use_beatrix,
557
+ use_consciousness_weighting=(config.fusion_mode == "consciousness"),
558
+ beatrix_tau=config.beatrix_tau,
559
+ use_gating=True,
560
+ dropout=config.dropout,
561
+ residual=False,
562
+ precompute_staircase=config.precompute_geometric,
563
+ precompute_routes=config.precompute_geometric,
564
+ precompute_distances=config.precompute_geometric,
565
+ use_optimized_gather=True,
566
+ staircase_cache_sizes=[config.num_patches],
567
+ use_torch_compile=config.use_torch_compile
568
+ )
569
+ self.fusion = CantorMultiheadFusion(fusion_config)
570
+
571
+ self.norm2 = nn.LayerNorm(config.embed_dim)
572
+ mlp_hidden = config.embed_dim * 4
573
+ self.mlp = nn.Sequential(
574
+ nn.Linear(config.embed_dim, mlp_hidden),
575
+ nn.GELU(),
576
+ nn.Dropout(config.dropout),
577
+ nn.Linear(mlp_hidden, config.embed_dim),
578
+ nn.Dropout(config.dropout)
579
+ )
580
+ self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity()
581
+
582
+ def forward(self, x: torch.Tensor, return_fusion_info: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict]]:
583
+ fusion_result = self.fusion(self.norm1(x))
584
+ x = x + self.drop_path(fusion_result['output'])
585
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
586
+
587
+ if return_fusion_info:
588
+ fusion_info = {
589
+ 'consciousness': fusion_result.get('consciousness'),
590
+ 'cantor_measure': fusion_result.get('cantor_measure')
591
+ }
592
+ return x, fusion_info
593
+ return x
594
+
595
+
596
+ class CantorClassifier(nn.Module):
597
+ """Cantor fusion classifier."""
598
+ def __init__(self, config: CantorTrainingConfig):
599
+ super().__init__()
600
+ self.config = config
601
+
602
+ self.patch_embed = PatchEmbedding(config)
603
+
604
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_fusion_blocks)]
605
+ self.blocks = nn.ModuleList([
606
+ CantorFusionBlock(config, drop_path=dpr[i])
607
+ for i in range(config.num_fusion_blocks)
608
+ ])
609
+
610
+ self.norm = nn.LayerNorm(config.embed_dim)
611
+ self.head = nn.Linear(config.embed_dim, config.num_classes)
612
+
613
+ self.apply(self._init_weights)
614
+
615
+ def _init_weights(self, m):
616
+ if isinstance(m, nn.Linear):
617
+ nn.init.trunc_normal_(m.weight, std=0.02)
618
+ if m.bias is not None:
619
+ nn.init.constant_(m.bias, 0)
620
+ elif isinstance(m, nn.LayerNorm):
621
+ nn.init.constant_(m.bias, 0)
622
+ nn.init.constant_(m.weight, 1.0)
623
+ elif isinstance(m, nn.Conv2d):
624
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
625
+
626
+ def forward(self, x: torch.Tensor, return_fusion_info: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, List[Dict]]]:
627
+ x = self.patch_embed(x)
628
+
629
+ fusion_infos = []
630
+ for i, block in enumerate(self.blocks):
631
+ if return_fusion_info and i == len(self.blocks) - 1:
632
+ x, fusion_info = block(x, return_fusion_info=True)
633
+ fusion_infos.append(fusion_info)
634
+ else:
635
+ x = block(x)
636
+
637
+ x = self.norm(x)
638
+ x = x.mean(dim=1)
639
+ logits = self.head(x)
640
+
641
+ if return_fusion_info:
642
+ return logits, fusion_infos
643
+ return logits
644
+
645
+
646
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
647
+ # HuggingFace Integration
648
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
649
+
650
+ class HuggingFaceUploader:
651
+ """Manages HuggingFace Hub uploads to ONE shared repo."""
652
+
653
+ def __init__(self, config: CantorTrainingConfig):
654
+ self.config = config
655
+ self.api = HfApi(token=config.hf_token) if config.upload_to_hf else None
656
+ self.repo_id = f"{config.hf_username}/{config.hf_repo_name}"
657
+ self.run_prefix = f"runs/{config.run_name}"
658
+
659
+ if config.upload_to_hf:
660
+ self._create_repo()
661
+ self._update_main_readme()
662
+
663
+ def _create_repo(self):
664
+ """Create HuggingFace repo if it doesn't exist."""
665
+ try:
666
+ create_repo(
667
+ repo_id=self.repo_id,
668
+ token=self.config.hf_token,
669
+ exist_ok=True,
670
+ private=False
671
+ )
672
+ print(f"[HF] Repository: https://huggingface.co/{self.repo_id}")
673
+ print(f"[HF] Run folder: {self.run_prefix}")
674
+ except Exception as e:
675
+ print(f"[HF] Warning: Could not create repo: {e}")
676
+
677
+ def _update_main_readme(self):
678
+ """Create or update the main shared README at repo root."""
679
+ if not self.config.upload_to_hf or self.api is None:
680
+ return
681
+
682
+ boost_info = ""
683
+ if self.config.restart_lr_mult > 1.0:
684
+ boost_info = f"""
685
+ ### πŸš€ LR Boost at Restarts (NEW!)
686
+ This run uses **restart_lr_mult = {self.config.restart_lr_mult}x**:
687
+ - Normal restart: 3e-4 β†’ 1e-7 β†’ restart at 3e-4
688
+ - **Boosted restart**: 3e-4 β†’ 1e-7 β†’ restart at {self.config.learning_rate * self.config.restart_lr_mult:.2e} ({self.config.restart_lr_mult}x!)
689
+ - Creates **wider exploration curves** to escape solidified local minima
690
+ - Each restart provides progressively stronger exploration boost
691
+ """
692
+
693
+ main_readme = f"""---
694
+ tags:
695
+ - image-classification
696
+ - cantor-fusion
697
+ - geometric-deep-learning
698
+ - safetensors
699
+ - vision-transformer
700
+ - warm-restarts
701
+ library_name: pytorch
702
+ datasets:
703
+ - cifar10
704
+ - cifar100
705
+ metrics:
706
+ - accuracy
707
+ ---
708
+
709
+ # {self.config.hf_repo_name}
710
+
711
+ **Geometric Deep Learning with Cantor Multihead Fusion + AdamW Warm Restarts**
712
+
713
+ This repository contains multiple training runs using Cantor fusion architecture with pentachoron structures, geometric routing, and **CosineAnnealingWarmRestarts** for automatic exploration cycles.
714
+
715
+ ## Training Strategy: AdamW + Warm Restarts
716
+
717
+ This model uses **AdamW with Cosine Annealing Warm Restarts** (SGDR):
718
+ - **Drop phase**: LR decays from {self.config.learning_rate} β†’ {self.config.min_lr} over {self.config.restart_period} epochs
719
+ - **Restart phase**: LR jumps back to {self.config.learning_rate} to explore new regions
720
+ - **Cycle multiplier**: Each cycle is {self.config.restart_mult}x longer than previous
721
+ - **Benefits**: Automatic exploration + exploitation, finds better minima, robust training
722
+ {boost_info}
723
+
724
+ ### Restart Schedule
725
+ ```
726
+ Epochs 0-{self.config.restart_period}: LR: {self.config.learning_rate} β†’ {self.config.min_lr} (first cycle)
727
+ Epoch {self.config.restart_period}: LR: RESTART to {self.config.learning_rate * self.config.restart_lr_mult if self.config.restart_lr_mult > 1.0 else self.config.learning_rate} πŸ”„
728
+ Epochs {self.config.restart_period}-{self.config.restart_period * (1 + self.config.restart_mult)}: LR: {self.config.learning_rate * self.config.restart_lr_mult if self.config.restart_lr_mult > 1.0 else self.config.learning_rate} β†’ {self.config.min_lr} (longer cycle)
729
+ ...
730
+ ```
731
+
732
+ ## Current Run
733
+
734
+ **Latest**: `{self.config.run_name}`
735
+ - **Dataset**: {self.config.dataset.upper()}
736
+ - **Fusion Mode**: {self.config.fusion_mode}
737
+ - **Optimizer**: AdamW (adaptive moments)
738
+ - **Scheduler**: CosineAnnealingWarmRestarts
739
+ - **Restart LR Mult**: {self.config.restart_lr_mult}x
740
+ - **Architecture**: {self.config.num_fusion_blocks} blocks, {self.config.num_heads} heads
741
+ - **Simplex**: {self.config.k_simplex}-simplex ({self.config.k_simplex + 1} vertices)
742
+
743
+ ## Architecture
744
+
745
+ The Cantor Fusion architecture uses:
746
+ - **Geometric Routing**: Pentachoron (5-simplex) structures for token routing
747
+ - **Cantor Multihead Fusion**: Multiple fusion heads with geometric attention
748
+ - **Beatrix Consciousness Routing**: Optional consciousness-aware token fusion
749
+ - **SafeTensors Format**: All model weights use SafeTensors (not pickle)
750
+
751
+ ## Usage
752
+ ```python
753
+ from huggingface_hub import hf_hub_download
754
+ from safetensors.torch import load_file
755
+
756
+ model_path = hf_hub_download(
757
+ repo_id="{self.repo_id}",
758
+ filename="runs/YOUR_RUN_NAME/checkpoints/best_model.safetensors"
759
+ )
760
+
761
+ state_dict = load_file(model_path)
762
+ model.load_state_dict(state_dict)
763
+ ```
764
+
765
+ ## Citation
766
+ ```bibtex
767
+ @misc{{{self.config.hf_repo_name.replace('-', '_')},
768
+ author = {{AbstractPhil}},
769
+ title = {{{self.config.hf_repo_name}: Geometric Deep Learning with Warm Restarts}},
770
+ year = {{2025}},
771
+ publisher = {{HuggingFace}},
772
+ url = {{https://huggingface.co/{self.repo_id}}}
773
+ }}
774
+ ```
775
+
776
+ ---
777
+
778
+ **Repository maintained by**: [@{self.config.hf_username}](https://huggingface.co/{self.config.hf_username})
779
+
780
+ **Latest update**: {time.strftime("%Y-%m-%d %H:%M:%S")}
781
+ """
782
+
783
+ main_readme_path = Path(self.config.weights_dir) / self.config.model_name / "MAIN_README.md"
784
+ main_readme_path.parent.mkdir(parents=True, exist_ok=True)
785
+ with open(main_readme_path, 'w') as f:
786
+ f.write(main_readme)
787
+
788
+ try:
789
+ upload_file(
790
+ path_or_fileobj=str(main_readme_path),
791
+ path_in_repo="README.md",
792
+ repo_id=self.repo_id,
793
+ token=self.config.hf_token
794
+ )
795
+ print(f"[HF] Updated main README")
796
+ except Exception as e:
797
+ print(f"[HF] Main README upload failed: {e}")
798
+
799
+ def upload_file(self, file_path: Path, repo_path: str):
800
+ """Upload single file to HuggingFace."""
801
+ if not self.config.upload_to_hf or self.api is None:
802
+ return
803
+
804
+ try:
805
+ if not repo_path.startswith(self.run_prefix) and not repo_path.startswith("runs/"):
806
+ full_path = f"{self.run_prefix}/{repo_path}"
807
+ else:
808
+ full_path = repo_path
809
+
810
+ upload_file(
811
+ path_or_fileobj=str(file_path),
812
+ path_in_repo=full_path,
813
+ repo_id=self.repo_id,
814
+ token=self.config.hf_token
815
+ )
816
+ print(f"[HF] βœ“ Uploaded: {full_path}")
817
+ except Exception as e:
818
+ print(f"[HF] βœ— Upload failed ({full_path}): {e}")
819
+
820
+ def upload_folder_contents(self, folder_path: Path, repo_folder: str):
821
+ """Upload entire folder to HuggingFace."""
822
+ if not self.config.upload_to_hf or self.api is None:
823
+ return
824
+
825
+ try:
826
+ full_path = f"{self.run_prefix}/{repo_folder}"
827
+ upload_folder(
828
+ folder_path=str(folder_path),
829
+ repo_id=self.repo_id,
830
+ path_in_repo=full_path,
831
+ token=self.config.hf_token,
832
+ ignore_patterns=["*.pyc", "__pycache__"]
833
+ )
834
+ print(f"[HF] Uploaded folder: {full_path}")
835
+ except Exception as e:
836
+ print(f"[HF] Folder upload failed: {e}")
837
+
838
+ def create_model_card(self, trainer_stats: Dict):
839
+ """Create and upload run-specific model card."""
840
+ if not self.config.upload_to_hf:
841
+ return
842
+
843
+ boost_section = ""
844
+ if self.config.restart_lr_mult > 1.0:
845
+ boost_section = f"""
846
+ ### πŸš€ LR Boost Feature
847
+
848
+ This run uses **restart_lr_mult = {self.config.restart_lr_mult}x** for aggressive exploration:
849
+
850
+ **How it works:**
851
+ ```
852
+ Cycle 1: {self.config.learning_rate:.2e} β†’ {self.config.min_lr:.2e} (standard convergence)
853
+ Restart: β†’ {self.config.learning_rate * self.config.restart_lr_mult:.2e} (BOOSTED!)
854
+ Cycle 2: {self.config.learning_rate * self.config.restart_lr_mult:.2e} β†’ {self.config.min_lr:.2e} (wider exploration)
855
+ Restart: β†’ {self.config.learning_rate * (self.config.restart_lr_mult ** 2):.2e} (EVEN MORE BOOSTED!)
856
+ Cycle 3: {self.config.learning_rate * (self.config.restart_lr_mult ** 2):.2e} β†’ {self.config.min_lr:.2e}
857
+ ...
858
+ ```
859
+
860
+ **Benefits:**
861
+ - πŸ”“ **Escape solidified local minima** with aggressive LR spikes
862
+ - 🌊 **Wider exploration curves** after each restart
863
+ - πŸ’ͺ **Progressively stronger exploration** as training proceeds
864
+ - 🎯 **Combat training plateaus** that plague long runs
865
+ """
866
+
867
+ run_card = f"""# Run: {self.config.run_name}
868
+
869
+ ## Configuration
870
+ - **Dataset**: {self.config.dataset.upper()}
871
+ - **Fusion Mode**: {self.config.fusion_mode}
872
+ - **Parameters**: {trainer_stats['total_params']:,}
873
+ - **Simplex**: {self.config.k_simplex}-simplex ({self.config.k_simplex + 1} vertices)
874
+
875
+ ## Performance
876
+ - **Best Validation Accuracy**: {trainer_stats['best_acc']:.2f}%
877
+ - **Training Time**: {trainer_stats['training_time']:.1f} hours
878
+ - **Final Epoch**: {trainer_stats['final_epoch']}
879
+
880
+ ## Training Setup: AdamW + Warm Restarts
881
+ - **Optimizer**: AdamW (lr={self.config.learning_rate}, wd={self.config.weight_decay})
882
+ - **Scheduler**: CosineAnnealingWarmRestarts
883
+ - **Restart Period (T_0)**: {self.config.restart_period} epochs
884
+ - **Cycle Multiplier (T_mult)**: {self.config.restart_mult}x
885
+ - **Restart LR Mult**: {self.config.restart_lr_mult}x {'πŸš€' if self.config.restart_lr_mult > 1.0 else ''}
886
+ - **Min LR**: {self.config.min_lr}
887
+ - **Batch Size**: {self.config.batch_size}
888
+ - **Mixed Precision**: {trainer_stats.get('mixed_precision', False)}
889
+ {boost_section}
890
+
891
+ ### Learning Rate Schedule
892
+ ```
893
+ Cycle 1: Epochs 0-{self.config.restart_period}
894
+ LR: {self.config.learning_rate} β†’ {self.config.min_lr} (drop)
895
+ Expected: Convergence to local minimum
896
+
897
+ Epoch {self.config.restart_period}: RESTART πŸ”„
898
+ LR: {self.config.min_lr} β†’ {self.config.learning_rate * self.config.restart_lr_mult if self.config.restart_lr_mult > 1.0 else self.config.learning_rate} (jump{"!" if self.config.restart_lr_mult > 1.0 else ""})
899
+ Expected: Escape local minimum, explore new regions
900
+
901
+ Cycle 2: Epochs {self.config.restart_period}-{self.config.restart_period * (1 + self.config.restart_mult)}
902
+ LR: {self.config.learning_rate * self.config.restart_lr_mult if self.config.restart_lr_mult > 1.0 else self.config.learning_rate} β†’ {self.config.min_lr} (longer cycle)
903
+ Expected: Deeper convergence
904
+
905
+ ... and so on
906
+ ```
907
+
908
+ ## Files
909
+ - `{self.run_prefix}/checkpoints/best_model.safetensors` - Model weights
910
+ - `{self.run_prefix}/checkpoints/best_training_state.pt` - Optimizer state
911
+ - `{self.run_prefix}/config.yaml` - Full configuration
912
+ - `{self.run_prefix}/tensorboard/` - TensorBoard logs (LR tracking!)
913
+
914
+ ## Usage
915
+ ```python
916
+ from safetensors.torch import load_file
917
+ from huggingface_hub import hf_hub_download
918
+
919
+ model_path = hf_hub_download(
920
+ repo_id="{self.repo_id}",
921
+ filename="{self.run_prefix}/checkpoints/best_model.safetensors"
922
+ )
923
+
924
+ state_dict = load_file(model_path)
925
+ model.load_state_dict(state_dict)
926
+ ```
927
+
928
+ ## Training Notes
929
+
930
+ **Warm Restarts Benefits:**
931
+ - πŸ”„ **Exploration**: Periodic LR jumps escape local minima
932
+ - πŸ“‰ **Exploitation**: Long drop phases converge deeply
933
+ - 🎯 **Robustness**: Multiple restarts find better solutions
934
+ - πŸ“Š **Monitoring**: Watch TensorBoard for restart effects!
935
+
936
+ **Expected Behavior:**
937
+ - Accuracy improves during each drop phase
938
+ - Brief accuracy dips after restarts (exploration)
939
+ - Overall upward trend across cycles
940
+ - Best models often found late in long cycles
941
+
942
+ ---
943
+
944
+ Built with geometric consciousness-aware routing using the Devil's Staircase (Beatrix) and pentachoron parameterization.
945
+
946
+ **Training completed**: {time.strftime("%Y-%m-%d %H:%M:%S")}
947
+
948
+ [← Back to main repository](https://huggingface.co/{self.repo_id})
949
+ """
950
+
951
+ readme_path = self.config.output_dir / "RUN_README.md"
952
+ with open(readme_path, 'w') as f:
953
+ f.write(run_card)
954
+
955
+ try:
956
+ upload_file(
957
+ path_or_fileobj=str(readme_path),
958
+ path_in_repo=f"{self.run_prefix}/README.md",
959
+ repo_id=self.repo_id,
960
+ token=self.config.hf_token
961
+ )
962
+ print(f"[HF] Uploaded run README")
963
+ except Exception as e:
964
+ print(f"[HF] Run README upload failed: {e}")
965
+
966
+
967
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
968
+ # Trainer with AdamW + CosineAnnealingWarmRestarts + LR Boost
969
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
970
+
971
+ class Trainer:
972
+ """Training manager with AdamW + Warm Restarts + LR Boost."""
973
+
974
+ def __init__(self, config: CantorTrainingConfig):
975
+ self.config = config
976
+ self.device = torch.device(config.device)
977
+
978
+ # Set seed
979
+ torch.manual_seed(config.seed)
980
+ if torch.cuda.is_available():
981
+ torch.cuda.manual_seed(config.seed)
982
+
983
+ # Model
984
+ print("\n" + "=" * 70)
985
+ print(f"Initializing Cantor Classifier - {config.dataset.upper()}")
986
+ print("=" * 70)
987
+
988
+ init_start = time.time()
989
+ self.model = CantorClassifier(config).to(self.device)
990
+ init_time = time.time() - init_start
991
+
992
+ print(f"\n[Model] Initialization time: {init_time:.2f}s")
993
+ self.print_model_info()
994
+
995
+ # Track restart epochs for logging
996
+ self.restart_epochs = self._calculate_restart_epochs()
997
+
998
+ # Optimizer
999
+ self.optimizer = self.create_optimizer()
1000
+
1001
+ # Scheduler
1002
+ self.scheduler = self.create_scheduler()
1003
+
1004
+ # Loss
1005
+ self.criterion = nn.CrossEntropyLoss(label_smoothing=config.label_smoothing)
1006
+
1007
+ # Mixing info
1008
+ self.use_mixing = config.use_mixing
1009
+ self.mixing_type = config.mixing_type
1010
+ self.mixing_prob = config.mixing_prob
1011
+
1012
+ # Mixed precision
1013
+ self.use_amp = config.use_mixed_precision and config.device == "cuda"
1014
+ self.scaler = GradScaler() if self.use_amp else None
1015
+
1016
+ if self.use_amp:
1017
+ print(f"[Training] Mixed precision enabled")
1018
+
1019
+ # TensorBoard
1020
+ self.writer = SummaryWriter(log_dir=str(config.tensorboard_dir))
1021
+ print(f"[TensorBoard] Logging to: {config.tensorboard_dir}")
1022
+ print(f"[Checkpoints] Format: SafeTensors (ClamAV safe)")
1023
+
1024
+ # HuggingFace
1025
+ self.hf_uploader = HuggingFaceUploader(config) if config.upload_to_hf else None
1026
+
1027
+ # Save config
1028
+ config.save(config.output_dir / "config.yaml")
1029
+
1030
+ # Metrics
1031
+ self.best_acc = 0.0
1032
+ self.global_step = 0
1033
+ self.start_time = time.time()
1034
+ self.upload_count = 0
1035
+
1036
+ def apply_mixing(self, images: torch.Tensor, labels: torch.Tensor):
1037
+ """Apply mixing augmentation if enabled."""
1038
+ if not self.use_mixing or torch.rand(1).item() > self.mixing_prob:
1039
+ return images, labels, None
1040
+
1041
+ if self.mixing_type == "alphamix":
1042
+ mixed_images, y_a, y_b, alpha = alphamix_data(
1043
+ images, labels,
1044
+ alpha_range=self.config.mixing_alpha_range,
1045
+ spatial_ratio=self.config.mixing_spatial_ratio
1046
+ )
1047
+ elif self.mixing_type == "fractal":
1048
+ mixed_images, y_a, y_b, alpha = alphamix_fractal(
1049
+ images, labels,
1050
+ alpha_range=self.config.mixing_alpha_range,
1051
+ steps_range=self.config.fractal_steps_range,
1052
+ triad_scales=self.config.fractal_triad_scales
1053
+ )
1054
+ else:
1055
+ raise ValueError(f"Unknown mixing type: {self.mixing_type}")
1056
+
1057
+ return mixed_images, (y_a, y_b, alpha), alpha
1058
+
1059
+ def compute_mixed_loss(self, logits: torch.Tensor, mixed_labels):
1060
+ """Compute loss for mixed labels."""
1061
+ if mixed_labels is None:
1062
+ # No mixing applied
1063
+ return None
1064
+
1065
+ y_a, y_b, alpha = mixed_labels
1066
+ loss_a = self.criterion(logits, y_a)
1067
+ loss_b = self.criterion(logits, y_b)
1068
+
1069
+ # Weighted combination based on mixing ratio
1070
+ # Use spatial_ratio for weighting (alpha represents transparency)
1071
+ loss = alpha * loss_a + (1 - alpha) * loss_b
1072
+ return loss
1073
+
1074
+
1075
+ def _calculate_restart_epochs(self) -> List[int]:
1076
+ """Calculate when restarts will occur."""
1077
+ if self.config.scheduler_type != "cosine_restarts":
1078
+ return []
1079
+
1080
+ restarts = []
1081
+ current = self.config.restart_period
1082
+ period = self.config.restart_period
1083
+
1084
+ while current < self.config.num_epochs:
1085
+ restarts.append(current)
1086
+ period *= self.config.restart_mult
1087
+ current += period
1088
+
1089
+ return restarts
1090
+
1091
+ def create_optimizer(self):
1092
+ """Create optimizer based on config."""
1093
+ if self.config.optimizer_type == "sgd":
1094
+ print(f"\n[Optimizer] SGD")
1095
+ print(f" LR: {self.config.learning_rate}")
1096
+ print(f" Momentum: {self.config.sgd_momentum}")
1097
+ print(f" Nesterov: {self.config.sgd_nesterov}")
1098
+ print(f" Weight decay: {self.config.weight_decay}")
1099
+
1100
+ return torch.optim.SGD(
1101
+ self.model.parameters(),
1102
+ lr=self.config.learning_rate,
1103
+ momentum=self.config.sgd_momentum,
1104
+ weight_decay=self.config.weight_decay,
1105
+ nesterov=self.config.sgd_nesterov
1106
+ )
1107
+
1108
+ elif self.config.optimizer_type == "adamw":
1109
+ print(f"\n[Optimizer] AdamW")
1110
+ print(f" LR: {self.config.learning_rate}")
1111
+ print(f" Betas: {self.config.adamw_betas}")
1112
+ print(f" Weight decay: {self.config.weight_decay}")
1113
+
1114
+ return torch.optim.AdamW(
1115
+ self.model.parameters(),
1116
+ lr=self.config.learning_rate,
1117
+ betas=self.config.adamw_betas,
1118
+ eps=self.config.adamw_eps,
1119
+ weight_decay=self.config.weight_decay
1120
+ )
1121
+
1122
+ else:
1123
+ raise ValueError(f"Unknown optimizer: {self.config.optimizer_type}")
1124
+
1125
+ def create_scheduler(self):
1126
+ """Create LR scheduler based on config."""
1127
+ if self.config.scheduler_type == "cosine_restarts":
1128
+ print(f"\n[Scheduler] CosineAnnealingWarmRestarts with LR Boost")
1129
+ print(f" T_0 (restart period): {self.config.restart_period} epochs")
1130
+ print(f" T_mult (cycle multiplier): {self.config.restart_mult}x")
1131
+ print(f" Restart LR mult: {self.config.restart_lr_mult}x {'πŸš€' if self.config.restart_lr_mult > 1.0 else ''}")
1132
+ print(f" Min LR: {self.config.min_lr}")
1133
+
1134
+ if self.config.restart_lr_mult > 1.0:
1135
+ print(f"\n πŸš€ BOOST MODE ENABLED!")
1136
+ print(f" Baseline LR: {self.config.learning_rate:.2e}")
1137
+ boosted_lrs = [self.config.learning_rate * (self.config.restart_lr_mult ** i) for i in range(1, min(4, len(self.restart_epochs) + 1))]
1138
+ for i, lr in enumerate(boosted_lrs):
1139
+ print(f" After restart #{i+1}: {lr:.2e} ({self.config.restart_lr_mult**(i+1):.2f}x)")
1140
+ print(f" β†’ Creates wider exploration curves to escape local minima!")
1141
+
1142
+ print(f"\n Restart schedule:")
1143
+ for i, epoch in enumerate(self.restart_epochs[:5]): # Show first 5
1144
+ mult = self.config.restart_lr_mult ** (i + 1) if self.config.restart_lr_mult > 1.0 else 1.0
1145
+ print(f" Restart #{i+1}: Epoch {epoch} (LR: {self.config.learning_rate * mult:.2e})")
1146
+ if len(self.restart_epochs) > 5:
1147
+ print(f" ... and {len(self.restart_epochs) - 5} more")
1148
+
1149
+ return CosineAnnealingWarmRestartsWithBoost(
1150
+ self.optimizer,
1151
+ T_0=self.config.restart_period,
1152
+ T_mult=self.config.restart_mult,
1153
+ eta_min=self.config.min_lr,
1154
+ restart_lr_mult=self.config.restart_lr_mult
1155
+ )
1156
+
1157
+ elif self.config.scheduler_type == "multistep":
1158
+ print(f"\n[Scheduler] MultiStepLR")
1159
+ print(f" Milestones: {self.config.lr_milestones}")
1160
+ print(f" Gamma: {self.config.lr_gamma}")
1161
+
1162
+ return torch.optim.lr_scheduler.MultiStepLR(
1163
+ self.optimizer,
1164
+ milestones=self.config.lr_milestones,
1165
+ gamma=self.config.lr_gamma
1166
+ )
1167
+
1168
+ elif self.config.scheduler_type == "cosine":
1169
+ print(f"\n[Scheduler] Cosine annealing with warmup")
1170
+ print(f" Warmup epochs: {self.config.warmup_epochs}")
1171
+ print(f" Min LR: {self.config.min_lr}")
1172
+
1173
+ def lr_lambda(epoch):
1174
+ if epoch < self.config.warmup_epochs:
1175
+ return (epoch + 1) / self.config.warmup_epochs
1176
+ progress = (epoch - self.config.warmup_epochs) / (self.config.num_epochs - self.config.warmup_epochs)
1177
+ return 0.5 * (1 + math.cos(math.pi * progress))
1178
+
1179
+ return torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda)
1180
+
1181
+ else:
1182
+ raise ValueError(f"Unknown scheduler: {self.config.scheduler_type}")
1183
+
1184
+ def print_model_info(self):
1185
+ """Print model info."""
1186
+ total_params = sum(p.numel() for p in self.model.parameters())
1187
+ print(f"\nParameters: {total_params:,}")
1188
+ print(f"Dataset: {self.config.dataset.upper()}")
1189
+ print(f"Classes: {self.config.num_classes}")
1190
+ print(f"Fusion mode: {self.config.fusion_mode}")
1191
+ print(f"Optimizer: {self.config.optimizer_type.upper()}")
1192
+ print(f"Scheduler: {self.config.scheduler_type}")
1193
+ if self.config.restart_lr_mult > 1.0:
1194
+ print(f"LR Boost: {self.config.restart_lr_mult}x at restarts πŸš€")
1195
+ if self.config.use_mixing:
1196
+ print(f"Mixing: {self.config.mixing_type} (prob={self.config.mixing_prob})")
1197
+ print(f"Output: {self.config.output_dir}")
1198
+
1199
+ def train_epoch(self, train_loader: DataLoader, epoch: int) -> Tuple[float, float]:
1200
+ """Train one epoch."""
1201
+ self.model.train()
1202
+ total_loss, correct, total = 0.0, 0, 0
1203
+ mixing_applied_count = 0
1204
+ total_batches = 0
1205
+
1206
+ # Check if this is a restart epoch
1207
+ is_restart = (epoch in self.restart_epochs)
1208
+ epoch_desc = f"Epoch {epoch+1}/{self.config.num_epochs}"
1209
+ if is_restart:
1210
+ restart_num = self.restart_epochs.index(epoch) + 1
1211
+ boost_mult = self.config.restart_lr_mult ** restart_num if self.config.restart_lr_mult > 1.0 else 1.0
1212
+ epoch_desc += f" πŸ”„ RESTART #{restart_num}"
1213
+ if self.config.restart_lr_mult > 1.0:
1214
+ epoch_desc += f" ({boost_mult:.2f}x)"
1215
+
1216
+ pbar = tqdm(train_loader, desc=f"{epoch_desc} [Train]")
1217
+
1218
+ for batch_idx, (images, labels) in enumerate(pbar):
1219
+ images, labels = images.to(self.device, non_blocking=True), labels.to(self.device, non_blocking=True)
1220
+
1221
+ # Apply mixing augmentation
1222
+ original_labels = labels
1223
+ mixed_images, mixed_labels_info, mixing_alpha = self.apply_mixing(images, labels)
1224
+ if mixing_alpha is not None:
1225
+ mixing_applied_count += 1
1226
+ images = mixed_images
1227
+
1228
+ total_batches += 1
1229
+
1230
+ # Forward
1231
+ if self.use_amp:
1232
+ with autocast():
1233
+ logits = self.model(images)
1234
+
1235
+ # Compute loss (handle mixed labels)
1236
+ if mixing_alpha is not None:
1237
+ loss = self.compute_mixed_loss(logits, mixed_labels_info)
1238
+ else:
1239
+ loss = self.criterion(logits, labels)
1240
+
1241
+ self.optimizer.zero_grad(set_to_none=True)
1242
+ self.scaler.scale(loss).backward()
1243
+ self.scaler.unscale_(self.optimizer)
1244
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_clip)
1245
+ self.scaler.step(self.optimizer)
1246
+ self.scaler.update()
1247
+ else:
1248
+ logits = self.model(images)
1249
+
1250
+ # Compute loss (handle mixed labels)
1251
+ if mixing_alpha is not None:
1252
+ loss = self.compute_mixed_loss(logits, mixed_labels_info)
1253
+ else:
1254
+ loss = self.criterion(logits, labels)
1255
+
1256
+ self.optimizer.zero_grad(set_to_none=True)
1257
+ loss.backward()
1258
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_clip)
1259
+ self.optimizer.step()
1260
+
1261
+ # Metrics (use original labels for accuracy)
1262
+ total_loss += loss.item()
1263
+ _, predicted = logits.max(1)
1264
+ correct += predicted.eq(original_labels).sum().item()
1265
+ total += original_labels.size(0)
1266
+
1267
+ # TensorBoard logging
1268
+ if batch_idx % self.config.log_interval == 0:
1269
+ current_lr = self.scheduler.get_last_lr()[0]
1270
+ self.writer.add_scalar('train/loss', loss.item(), self.global_step)
1271
+ self.writer.add_scalar('train/accuracy', 100. * correct / total, self.global_step)
1272
+ self.writer.add_scalar('train/learning_rate', current_lr, self.global_step)
1273
+ if mixing_alpha is not None:
1274
+ self.writer.add_scalar('train/mixing_alpha', mixing_alpha, self.global_step)
1275
+
1276
+ self.global_step += 1
1277
+
1278
+ postfix_dict = {
1279
+ 'loss': f'{loss.item():.4f}',
1280
+ 'acc': f'{100. * correct / total:.2f}%',
1281
+ 'lr': f'{self.scheduler.get_last_lr()[0]:.6f}'
1282
+ }
1283
+ if self.use_mixing:
1284
+ mix_pct = 100.0 * mixing_applied_count / total_batches
1285
+ postfix_dict['mix'] = f'{mix_pct:.0f}%'
1286
+
1287
+ pbar.set_postfix(postfix_dict)
1288
+
1289
+ return total_loss / len(train_loader), 100. * correct / total
1290
+
1291
+ @torch.no_grad()
1292
+ def evaluate(self, val_loader: DataLoader, epoch: int) -> Tuple[float, Dict]:
1293
+ """Evaluate."""
1294
+ self.model.eval()
1295
+ total_loss, correct, total = 0.0, 0, 0
1296
+ consciousness_values = []
1297
+
1298
+ pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{self.config.num_epochs} [Val] ")
1299
+
1300
+ for batch_idx, (images, labels) in enumerate(pbar):
1301
+ images, labels = images.to(self.device, non_blocking=True), labels.to(self.device, non_blocking=True)
1302
+
1303
+ # Forward with fusion info on last batch
1304
+ return_info = (batch_idx == len(val_loader) - 1)
1305
+
1306
+ if self.use_amp:
1307
+ with autocast():
1308
+ if return_info:
1309
+ logits, fusion_infos = self.model(images, return_fusion_info=True)
1310
+ if fusion_infos and fusion_infos[0].get('consciousness') is not None:
1311
+ consciousness_values.append(fusion_infos[0]['consciousness'].mean().item())
1312
+ else:
1313
+ logits = self.model(images)
1314
+ loss = self.criterion(logits, labels)
1315
+ else:
1316
+ if return_info:
1317
+ logits, fusion_infos = self.model(images, return_fusion_info=True)
1318
+ if fusion_infos and fusion_infos[0].get('consciousness') is not None:
1319
+ consciousness_values.append(fusion_infos[0]['consciousness'].mean().item())
1320
+ else:
1321
+ logits = self.model(images)
1322
+ loss = self.criterion(logits, labels)
1323
+
1324
+ total_loss += loss.item()
1325
+ _, predicted = logits.max(1)
1326
+ correct += predicted.eq(labels).sum().item()
1327
+ total += labels.size(0)
1328
+
1329
+ pbar.set_postfix({
1330
+ 'loss': f'{total_loss / (batch_idx + 1):.4f}',
1331
+ 'acc': f'{100. * correct / total:.2f}%'
1332
+ })
1333
+
1334
+ avg_loss = total_loss / len(val_loader)
1335
+ accuracy = 100. * correct / total
1336
+
1337
+ # TensorBoard logging
1338
+ self.writer.add_scalar('val/loss', avg_loss, epoch)
1339
+ self.writer.add_scalar('val/accuracy', accuracy, epoch)
1340
+ if consciousness_values:
1341
+ self.writer.add_scalar('val/consciousness', sum(consciousness_values) / len(consciousness_values), epoch)
1342
+
1343
+ metrics = {
1344
+ 'loss': avg_loss,
1345
+ 'accuracy': accuracy,
1346
+ 'consciousness': sum(consciousness_values) / len(consciousness_values) if consciousness_values else None
1347
+ }
1348
+
1349
+ return accuracy, metrics
1350
+
1351
+ def train(self, train_loader: DataLoader, val_loader: DataLoader):
1352
+ """Full training loop."""
1353
+ print("\n" + "=" * 70)
1354
+ print("Starting training with AdamW + Warm Restarts" + (" + LR Boost πŸš€" if self.config.restart_lr_mult > 1.0 else ""))
1355
+ print(f"Optimizer: {self.config.optimizer_type.upper()}")
1356
+ print(f"Scheduler: {self.config.scheduler_type}")
1357
+ print(f"Restart period: {self.config.restart_period} epochs (T_0)")
1358
+ print(f"Cycle multiplier: {self.config.restart_mult}x (T_mult)")
1359
+ if self.config.restart_lr_mult > 1.0:
1360
+ print(f"LR boost multiplier: {self.config.restart_lr_mult}x πŸš€")
1361
+ print(f"Total restarts: {len(self.restart_epochs)}")
1362
+ print("=" * 70 + "\n")
1363
+
1364
+ for epoch in range(self.config.num_epochs):
1365
+ # Train
1366
+ train_loss, train_acc = self.train_epoch(train_loader, epoch)
1367
+
1368
+ # Evaluate
1369
+ val_acc, val_metrics = self.evaluate(val_loader, epoch)
1370
+
1371
+ # Update scheduler
1372
+ self.scheduler.step()
1373
+
1374
+ # Check if this is a restart epoch or next epoch is a restart
1375
+ is_restart = (epoch in self.restart_epochs)
1376
+ next_is_restart = ((epoch + 1) in self.restart_epochs)
1377
+ next_lr = self.scheduler.get_last_lr()[0]
1378
+
1379
+ # Print summary
1380
+ print(f"\n{'='*70}")
1381
+ print(f"Epoch [{epoch + 1}/{self.config.num_epochs}] Summary:")
1382
+ print(f" Train: Loss={train_loss:.4f}, Acc={train_acc:.2f}%")
1383
+ print(f" Val: Loss={val_metrics['loss']:.4f}, Acc={val_acc:.2f}%")
1384
+ if val_metrics['consciousness']:
1385
+ print(f" Consciousness: {val_metrics['consciousness']:.4f}")
1386
+
1387
+ if next_is_restart:
1388
+ restart_num = self.restart_epochs.index(epoch + 1) + 1
1389
+ boost_mult = self.config.restart_lr_mult ** restart_num if self.config.restart_lr_mult > 1.0 else 1.0
1390
+ print(f" Next LR: {next_lr:.6f}")
1391
+ print(f" ⚠️ RESTART COMING! Next epoch will jump to {next_lr * self.config.restart_lr_mult:.6f}")
1392
+ if self.config.restart_lr_mult > 1.0:
1393
+ print(f" πŸš€ Boosted exploration: {boost_mult:.2f}x baseline!")
1394
+ print(f" (Breaking out of solidified local minima)")
1395
+ elif is_restart:
1396
+ restart_num = self.restart_epochs.index(epoch) + 1
1397
+ boost_mult = self.config.restart_lr_mult ** restart_num if self.config.restart_lr_mult > 1.0 else 1.0
1398
+ print(f" πŸ”„ WARM RESTART #{restart_num}! Current LR: {next_lr:.6f}")
1399
+ if self.config.restart_lr_mult > 1.0:
1400
+ print(f" πŸš€ Exploration boost: {boost_mult:.2f}x baseline")
1401
+ print(f" (Wider curve for aggressive exploration)")
1402
+ else:
1403
+ print(f" Current LR: {next_lr:.6f}")
1404
+
1405
+ # Checkpoint logic
1406
+ is_best = val_acc > self.best_acc
1407
+ should_save_regular = ((epoch + 1) % self.config.save_interval == 0)
1408
+ should_upload_regular = ((epoch + 1) % self.config.checkpoint_upload_interval == 0)
1409
+
1410
+ if is_best:
1411
+ self.best_acc = val_acc
1412
+ print(f" βœ“ New best model! Accuracy: {val_acc:.2f}%")
1413
+ self.save_checkpoint(epoch, val_acc, prefix="best", upload=should_upload_regular)
1414
+
1415
+ if should_save_regular:
1416
+ self.save_checkpoint(epoch, val_acc, prefix=f"epoch_{epoch+1}", upload=should_upload_regular)
1417
+
1418
+ print(f" HF Uploads: {self.upload_count}")
1419
+ print(f"{'='*70}\n")
1420
+
1421
+ # Flush TensorBoard
1422
+ if (epoch + 1) % 10 == 0:
1423
+ self.writer.flush()
1424
+
1425
+ # Training complete
1426
+ training_time = (time.time() - self.start_time) / 3600
1427
+
1428
+ print("\n" + "=" * 70)
1429
+ print("Training Complete!")
1430
+ print(f"Best Validation Accuracy: {self.best_acc:.2f}%")
1431
+ print(f"Training Time: {training_time:.2f} hours")
1432
+ print(f"Total Uploads: {self.upload_count}")
1433
+ print(f"Warm Restarts: {len(self.restart_epochs)}")
1434
+ if self.config.restart_lr_mult > 1.0:
1435
+ print(f"LR Boost: {self.config.restart_lr_mult}x (helped escape local minima! πŸš€)")
1436
+ print("=" * 70)
1437
+
1438
+ # Upload to HuggingFace
1439
+ if self.hf_uploader:
1440
+ print("\n[HF] Uploading final best model...")
1441
+ best_model_path = self.config.checkpoint_dir / "best_model.safetensors"
1442
+ best_state_path = self.config.checkpoint_dir / "best_training_state.pt"
1443
+ best_metadata_path = self.config.checkpoint_dir / "best_metadata.json"
1444
+ config_path = self.config.output_dir / "config.yaml"
1445
+
1446
+ if best_model_path.exists():
1447
+ self.hf_uploader.upload_file(best_model_path, "checkpoints/best_model.safetensors")
1448
+ if best_state_path.exists():
1449
+ self.hf_uploader.upload_file(best_state_path, "checkpoints/best_training_state.pt")
1450
+ if best_metadata_path.exists():
1451
+ self.hf_uploader.upload_file(best_metadata_path, "checkpoints/best_metadata.json")
1452
+ if config_path.exists():
1453
+ self.hf_uploader.upload_file(config_path, "config.yaml")
1454
+
1455
+ print("[HF] Final upload: TensorBoard logs...")
1456
+ self.hf_uploader.upload_folder_contents(self.config.tensorboard_dir, "tensorboard")
1457
+
1458
+ trainer_stats = {
1459
+ 'total_params': sum(p.numel() for p in self.model.parameters()),
1460
+ 'best_acc': self.best_acc,
1461
+ 'training_time': training_time,
1462
+ 'final_epoch': self.config.num_epochs,
1463
+ 'batch_size': self.config.batch_size,
1464
+ 'mixed_precision': self.use_amp
1465
+ }
1466
+ self.hf_uploader.create_model_card(trainer_stats)
1467
+
1468
+ self.writer.close()
1469
+
1470
+ def save_checkpoint(self, epoch: int, accuracy: float, prefix: str = "checkpoint", upload: bool = False):
1471
+ """Save checkpoint as safetensors with selective upload."""
1472
+ checkpoint_dir = self.config.checkpoint_dir
1473
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
1474
+
1475
+ # 1. Save model weights as safetensors
1476
+ model_path = checkpoint_dir / f"{prefix}_model.safetensors"
1477
+ save_file(self.model.state_dict(), str(model_path))
1478
+
1479
+ # 2. Save optimizer/scheduler state
1480
+ training_state = {
1481
+ 'optimizer_state_dict': self.optimizer.state_dict(),
1482
+ 'scheduler_state_dict': self.scheduler.state_dict(),
1483
+ }
1484
+ if self.scaler is not None:
1485
+ training_state['scaler_state_dict'] = self.scaler.state_dict()
1486
+
1487
+ training_state_path = checkpoint_dir / f"{prefix}_training_state.pt"
1488
+ torch.save(training_state, training_state_path)
1489
+
1490
+ # 3. Save metadata
1491
+ metadata = {
1492
+ 'epoch': epoch,
1493
+ 'accuracy': accuracy,
1494
+ 'best_accuracy': self.best_acc,
1495
+ 'global_step': self.global_step,
1496
+ 'timestamp': time.strftime("%Y-%m-%d %H:%M:%S"),
1497
+ 'optimizer': self.config.optimizer_type,
1498
+ 'scheduler': self.config.scheduler_type,
1499
+ 'learning_rate': self.scheduler.get_last_lr()[0],
1500
+ 'restart_lr_mult': self.config.restart_lr_mult
1501
+ }
1502
+ metadata_path = checkpoint_dir / f"{prefix}_metadata.json"
1503
+ with open(metadata_path, 'w') as f:
1504
+ json.dump(metadata, f, indent=2)
1505
+
1506
+ is_best = (prefix == "best")
1507
+
1508
+ if is_best:
1509
+ print(f" πŸ’Ύ Saved best: {prefix}_model.safetensors")
1510
+ else:
1511
+ print(f" πŸ’Ύ Saved: {prefix}_model.safetensors", end="")
1512
+
1513
+ # Upload to HuggingFace
1514
+ if self.hf_uploader and upload:
1515
+ self.hf_uploader.upload_file(model_path, f"checkpoints/{prefix}_model.safetensors")
1516
+ self.hf_uploader.upload_file(training_state_path, f"checkpoints/{prefix}_training_state.pt")
1517
+ self.hf_uploader.upload_file(metadata_path, f"checkpoints/{prefix}_metadata.json")
1518
+
1519
+ if is_best:
1520
+ config_path = self.config.output_dir / "config.yaml"
1521
+ if config_path.exists():
1522
+ self.hf_uploader.upload_file(config_path, "config.yaml")
1523
+
1524
+ self.upload_count += 1
1525
+
1526
+ if not is_best:
1527
+ print(" β†’ Uploaded to HF")
1528
+ else:
1529
+ if not is_best:
1530
+ print(" (local only)")
1531
+
1532
+
1533
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
1534
+ # Data Loading (with Cutout)
1535
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
1536
+
1537
+ class Cutout:
1538
+ """Cutout data augmentation."""
1539
+ def __init__(self, length: int):
1540
+ self.length = length
1541
+
1542
+ def __call__(self, img):
1543
+ h, w = img.size(1), img.size(2)
1544
+ mask = torch.ones((h, w), dtype=torch.float32)
1545
+ y = torch.randint(h, (1,)).item()
1546
+ x = torch.randint(w, (1,)).item()
1547
+
1548
+ y1 = max(0, y - self.length // 2)
1549
+ y2 = min(h, y + self.length // 2)
1550
+ x1 = max(0, x - self.length // 2)
1551
+ x2 = min(w, x + self.length // 2)
1552
+
1553
+ mask[y1:y2, x1:x2] = 0.
1554
+ mask = mask.expand_as(img)
1555
+ return img * mask
1556
+
1557
+
1558
+ def get_data_loaders(config: CantorTrainingConfig) -> Tuple[DataLoader, DataLoader]:
1559
+ """Create data loaders."""
1560
+
1561
+ # Normalization
1562
+ mean = (0.4914, 0.4822, 0.4465)
1563
+ std = (0.2470, 0.2435, 0.2616)
1564
+
1565
+ # Augmentation
1566
+ if config.use_augmentation:
1567
+ transforms_list = []
1568
+
1569
+ if config.use_autoaugment:
1570
+ policy = transforms.AutoAugmentPolicy.CIFAR10
1571
+ transforms_list.append(transforms.AutoAugment(policy))
1572
+ else:
1573
+ transforms_list.extend([
1574
+ transforms.RandomCrop(32, padding=4),
1575
+ transforms.RandomHorizontalFlip(),
1576
+ ])
1577
+
1578
+ transforms_list.append(transforms.ToTensor())
1579
+ transforms_list.append(transforms.Normalize(mean, std))
1580
+
1581
+ if config.use_cutout:
1582
+ transforms_list.append(Cutout(config.cutout_length))
1583
+
1584
+ train_transform = transforms.Compose(transforms_list)
1585
+ else:
1586
+ train_transform = transforms.Compose([
1587
+ transforms.ToTensor(),
1588
+ transforms.Normalize(mean, std)
1589
+ ])
1590
+
1591
+ val_transform = transforms.Compose([
1592
+ transforms.ToTensor(),
1593
+ transforms.Normalize(mean, std)
1594
+ ])
1595
+
1596
+ # Dataset selection
1597
+ if config.dataset == "cifar10":
1598
+ train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
1599
+ val_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=val_transform)
1600
+ elif config.dataset == "cifar100":
1601
+ train_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=train_transform)
1602
+ val_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=val_transform)
1603
+ else:
1604
+ raise ValueError(f"Unknown dataset: {config.dataset}")
1605
+
1606
+ train_loader = DataLoader(
1607
+ train_dataset,
1608
+ batch_size=config.batch_size,
1609
+ shuffle=True,
1610
+ num_workers=config.num_workers,
1611
+ pin_memory=(config.device == "cuda")
1612
+ )
1613
+
1614
+ val_loader = DataLoader(
1615
+ val_dataset,
1616
+ batch_size=config.batch_size,
1617
+ shuffle=False,
1618
+ num_workers=config.num_workers,
1619
+ pin_memory=(config.device == "cuda")
1620
+ )
1621
+
1622
+ return train_loader, val_loader
1623
+
1624
+
1625
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
1626
+ # Main - AdamW + CosineAnnealingWarmRestarts + LR Boost
1627
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
1628
+
1629
+ def main():
1630
+ """Main training function with AdamW + Warm Restarts + LR Boost."""
1631
+
1632
+ # ═══════════════════════════════════════════════════════════════════
1633
+ # Configuration - AdamW with Cosine Annealing Warm Restarts + LR BOOST
1634
+ # ═══════════════════════════════════════════════════════════════════
1635
+
1636
+ config = CantorTrainingConfig(
1637
+ # Dataset
1638
+ dataset="cifar100",
1639
+
1640
+ # Architecture
1641
+ embed_dim=512,
1642
+ num_fusion_blocks=12,
1643
+ num_heads=8,
1644
+ fusion_mode="consciousness",
1645
+ k_simplex=4,
1646
+ use_beatrix=True,
1647
+ fusion_window=32,
1648
+
1649
+ # Optimizer: AdamW
1650
+ optimizer_type="adamw",
1651
+ learning_rate=1e-4,
1652
+ weight_decay=0.005, # Stronger regularization
1653
+ adamw_betas=(0.9, 0.999),
1654
+
1655
+ # Scheduler: Cosine Annealing with Warm Restarts + LR BOOST
1656
+ scheduler_type="cosine_restarts",
1657
+ restart_period=40,
1658
+ restart_mult=1.5, # Consistent cycle growth
1659
+ restart_lr_mult=1.25, # πŸš€ NEW! Boost LR at restarts
1660
+ min_lr=1e-7,
1661
+
1662
+ # Training
1663
+ num_epochs=200,
1664
+ batch_size=256,
1665
+ grad_clip=1.0,
1666
+ label_smoothing=0.15,
1667
+
1668
+ # Augmentation
1669
+ use_augmentation=True,
1670
+ use_autoaugment=True,
1671
+ use_cutout=True,
1672
+ cutout_length=16,
1673
+
1674
+ # Mixing augmentation (AlphaMix)
1675
+ use_mixing=True, # Enable mixing
1676
+ mixing_type="alphamix", # "alphamix" or "fractal"
1677
+ mixing_alpha_range=(0.3, 0.7),
1678
+ mixing_spatial_ratio=0.25,
1679
+ mixing_prob=0.5, # Apply to 50% of batches
1680
+
1681
+ # Regularization
1682
+ dropout=0.1,
1683
+ drop_path_rate=0.15,
1684
+
1685
+ # System
1686
+ device="cuda",
1687
+ use_mixed_precision=False,
1688
+
1689
+ # HuggingFace
1690
+ hf_username="AbstractPhil",
1691
+ upload_to_hf=True,
1692
+ checkpoint_upload_interval=25,
1693
+ )
1694
+
1695
+ print("=" * 70)
1696
+ print(f"Cantor Fusion Classifier - {config.dataset.upper()}")
1697
+ print("Training Strategy: AdamW + Cosine Annealing Warm Restarts")
1698
+ if config.restart_lr_mult > 1.0:
1699
+ print("πŸš€ WITH LR BOOST AT RESTARTS πŸš€")
1700
+ print("=" * 70)
1701
+ print(f"\nConfiguration:")
1702
+ print(f" Dataset: {config.dataset}")
1703
+ print(f" Fusion mode: {config.fusion_mode}")
1704
+ print(f" Optimizer: AdamW")
1705
+ print(f" Scheduler: CosineAnnealingWarmRestarts")
1706
+ print(f" Initial LR: {config.learning_rate}")
1707
+ print(f" Min LR: {config.min_lr}")
1708
+ print(f" Restart period (T_0): {config.restart_period} epochs")
1709
+ print(f" Cycle multiplier (T_mult): {config.restart_mult}x")
1710
+ if config.restart_lr_mult > 1.0:
1711
+ print(f" πŸš€ Restart LR mult: {config.restart_lr_mult}x (BOOST MODE!)")
1712
+ if config.use_mixing:
1713
+ print(f" 🎨 Mixing: {config.mixing_type} (prob={config.mixing_prob})")
1714
+ print(f" Total epochs: {config.num_epochs}")
1715
+
1716
+ # Calculate restart schedule
1717
+ restarts = []
1718
+ current = config.restart_period
1719
+ period = config.restart_period
1720
+ while current < config.num_epochs:
1721
+ restarts.append(current)
1722
+ period *= config.restart_mult
1723
+ current += period
1724
+
1725
+ print(f"\n Restart schedule ({len(restarts)} restarts):")
1726
+ for i, epoch in enumerate(restarts[:5]):
1727
+ boost_mult = config.restart_lr_mult ** (i + 1) if config.restart_lr_mult > 1.0 else 1.0
1728
+ lr = config.learning_rate * boost_mult
1729
+ boost_str = f" ({boost_mult:.2f}x πŸš€)" if config.restart_lr_mult > 1.0 else ""
1730
+ print(f" Restart #{i+1}: Epoch {epoch} β†’ LR: {lr:.2e}{boost_str}")
1731
+ if len(restarts) > 5:
1732
+ print(f" ... and {len(restarts) - 5} more")
1733
+
1734
+ print(f"\n Output: {config.output_dir}")
1735
+ print(f" HuggingFace: {'Enabled' if config.upload_to_hf else 'Disabled'}")
1736
+ if config.upload_to_hf:
1737
+ print(f" Repo: {config.hf_username}/{config.hf_repo_name}")
1738
+ print(f" Run: {config.run_name}")
1739
+
1740
+ if config.restart_lr_mult > 1.0:
1741
+ print("\n" + "=" * 70)
1742
+ print("πŸš€ LR BOOST MODE - Expected Training Behavior:")
1743
+ print("=" * 70)
1744
+ print(f"πŸ“‰ Cycle 1 (epochs 0-{config.restart_period}):")
1745
+ print(f" LR: {config.learning_rate:.2e} β†’ {config.min_lr:.2e} (smooth drop)")
1746
+ print(" Expected: Convergence to local minimum")
1747
+ print("")
1748
+ print(f"πŸ”„ Epoch {config.restart_period}: RESTART WITH BOOST!")
1749
+ boosted_lr = config.learning_rate * config.restart_lr_mult
1750
+ print(f" LR: {config.min_lr:.2e} β†’ {boosted_lr:.2e} ({config.restart_lr_mult}x BOOST!)")
1751
+ print(" Expected: AGGRESSIVE exploration, escape local minimum")
1752
+ print(f" Benefit: Wider curve ({(config.restart_lr_mult - 1) * 100:.0f}% more exploration)")
1753
+ print("")
1754
+ print(f"πŸ“‰ Cycle 2 (epochs {config.restart_period}-{int(config.restart_period * (1 + config.restart_mult))}):")
1755
+ print(f" LR: {boosted_lr:.2e} β†’ {config.min_lr:.2e} (longer cycle)")
1756
+ print(" Expected: Deeper convergence from better starting point")
1757
+ print("")
1758
+ print(f"πŸ”„ Epoch {int(config.restart_period * (1 + config.restart_mult))}: EVEN BIGGER BOOST!")
1759
+ boosted_lr2 = config.learning_rate * (config.restart_lr_mult ** 2)
1760
+ print(f" LR: {config.min_lr:.2e} β†’ {boosted_lr2:.2e} ({config.restart_lr_mult**2:.2f}x!)")
1761
+ print(" Expected: VERY aggressive exploration")
1762
+ print("")
1763
+ print("🎯 Benefits:")
1764
+ print(" - Escape solidified local minima with LR spikes")
1765
+ print(" - Each restart explores WIDER than baseline")
1766
+ print(" - Progressive boost helps late-training plateaus")
1767
+ print(" - Automatic fracturing of failure modes")
1768
+ print("=" * 70)
1769
+
1770
+ # Load data
1771
+ print("\nLoading data...")
1772
+ train_loader, val_loader = get_data_loaders(config)
1773
+ print(f" Train: {len(train_loader.dataset)} samples")
1774
+ print(f" Val: {len(val_loader.dataset)} samples")
1775
+
1776
+ # Train
1777
+ trainer = Trainer(config)
1778
+ trainer.train(train_loader, val_loader)
1779
+
1780
+ print("\n" + "=" * 70)
1781
+ print("🎯 Training complete!")
1782
+ if config.restart_lr_mult > 1.0:
1783
+ print(" Check TensorBoard to see the BOOSTED warm restart cycles!")
1784
+ else:
1785
+ print(" Check TensorBoard to see the warm restart cycles!")
1786
+ print(f" tensorboard --logdir {config.tensorboard_dir}")
1787
+ print("")
1788
+ print(" Look for:")
1789
+ print(" - Smooth LR drops during each cycle")
1790
+ if config.restart_lr_mult > 1.0:
1791
+ print(" - πŸš€ BOOSTED LR jumps at restart epochs")
1792
+ print(" - Wider exploration curves after restarts")
1793
+ else:
1794
+ print(" - Sharp LR jumps at restart epochs")
1795
+ print(" - Accuracy improvements across cycles")
1796
+ print("=" * 70)
1797
+
1798
+
1799
+ if __name__ == "__main__":
1800
+ main()