AbstractPhil commited on
Commit
0c4228b
Β·
verified Β·
1 Parent(s): 5802634

Create trainer_v4_sgd_warm_restarts.py

Browse files
Files changed (1) hide show
  1. trainer_v4_sgd_warm_restarts.py +1330 -0
trainer_v4_sgd_warm_restarts.py ADDED
@@ -0,0 +1,1330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # train_cantor_fusion_hf.py - PRODUCTION WITH ADAMW + WARM RESTARTS
2
+
3
+ """
4
+ Cantor Fusion Classifier with AdamW + Cosine Warm Restarts
5
+ -----------------------------------------------------------
6
+ Features:
7
+ - AdamW optimizer (best for ViTs)
8
+ - CosineAnnealingWarmRestarts (automatic drop + restart cycles)
9
+ - HuggingFace Hub uploads (ONE shared repo, organized by run)
10
+ - TensorBoard logging (loss, accuracy, fusion metrics, LR tracking)
11
+ - Easy CIFAR-10/100 switching
12
+ - Automatic checkpoint management
13
+ - SafeTensors format (ClamAV safe)
14
+
15
+ Author: AbstractPhil
16
+ License: MIT
17
+ """
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+ from torch.utils.data import DataLoader
23
+ from torch.utils.tensorboard import SummaryWriter
24
+ from torchvision import datasets, transforms
25
+ from torch.cuda.amp import autocast, GradScaler
26
+ from safetensors.torch import save_file, load_file
27
+
28
+ import math
29
+ import os
30
+ import json
31
+ from typing import Optional, Dict, List, Tuple, Union
32
+ from dataclasses import dataclass, asdict
33
+ import time
34
+ from pathlib import Path
35
+ from tqdm import tqdm
36
+
37
+ # HuggingFace
38
+ from huggingface_hub import HfApi, create_repo, upload_folder, upload_file
39
+ import yaml
40
+
41
+ # Import from your repo
42
+ from geovocab2.train.model.layers.attention.cantor_multiheaded_fusion import (
43
+ CantorMultiheadFusion,
44
+ CantorFusionConfig
45
+ )
46
+ from geovocab2.shapes.factory.cantor_route_factory import (
47
+ CantorRouteFactory,
48
+ RouteMode,
49
+ SimplexConfig
50
+ )
51
+
52
+
53
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
54
+ # Configuration
55
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
56
+
57
+ @dataclass
58
+ class CantorTrainingConfig:
59
+ """Complete configuration for Cantor fusion training with AdamW + Warm Restarts."""
60
+
61
+ # Dataset
62
+ dataset: str = "cifar10" # "cifar10" or "cifar100"
63
+ num_classes: int = 10
64
+
65
+ # Architecture
66
+ image_size: int = 32
67
+ patch_size: int = 4
68
+ embed_dim: int = 384
69
+ num_fusion_blocks: int = 6
70
+ num_heads: int = 8
71
+ fusion_window: int = 32
72
+ fusion_mode: str = "weighted" # "weighted" or "consciousness"
73
+ k_simplex: int = 4
74
+ use_beatrix: bool = False
75
+ beatrix_tau: float = 0.25
76
+
77
+ # Optimization
78
+ precompute_geometric: bool = True
79
+ use_torch_compile: bool = True
80
+ use_mixed_precision: bool = False
81
+
82
+ # Regularization
83
+ dropout: float = 0.1
84
+ drop_path_rate: float = 0.1
85
+ label_smoothing: float = 0.1
86
+
87
+ # Training - Optimizer (AdamW)
88
+ optimizer_type: str = "adamw" # "sgd" or "adamw"
89
+ batch_size: int = 128
90
+ num_epochs: int = 300
91
+ learning_rate: float = 3e-4 # AdamW default
92
+ weight_decay: float = 0.05
93
+ grad_clip: float = 1.0
94
+
95
+ # SGD-specific (if needed)
96
+ sgd_momentum: float = 0.9
97
+ sgd_nesterov: bool = True
98
+
99
+ # AdamW-specific
100
+ adamw_betas: Tuple[float, float] = (0.9, 0.999)
101
+ adamw_eps: float = 1e-8
102
+
103
+ # Learning rate schedule - WARM RESTARTS
104
+ scheduler_type: str = "cosine_restarts" # "multistep", "cosine", "cosine_restarts"
105
+
106
+ # CosineAnnealingWarmRestarts parameters
107
+ restart_period: int = 50 # T_0: epochs until first restart
108
+ restart_mult: int = 2 # T_mult: multiply period after each restart
109
+ min_lr: float = 1e-7 # eta_min: minimum learning rate
110
+
111
+ # MultiStepLR (for SGD fallback)
112
+ lr_milestones: List[int] = None
113
+ lr_gamma: float = 0.2
114
+
115
+ # Cosine annealing (regular, no restarts)
116
+ warmup_epochs: int = 10
117
+
118
+ # Data augmentation
119
+ use_augmentation: bool = True
120
+ use_autoaugment: bool = True
121
+ use_cutout: bool = False
122
+ cutout_length: int = 16
123
+
124
+ # System
125
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
126
+ num_workers: int = 4
127
+ seed: int = 42
128
+
129
+ # Paths
130
+ weights_dir: str = "weights"
131
+ model_name: str = "vit-beans-v3"
132
+ run_name: Optional[str] = None # Auto-generated if None
133
+
134
+ # HuggingFace - ONE SHARED REPO
135
+ hf_username: str = "AbstractPhil"
136
+ hf_repo_name: Optional[str] = None
137
+ upload_to_hf: bool = True
138
+ hf_token: Optional[str] = None
139
+
140
+ # Logging
141
+ log_interval: int = 50
142
+ save_interval: int = 10
143
+ checkpoint_upload_interval: int = 20
144
+
145
+ def __post_init__(self):
146
+ # Auto-set num_classes based on dataset
147
+ if self.dataset == "cifar10":
148
+ self.num_classes = 10
149
+ elif self.dataset == "cifar100":
150
+ self.num_classes = 100
151
+ else:
152
+ raise ValueError(f"Unknown dataset: {self.dataset}")
153
+
154
+ # Set default milestones if None (for multistep fallback)
155
+ if self.lr_milestones is None:
156
+ if self.num_epochs >= 200:
157
+ self.lr_milestones = [60, 120, 160]
158
+ elif self.num_epochs >= 100:
159
+ self.lr_milestones = [30, 60, 80]
160
+ else:
161
+ self.lr_milestones = [
162
+ int(self.num_epochs * 0.5),
163
+ int(self.num_epochs * 0.75)
164
+ ]
165
+
166
+ # Auto-generate run name
167
+ if self.run_name is None:
168
+ timestamp = time.strftime("%Y%m%d_%H%M%S")
169
+ opt_name = self.optimizer_type.upper()
170
+ sched_name = "WarmRestart" if self.scheduler_type == "cosine_restarts" else self.scheduler_type
171
+ self.run_name = f"{self.dataset}_{self.fusion_mode}_{opt_name}_{sched_name}_{timestamp}"
172
+
173
+ # ONE SHARED REPO for all runs
174
+ if self.hf_repo_name is None:
175
+ self.hf_repo_name = self.model_name
176
+
177
+ # Set HF token from environment if not provided
178
+ if self.hf_token is None:
179
+ self.hf_token = os.environ.get("HF_TOKEN")
180
+
181
+ # Calculate derived values
182
+ assert self.image_size % self.patch_size == 0
183
+ self.num_patches = (self.image_size // self.patch_size) ** 2
184
+ self.patch_dim = self.patch_size * self.patch_size * 3
185
+
186
+ # Create paths
187
+ self.output_dir = Path(self.weights_dir) / self.model_name / self.run_name
188
+ self.checkpoint_dir = self.output_dir / "checkpoints"
189
+ self.tensorboard_dir = self.output_dir / "tensorboard"
190
+
191
+ # Create directories
192
+ self.output_dir.mkdir(parents=True, exist_ok=True)
193
+ self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
194
+ self.tensorboard_dir.mkdir(parents=True, exist_ok=True)
195
+
196
+ def save(self, path: Union[str, Path]):
197
+ """Save config to YAML file."""
198
+ path = Path(path)
199
+ config_dict = asdict(self)
200
+ # Convert tuples to lists for YAML
201
+ if 'adamw_betas' in config_dict:
202
+ config_dict['adamw_betas'] = list(config_dict['adamw_betas'])
203
+ with open(path, 'w') as f:
204
+ yaml.dump(config_dict, f, default_flow_style=False)
205
+
206
+ @classmethod
207
+ def load(cls, path: Union[str, Path]):
208
+ """Load config from YAML file."""
209
+ path = Path(path)
210
+ with open(path, 'r') as f:
211
+ config_dict = yaml.safe_load(f)
212
+ # Convert lists back to tuples
213
+ if 'adamw_betas' in config_dict:
214
+ config_dict['adamw_betas'] = tuple(config_dict['adamw_betas'])
215
+ return cls(**config_dict)
216
+
217
+
218
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
219
+ # Model Components (unchanged from previous version)
220
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
221
+
222
+ class PatchEmbedding(nn.Module):
223
+ """Patch embedding layer."""
224
+ def __init__(self, config: CantorTrainingConfig):
225
+ super().__init__()
226
+ self.config = config
227
+ self.proj = nn.Conv2d(3, config.embed_dim, kernel_size=config.patch_size, stride=config.patch_size)
228
+ self.pos_embed = nn.Parameter(torch.randn(1, config.num_patches, config.embed_dim) * 0.02)
229
+
230
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
231
+ x = self.proj(x)
232
+ x = x.flatten(2).transpose(1, 2)
233
+ x = x + self.pos_embed
234
+ return x
235
+
236
+
237
+ class DropPath(nn.Module):
238
+ """Stochastic depth."""
239
+ def __init__(self, drop_prob: float = 0.0):
240
+ super().__init__()
241
+ self.drop_prob = drop_prob
242
+
243
+ def forward(self, x):
244
+ if self.drop_prob == 0. or not self.training:
245
+ return x
246
+ keep_prob = 1 - self.drop_prob
247
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
248
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
249
+ random_tensor.floor_()
250
+ return x.div(keep_prob) * random_tensor
251
+
252
+
253
+ class CantorFusionBlock(nn.Module):
254
+ """Cantor fusion block."""
255
+ def __init__(self, config: CantorTrainingConfig, drop_path: float = 0.0):
256
+ super().__init__()
257
+ self.norm1 = nn.LayerNorm(config.embed_dim)
258
+
259
+ fusion_config = CantorFusionConfig(
260
+ dim=config.embed_dim,
261
+ num_heads=config.num_heads,
262
+ fusion_window=config.fusion_window,
263
+ fusion_mode=config.fusion_mode,
264
+ k_simplex=config.k_simplex,
265
+ use_beatrix_routing=config.use_beatrix,
266
+ use_consciousness_weighting=(config.fusion_mode == "consciousness"),
267
+ beatrix_tau=config.beatrix_tau,
268
+ use_gating=True,
269
+ dropout=config.dropout,
270
+ residual=False,
271
+ precompute_staircase=config.precompute_geometric,
272
+ precompute_routes=config.precompute_geometric,
273
+ precompute_distances=config.precompute_geometric,
274
+ use_optimized_gather=True,
275
+ staircase_cache_sizes=[config.num_patches],
276
+ use_torch_compile=config.use_torch_compile
277
+ )
278
+ self.fusion = CantorMultiheadFusion(fusion_config)
279
+
280
+ self.norm2 = nn.LayerNorm(config.embed_dim)
281
+ mlp_hidden = config.embed_dim * 4
282
+ self.mlp = nn.Sequential(
283
+ nn.Linear(config.embed_dim, mlp_hidden),
284
+ nn.GELU(),
285
+ nn.Dropout(config.dropout),
286
+ nn.Linear(mlp_hidden, config.embed_dim),
287
+ nn.Dropout(config.dropout)
288
+ )
289
+ self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity()
290
+
291
+ def forward(self, x: torch.Tensor, return_fusion_info: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict]]:
292
+ fusion_result = self.fusion(self.norm1(x))
293
+ x = x + self.drop_path(fusion_result['output'])
294
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
295
+
296
+ if return_fusion_info:
297
+ fusion_info = {
298
+ 'consciousness': fusion_result.get('consciousness'),
299
+ 'cantor_measure': fusion_result.get('cantor_measure')
300
+ }
301
+ return x, fusion_info
302
+ return x
303
+
304
+
305
+ class CantorClassifier(nn.Module):
306
+ """Cantor fusion classifier."""
307
+ def __init__(self, config: CantorTrainingConfig):
308
+ super().__init__()
309
+ self.config = config
310
+
311
+ self.patch_embed = PatchEmbedding(config)
312
+
313
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_fusion_blocks)]
314
+ self.blocks = nn.ModuleList([
315
+ CantorFusionBlock(config, drop_path=dpr[i])
316
+ for i in range(config.num_fusion_blocks)
317
+ ])
318
+
319
+ self.norm = nn.LayerNorm(config.embed_dim)
320
+ self.head = nn.Linear(config.embed_dim, config.num_classes)
321
+
322
+ self.apply(self._init_weights)
323
+
324
+ def _init_weights(self, m):
325
+ if isinstance(m, nn.Linear):
326
+ nn.init.trunc_normal_(m.weight, std=0.02)
327
+ if m.bias is not None:
328
+ nn.init.constant_(m.bias, 0)
329
+ elif isinstance(m, nn.LayerNorm):
330
+ nn.init.constant_(m.bias, 0)
331
+ nn.init.constant_(m.weight, 1.0)
332
+ elif isinstance(m, nn.Conv2d):
333
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
334
+
335
+ def forward(self, x: torch.Tensor, return_fusion_info: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, List[Dict]]]:
336
+ x = self.patch_embed(x)
337
+
338
+ fusion_infos = []
339
+ for i, block in enumerate(self.blocks):
340
+ if return_fusion_info and i == len(self.blocks) - 1:
341
+ x, fusion_info = block(x, return_fusion_info=True)
342
+ fusion_infos.append(fusion_info)
343
+ else:
344
+ x = block(x)
345
+
346
+ x = self.norm(x)
347
+ x = x.mean(dim=1)
348
+ logits = self.head(x)
349
+
350
+ if return_fusion_info:
351
+ return logits, fusion_infos
352
+ return logits
353
+
354
+
355
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
356
+ # HuggingFace Integration
357
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
358
+
359
+ class HuggingFaceUploader:
360
+ """Manages HuggingFace Hub uploads to ONE shared repo."""
361
+
362
+ def __init__(self, config: CantorTrainingConfig):
363
+ self.config = config
364
+ self.api = HfApi(token=config.hf_token) if config.upload_to_hf else None
365
+ self.repo_id = f"{config.hf_username}/{config.hf_repo_name}"
366
+ self.run_prefix = f"runs/{config.run_name}"
367
+
368
+ if config.upload_to_hf:
369
+ self._create_repo()
370
+ self._update_main_readme()
371
+
372
+ def _create_repo(self):
373
+ """Create HuggingFace repo if it doesn't exist."""
374
+ try:
375
+ create_repo(
376
+ repo_id=self.repo_id,
377
+ token=self.config.hf_token,
378
+ exist_ok=True,
379
+ private=False
380
+ )
381
+ print(f"[HF] Repository: https://huggingface.co/{self.repo_id}")
382
+ print(f"[HF] Run folder: {self.run_prefix}")
383
+ except Exception as e:
384
+ print(f"[HF] Warning: Could not create repo: {e}")
385
+
386
+ def _update_main_readme(self):
387
+ """Create or update the main shared README at repo root."""
388
+ if not self.config.upload_to_hf or self.api is None:
389
+ return
390
+
391
+ main_readme = f"""---
392
+ tags:
393
+ - image-classification
394
+ - cantor-fusion
395
+ - geometric-deep-learning
396
+ - safetensors
397
+ - vision-transformer
398
+ - warm-restarts
399
+ library_name: pytorch
400
+ datasets:
401
+ - cifar10
402
+ - cifar100
403
+ metrics:
404
+ - accuracy
405
+ ---
406
+
407
+ # {self.config.hf_repo_name}
408
+
409
+ **Geometric Deep Learning with Cantor Multihead Fusion + AdamW Warm Restarts**
410
+
411
+ This repository contains multiple training runs using Cantor fusion architecture with pentachoron structures, geometric routing, and **CosineAnnealingWarmRestarts** for automatic exploration cycles.
412
+
413
+ ## Training Strategy: AdamW + Warm Restarts
414
+
415
+ This model uses **AdamW with Cosine Annealing Warm Restarts** (SGDR):
416
+ - **Drop phase**: LR decays from {self.config.learning_rate} β†’ {self.config.min_lr} over {self.config.restart_period} epochs
417
+ - **Restart phase**: LR jumps back to {self.config.learning_rate} to explore new regions
418
+ - **Cycle multiplier**: Each cycle is {self.config.restart_mult}x longer than previous
419
+ - **Benefits**: Automatic exploration + exploitation, finds better minima, robust training
420
+
421
+ ### Restart Schedule
422
+ ```
423
+ Epochs 0-{self.config.restart_period}: LR: {self.config.learning_rate} β†’ {self.config.min_lr} (first cycle)
424
+ Epoch {self.config.restart_period}: LR: RESTART to {self.config.learning_rate} πŸ”„
425
+ Epochs {self.config.restart_period}-{self.config.restart_period * (1 + self.config.restart_mult)}: LR: {self.config.learning_rate} β†’ {self.config.min_lr} (longer cycle)
426
+ ...
427
+ ```
428
+
429
+ ## Current Run
430
+
431
+ **Latest**: `{self.config.run_name}`
432
+ - **Dataset**: {self.config.dataset.upper()}
433
+ - **Fusion Mode**: {self.config.fusion_mode}
434
+ - **Optimizer**: AdamW (adaptive moments)
435
+ - **Scheduler**: CosineAnnealingWarmRestarts
436
+ - **Architecture**: {self.config.num_fusion_blocks} blocks, {self.config.num_heads} heads
437
+ - **Simplex**: {self.config.k_simplex}-simplex ({self.config.k_simplex + 1} vertices)
438
+
439
+ ## Architecture
440
+
441
+ The Cantor Fusion architecture uses:
442
+ - **Geometric Routing**: Pentachoron (5-simplex) structures for token routing
443
+ - **Cantor Multihead Fusion**: Multiple fusion heads with geometric attention
444
+ - **Beatrix Consciousness Routing**: Optional consciousness-aware token fusion
445
+ - **SafeTensors Format**: All model weights use SafeTensors (not pickle)
446
+
447
+ ## Usage
448
+ ```python
449
+ from huggingface_hub import hf_hub_download
450
+ from safetensors.torch import load_file
451
+
452
+ model_path = hf_hub_download(
453
+ repo_id="{self.repo_id}",
454
+ filename="runs/YOUR_RUN_NAME/checkpoints/best_model.safetensors"
455
+ )
456
+
457
+ state_dict = load_file(model_path)
458
+ model.load_state_dict(state_dict)
459
+ ```
460
+
461
+ ## Citation
462
+ ```bibtex
463
+ @misc{{{self.config.hf_repo_name.replace('-', '_')},
464
+ author = {{AbstractPhil}},
465
+ title = {{{self.config.hf_repo_name}: Geometric Deep Learning with Warm Restarts}},
466
+ year = {{2025}},
467
+ publisher = {{HuggingFace}},
468
+ url = {{https://huggingface.co/{self.repo_id}}}
469
+ }}
470
+ ```
471
+
472
+ ---
473
+
474
+ **Repository maintained by**: [@{self.config.hf_username}](https://huggingface.co/{self.config.hf_username})
475
+
476
+ **Latest update**: {time.strftime("%Y-%m-%d %H:%M:%S")}
477
+ """
478
+
479
+ main_readme_path = Path(self.config.weights_dir) / self.config.model_name / "MAIN_README.md"
480
+ main_readme_path.parent.mkdir(parents=True, exist_ok=True)
481
+ with open(main_readme_path, 'w') as f:
482
+ f.write(main_readme)
483
+
484
+ try:
485
+ upload_file(
486
+ path_or_fileobj=str(main_readme_path),
487
+ path_in_repo="README.md",
488
+ repo_id=self.repo_id,
489
+ token=self.config.hf_token
490
+ )
491
+ print(f"[HF] Updated main README")
492
+ except Exception as e:
493
+ print(f"[HF] Main README upload failed: {e}")
494
+
495
+ def upload_file(self, file_path: Path, repo_path: str):
496
+ """Upload single file to HuggingFace."""
497
+ if not self.config.upload_to_hf or self.api is None:
498
+ return
499
+
500
+ try:
501
+ if not repo_path.startswith(self.run_prefix) and not repo_path.startswith("runs/"):
502
+ full_path = f"{self.run_prefix}/{repo_path}"
503
+ else:
504
+ full_path = repo_path
505
+
506
+ upload_file(
507
+ path_or_fileobj=str(file_path),
508
+ path_in_repo=full_path,
509
+ repo_id=self.repo_id,
510
+ token=self.config.hf_token
511
+ )
512
+ print(f"[HF] βœ“ Uploaded: {full_path}")
513
+ except Exception as e:
514
+ print(f"[HF] βœ— Upload failed ({full_path}): {e}")
515
+
516
+ def upload_folder_contents(self, folder_path: Path, repo_folder: str):
517
+ """Upload entire folder to HuggingFace."""
518
+ if not self.config.upload_to_hf or self.api is None:
519
+ return
520
+
521
+ try:
522
+ full_path = f"{self.run_prefix}/{repo_folder}"
523
+ upload_folder(
524
+ folder_path=str(folder_path),
525
+ repo_id=self.repo_id,
526
+ path_in_repo=full_path,
527
+ token=self.config.hf_token,
528
+ ignore_patterns=["*.pyc", "__pycache__"]
529
+ )
530
+ print(f"[HF] Uploaded folder: {full_path}")
531
+ except Exception as e:
532
+ print(f"[HF] Folder upload failed: {e}")
533
+
534
+ def create_model_card(self, trainer_stats: Dict):
535
+ """Create and upload run-specific model card."""
536
+ if not self.config.upload_to_hf:
537
+ return
538
+
539
+ run_card = f"""# Run: {self.config.run_name}
540
+
541
+ ## Configuration
542
+ - **Dataset**: {self.config.dataset.upper()}
543
+ - **Fusion Mode**: {self.config.fusion_mode}
544
+ - **Parameters**: {trainer_stats['total_params']:,}
545
+ - **Simplex**: {self.config.k_simplex}-simplex ({self.config.k_simplex + 1} vertices)
546
+
547
+ ## Performance
548
+ - **Best Validation Accuracy**: {trainer_stats['best_acc']:.2f}%
549
+ - **Training Time**: {trainer_stats['training_time']:.1f} hours
550
+ - **Final Epoch**: {trainer_stats['final_epoch']}
551
+
552
+ ## Training Setup: AdamW + Warm Restarts
553
+ - **Optimizer**: AdamW (lr={self.config.learning_rate}, wd={self.config.weight_decay})
554
+ - **Scheduler**: CosineAnnealingWarmRestarts
555
+ - **Restart Period (T_0)**: {self.config.restart_period} epochs
556
+ - **Cycle Multiplier (T_mult)**: {self.config.restart_mult}x
557
+ - **Min LR**: {self.config.min_lr}
558
+ - **Batch Size**: {self.config.batch_size}
559
+ - **Mixed Precision**: {trainer_stats.get('mixed_precision', False)}
560
+
561
+ ### Learning Rate Schedule
562
+ ```
563
+ Cycle 1: Epochs 0-{self.config.restart_period}
564
+ LR: {self.config.learning_rate} β†’ {self.config.min_lr} (drop)
565
+ Expected: Convergence to local minimum
566
+
567
+ Epoch {self.config.restart_period}: RESTART πŸ”„
568
+ LR: {self.config.min_lr} β†’ {self.config.learning_rate} (jump!)
569
+ Expected: Escape local minimum, explore new regions
570
+
571
+ Cycle 2: Epochs {self.config.restart_period}-{self.config.restart_period * (1 + self.config.restart_mult)}
572
+ LR: {self.config.learning_rate} β†’ {self.config.min_lr} (longer cycle)
573
+ Expected: Deeper convergence
574
+
575
+ ... and so on
576
+ ```
577
+
578
+ ## Files
579
+ - `{self.run_prefix}/checkpoints/best_model.safetensors` - Model weights
580
+ - `{self.run_prefix}/checkpoints/best_training_state.pt` - Optimizer state
581
+ - `{self.run_prefix}/config.yaml` - Full configuration
582
+ - `{self.run_prefix}/tensorboard/` - TensorBoard logs (LR tracking!)
583
+
584
+ ## Usage
585
+ ```python
586
+ from safetensors.torch import load_file
587
+ from huggingface_hub import hf_hub_download
588
+
589
+ model_path = hf_hub_download(
590
+ repo_id="{self.repo_id}",
591
+ filename="{self.run_prefix}/checkpoints/best_model.safetensors"
592
+ )
593
+
594
+ state_dict = load_file(model_path)
595
+ model.load_state_dict(state_dict)
596
+ ```
597
+
598
+ ## Training Notes
599
+
600
+ **Warm Restarts Benefits:**
601
+ - πŸ”„ **Exploration**: Periodic LR jumps escape local minima
602
+ - πŸ“‰ **Exploitation**: Long drop phases converge deeply
603
+ - 🎯 **Robustness**: Multiple restarts find better solutions
604
+ - πŸ“Š **Monitoring**: Watch TensorBoard for restart effects!
605
+
606
+ **Expected Behavior:**
607
+ - Accuracy improves during each drop phase
608
+ - Brief accuracy dips after restarts (exploration)
609
+ - Overall upward trend across cycles
610
+ - Best models often found late in long cycles
611
+
612
+ ---
613
+
614
+ Built with geometric consciousness-aware routing using the Devil's Staircase (Beatrix) and pentachoron parameterization.
615
+
616
+ **Training completed**: {time.strftime("%Y-%m-%d %H:%M:%S")}
617
+
618
+ [← Back to main repository](https://huggingface.co/{self.repo_id})
619
+ """
620
+
621
+ readme_path = self.config.output_dir / "RUN_README.md"
622
+ with open(readme_path, 'w') as f:
623
+ f.write(run_card)
624
+
625
+ try:
626
+ upload_file(
627
+ path_or_fileobj=str(readme_path),
628
+ path_in_repo=f"{self.run_prefix}/README.md",
629
+ repo_id=self.repo_id,
630
+ token=self.config.hf_token
631
+ )
632
+ print(f"[HF] Uploaded run README")
633
+ except Exception as e:
634
+ print(f"[HF] Run README upload failed: {e}")
635
+
636
+
637
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
638
+ # Trainer with AdamW + CosineAnnealingWarmRestarts
639
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
640
+
641
+ class Trainer:
642
+ """Training manager with AdamW + Warm Restarts."""
643
+
644
+ def __init__(self, config: CantorTrainingConfig):
645
+ self.config = config
646
+ self.device = torch.device(config.device)
647
+
648
+ # Set seed
649
+ torch.manual_seed(config.seed)
650
+ if torch.cuda.is_available():
651
+ torch.cuda.manual_seed(config.seed)
652
+
653
+ # Model
654
+ print("\n" + "=" * 70)
655
+ print(f"Initializing Cantor Classifier - {config.dataset.upper()}")
656
+ print("=" * 70)
657
+
658
+ init_start = time.time()
659
+ self.model = CantorClassifier(config).to(self.device)
660
+ init_time = time.time() - init_start
661
+
662
+ print(f"\n[Model] Initialization time: {init_time:.2f}s")
663
+ self.print_model_info()
664
+
665
+ # Track restart epochs for logging
666
+ self.restart_epochs = self._calculate_restart_epochs()
667
+
668
+ # Optimizer
669
+ self.optimizer = self.create_optimizer()
670
+
671
+ # Scheduler
672
+ self.scheduler = self.create_scheduler()
673
+
674
+ # Loss
675
+ self.criterion = nn.CrossEntropyLoss(label_smoothing=config.label_smoothing)
676
+
677
+ # Mixed precision
678
+ self.use_amp = config.use_mixed_precision and config.device == "cuda"
679
+ self.scaler = GradScaler() if self.use_amp else None
680
+
681
+ if self.use_amp:
682
+ print(f"[Training] Mixed precision enabled")
683
+
684
+ # TensorBoard
685
+ self.writer = SummaryWriter(log_dir=str(config.tensorboard_dir))
686
+ print(f"[TensorBoard] Logging to: {config.tensorboard_dir}")
687
+ print(f"[Checkpoints] Format: SafeTensors (ClamAV safe)")
688
+
689
+ # HuggingFace
690
+ self.hf_uploader = HuggingFaceUploader(config) if config.upload_to_hf else None
691
+
692
+ # Save config
693
+ config.save(config.output_dir / "config.yaml")
694
+
695
+ # Metrics
696
+ self.best_acc = 0.0
697
+ self.global_step = 0
698
+ self.start_time = time.time()
699
+ self.upload_count = 0
700
+
701
+
702
+ def _calculate_restart_epochs(self) -> List[int]:
703
+ """Calculate when restarts will occur."""
704
+ if self.config.scheduler_type != "cosine_restarts":
705
+ return []
706
+
707
+ restarts = []
708
+ current = self.config.restart_period
709
+ period = self.config.restart_period
710
+
711
+ while current < self.config.num_epochs:
712
+ restarts.append(current)
713
+ period *= self.config.restart_mult
714
+ current += period
715
+
716
+ return restarts
717
+
718
+ def create_optimizer(self):
719
+ """Create optimizer based on config."""
720
+ if self.config.optimizer_type == "sgd":
721
+ print(f"\n[Optimizer] SGD")
722
+ print(f" LR: {self.config.learning_rate}")
723
+ print(f" Momentum: {self.config.sgd_momentum}")
724
+ print(f" Nesterov: {self.config.sgd_nesterov}")
725
+ print(f" Weight decay: {self.config.weight_decay}")
726
+
727
+ return torch.optim.SGD(
728
+ self.model.parameters(),
729
+ lr=self.config.learning_rate,
730
+ momentum=self.config.sgd_momentum,
731
+ weight_decay=self.config.weight_decay,
732
+ nesterov=self.config.sgd_nesterov
733
+ )
734
+
735
+ elif self.config.optimizer_type == "adamw":
736
+ print(f"\n[Optimizer] AdamW")
737
+ print(f" LR: {self.config.learning_rate}")
738
+ print(f" Betas: {self.config.adamw_betas}")
739
+ print(f" Weight decay: {self.config.weight_decay}")
740
+
741
+ return torch.optim.AdamW(
742
+ self.model.parameters(),
743
+ lr=self.config.learning_rate,
744
+ betas=self.config.adamw_betas,
745
+ eps=self.config.adamw_eps,
746
+ weight_decay=self.config.weight_decay
747
+ )
748
+
749
+ else:
750
+ raise ValueError(f"Unknown optimizer: {self.config.optimizer_type}")
751
+
752
+ def create_scheduler(self):
753
+ """Create LR scheduler based on config."""
754
+ if self.config.scheduler_type == "cosine_restarts":
755
+ print(f"\n[Scheduler] CosineAnnealingWarmRestarts")
756
+ print(f" T_0 (restart period): {self.config.restart_period} epochs")
757
+ print(f" T_mult (cycle multiplier): {self.config.restart_mult}x")
758
+ print(f" Min LR: {self.config.min_lr}")
759
+ print(f"\n Restart schedule:")
760
+ for i, epoch in enumerate(self.restart_epochs[:5]): # Show first 5
761
+ print(f" Restart #{i+1}: Epoch {epoch}")
762
+ if len(self.restart_epochs) > 5:
763
+ print(f" ... and {len(self.restart_epochs) - 5} more")
764
+
765
+ return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
766
+ self.optimizer,
767
+ T_0=self.config.restart_period,
768
+ T_mult=self.config.restart_mult,
769
+ eta_min=self.config.min_lr
770
+ )
771
+
772
+ elif self.config.scheduler_type == "multistep":
773
+ print(f"\n[Scheduler] MultiStepLR")
774
+ print(f" Milestones: {self.config.lr_milestones}")
775
+ print(f" Gamma: {self.config.lr_gamma}")
776
+
777
+ return torch.optim.lr_scheduler.MultiStepLR(
778
+ self.optimizer,
779
+ milestones=self.config.lr_milestones,
780
+ gamma=self.config.lr_gamma
781
+ )
782
+
783
+ elif self.config.scheduler_type == "cosine":
784
+ print(f"\n[Scheduler] Cosine annealing with warmup")
785
+ print(f" Warmup epochs: {self.config.warmup_epochs}")
786
+ print(f" Min LR: {self.config.min_lr}")
787
+
788
+ def lr_lambda(epoch):
789
+ if epoch < self.config.warmup_epochs:
790
+ return (epoch + 1) / self.config.warmup_epochs
791
+ progress = (epoch - self.config.warmup_epochs) / (self.config.num_epochs - self.config.warmup_epochs)
792
+ return 0.5 * (1 + math.cos(math.pi * progress))
793
+
794
+ return torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda)
795
+
796
+ else:
797
+ raise ValueError(f"Unknown scheduler: {self.config.scheduler_type}")
798
+
799
+ def print_model_info(self):
800
+ """Print model info."""
801
+ total_params = sum(p.numel() for p in self.model.parameters())
802
+ print(f"\nParameters: {total_params:,}")
803
+ print(f"Dataset: {self.config.dataset.upper()}")
804
+ print(f"Classes: {self.config.num_classes}")
805
+ print(f"Fusion mode: {self.config.fusion_mode}")
806
+ print(f"Optimizer: {self.config.optimizer_type.upper()}")
807
+ print(f"Scheduler: {self.config.scheduler_type}")
808
+ print(f"Output: {self.config.output_dir}")
809
+
810
+ def train_epoch(self, train_loader: DataLoader, epoch: int) -> Tuple[float, float]:
811
+ """Train one epoch."""
812
+ self.model.train()
813
+ total_loss, correct, total = 0.0, 0, 0
814
+
815
+ # Check if this is a restart epoch
816
+ is_restart = (epoch in self.restart_epochs)
817
+ epoch_desc = f"Epoch {epoch+1}/{self.config.num_epochs}"
818
+ if is_restart:
819
+ epoch_desc += " πŸ”„ RESTART"
820
+
821
+ pbar = tqdm(train_loader, desc=f"{epoch_desc} [Train]")
822
+
823
+ for batch_idx, (images, labels) in enumerate(pbar):
824
+ images, labels = images.to(self.device, non_blocking=True), labels.to(self.device, non_blocking=True)
825
+
826
+ # Forward
827
+ if self.use_amp:
828
+ with autocast():
829
+ logits = self.model(images)
830
+ loss = self.criterion(logits, labels)
831
+ self.optimizer.zero_grad(set_to_none=True)
832
+ self.scaler.scale(loss).backward()
833
+ self.scaler.unscale_(self.optimizer)
834
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_clip)
835
+ self.scaler.step(self.optimizer)
836
+ self.scaler.update()
837
+ else:
838
+ logits = self.model(images)
839
+ loss = self.criterion(logits, labels)
840
+ self.optimizer.zero_grad(set_to_none=True)
841
+ loss.backward()
842
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_clip)
843
+ self.optimizer.step()
844
+
845
+ # Metrics
846
+ total_loss += loss.item()
847
+ _, predicted = logits.max(1)
848
+ correct += predicted.eq(labels).sum().item()
849
+ total += labels.size(0)
850
+
851
+ # TensorBoard logging
852
+ if batch_idx % self.config.log_interval == 0:
853
+ current_lr = self.scheduler.get_last_lr()[0]
854
+ self.writer.add_scalar('train/loss', loss.item(), self.global_step)
855
+ self.writer.add_scalar('train/accuracy', 100. * correct / total, self.global_step)
856
+ self.writer.add_scalar('train/learning_rate', current_lr, self.global_step)
857
+
858
+ self.global_step += 1
859
+
860
+ pbar.set_postfix({
861
+ 'loss': f'{loss.item():.4f}',
862
+ 'acc': f'{100. * correct / total:.2f}%',
863
+ 'lr': f'{self.scheduler.get_last_lr()[0]:.6f}'
864
+ })
865
+
866
+ return total_loss / len(train_loader), 100. * correct / total
867
+
868
+ @torch.no_grad()
869
+ def evaluate(self, val_loader: DataLoader, epoch: int) -> Tuple[float, Dict]:
870
+ """Evaluate."""
871
+ self.model.eval()
872
+ total_loss, correct, total = 0.0, 0, 0
873
+ consciousness_values = []
874
+
875
+ pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{self.config.num_epochs} [Val] ")
876
+
877
+ for batch_idx, (images, labels) in enumerate(pbar):
878
+ images, labels = images.to(self.device, non_blocking=True), labels.to(self.device, non_blocking=True)
879
+
880
+ # Forward with fusion info on last batch
881
+ return_info = (batch_idx == len(val_loader) - 1)
882
+
883
+ if self.use_amp:
884
+ with autocast():
885
+ if return_info:
886
+ logits, fusion_infos = self.model(images, return_fusion_info=True)
887
+ if fusion_infos and fusion_infos[0].get('consciousness') is not None:
888
+ consciousness_values.append(fusion_infos[0]['consciousness'].mean().item())
889
+ else:
890
+ logits = self.model(images)
891
+ loss = self.criterion(logits, labels)
892
+ else:
893
+ if return_info:
894
+ logits, fusion_infos = self.model(images, return_fusion_info=True)
895
+ if fusion_infos and fusion_infos[0].get('consciousness') is not None:
896
+ consciousness_values.append(fusion_infos[0]['consciousness'].mean().item())
897
+ else:
898
+ logits = self.model(images)
899
+ loss = self.criterion(logits, labels)
900
+
901
+ total_loss += loss.item()
902
+ _, predicted = logits.max(1)
903
+ correct += predicted.eq(labels).sum().item()
904
+ total += labels.size(0)
905
+
906
+ pbar.set_postfix({
907
+ 'loss': f'{total_loss / (batch_idx + 1):.4f}',
908
+ 'acc': f'{100. * correct / total:.2f}%'
909
+ })
910
+
911
+ avg_loss = total_loss / len(val_loader)
912
+ accuracy = 100. * correct / total
913
+
914
+ # TensorBoard logging
915
+ self.writer.add_scalar('val/loss', avg_loss, epoch)
916
+ self.writer.add_scalar('val/accuracy', accuracy, epoch)
917
+ if consciousness_values:
918
+ self.writer.add_scalar('val/consciousness', sum(consciousness_values) / len(consciousness_values), epoch)
919
+
920
+ metrics = {
921
+ 'loss': avg_loss,
922
+ 'accuracy': accuracy,
923
+ 'consciousness': sum(consciousness_values) / len(consciousness_values) if consciousness_values else None
924
+ }
925
+
926
+ return accuracy, metrics
927
+
928
+ def train(self, train_loader: DataLoader, val_loader: DataLoader):
929
+ """Full training loop."""
930
+ print("\n" + "=" * 70)
931
+ print("Starting training with AdamW + Warm Restarts")
932
+ print(f"Optimizer: {self.config.optimizer_type.upper()}")
933
+ print(f"Scheduler: {self.config.scheduler_type}")
934
+ print(f"Restart period: {self.config.restart_period} epochs (T_0)")
935
+ print(f"Cycle multiplier: {self.config.restart_mult}x (T_mult)")
936
+ print(f"Total restarts: {len(self.restart_epochs)}")
937
+ print("=" * 70 + "\n")
938
+
939
+ for epoch in range(self.config.num_epochs):
940
+ # Train
941
+ train_loss, train_acc = self.train_epoch(train_loader, epoch)
942
+
943
+ # Evaluate
944
+ val_acc, val_metrics = self.evaluate(val_loader, epoch)
945
+
946
+ # Update scheduler
947
+ self.scheduler.step()
948
+
949
+ # Check if this is a restart epoch
950
+ is_restart = (epoch in self.restart_epochs)
951
+ next_lr = self.scheduler.get_last_lr()[0]
952
+
953
+ # Print summary
954
+ print(f"\n{'='*70}")
955
+ print(f"Epoch [{epoch + 1}/{self.config.num_epochs}] Summary:")
956
+ print(f" Train: Loss={train_loss:.4f}, Acc={train_acc:.2f}%")
957
+ print(f" Val: Loss={val_metrics['loss']:.4f}, Acc={val_acc:.2f}%")
958
+ if val_metrics['consciousness']:
959
+ print(f" Consciousness: {val_metrics['consciousness']:.4f}")
960
+
961
+ if is_restart:
962
+ print(f" πŸ”„ WARM RESTART! Next LR: {next_lr:.6f}")
963
+ print(f" (Escaping local minimum, exploring new regions)")
964
+ else:
965
+ print(f" Current LR: {next_lr:.6f}")
966
+
967
+ # Checkpoint logic
968
+ is_best = val_acc > self.best_acc
969
+ should_save_regular = ((epoch + 1) % self.config.save_interval == 0)
970
+ should_upload_regular = ((epoch + 1) % self.config.checkpoint_upload_interval == 0)
971
+
972
+ if is_best:
973
+ self.best_acc = val_acc
974
+ print(f" βœ“ New best model! Accuracy: {val_acc:.2f}%")
975
+ self.save_checkpoint(epoch, val_acc, prefix="best", upload=should_upload_regular)
976
+
977
+ if should_save_regular:
978
+ self.save_checkpoint(epoch, val_acc, prefix=f"epoch_{epoch+1}", upload=should_upload_regular)
979
+
980
+ print(f" HF Uploads: {self.upload_count}")
981
+ print(f"{'='*70}\n")
982
+
983
+ # Flush TensorBoard
984
+ if (epoch + 1) % 10 == 0:
985
+ self.writer.flush()
986
+
987
+ # Training complete
988
+ training_time = (time.time() - self.start_time) / 3600
989
+
990
+ print("\n" + "=" * 70)
991
+ print("Training Complete!")
992
+ print(f"Best Validation Accuracy: {self.best_acc:.2f}%")
993
+ print(f"Training Time: {training_time:.2f} hours")
994
+ print(f"Total Uploads: {self.upload_count}")
995
+ print(f"Warm Restarts: {len(self.restart_epochs)}")
996
+ print("=" * 70)
997
+
998
+ # Upload to HuggingFace
999
+ if self.hf_uploader:
1000
+ print("\n[HF] Uploading final best model...")
1001
+ best_model_path = self.config.checkpoint_dir / "best_model.safetensors"
1002
+ best_state_path = self.config.checkpoint_dir / "best_training_state.pt"
1003
+ best_metadata_path = self.config.checkpoint_dir / "best_metadata.json"
1004
+ config_path = self.config.output_dir / "config.yaml"
1005
+
1006
+ if best_model_path.exists():
1007
+ self.hf_uploader.upload_file(best_model_path, "checkpoints/best_model.safetensors")
1008
+ if best_state_path.exists():
1009
+ self.hf_uploader.upload_file(best_state_path, "checkpoints/best_training_state.pt")
1010
+ if best_metadata_path.exists():
1011
+ self.hf_uploader.upload_file(best_metadata_path, "checkpoints/best_metadata.json")
1012
+ if config_path.exists():
1013
+ self.hf_uploader.upload_file(config_path, "config.yaml")
1014
+
1015
+ print("[HF] Final upload: TensorBoard logs...")
1016
+ self.hf_uploader.upload_folder_contents(self.config.tensorboard_dir, "tensorboard")
1017
+
1018
+ trainer_stats = {
1019
+ 'total_params': sum(p.numel() for p in self.model.parameters()),
1020
+ 'best_acc': self.best_acc,
1021
+ 'training_time': training_time,
1022
+ 'final_epoch': self.config.num_epochs,
1023
+ 'batch_size': self.config.batch_size,
1024
+ 'mixed_precision': self.use_amp
1025
+ }
1026
+ self.hf_uploader.create_model_card(trainer_stats)
1027
+
1028
+ self.writer.close()
1029
+
1030
+ def save_checkpoint(self, epoch: int, accuracy: float, prefix: str = "checkpoint", upload: bool = False):
1031
+ """Save checkpoint as safetensors with selective upload."""
1032
+ checkpoint_dir = self.config.checkpoint_dir
1033
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
1034
+
1035
+ # 1. Save model weights as safetensors
1036
+ model_path = checkpoint_dir / f"{prefix}_model.safetensors"
1037
+ save_file(self.model.state_dict(), str(model_path))
1038
+
1039
+ # 2. Save optimizer/scheduler state
1040
+ training_state = {
1041
+ 'optimizer_state_dict': self.optimizer.state_dict(),
1042
+ 'scheduler_state_dict': self.scheduler.state_dict(),
1043
+ }
1044
+ if self.scaler is not None:
1045
+ training_state['scaler_state_dict'] = self.scaler.state_dict()
1046
+
1047
+ training_state_path = checkpoint_dir / f"{prefix}_training_state.pt"
1048
+ torch.save(training_state, training_state_path)
1049
+
1050
+ # 3. Save metadata
1051
+ metadata = {
1052
+ 'epoch': epoch,
1053
+ 'accuracy': accuracy,
1054
+ 'best_accuracy': self.best_acc,
1055
+ 'global_step': self.global_step,
1056
+ 'timestamp': time.strftime("%Y-%m-%d %H:%M:%S"),
1057
+ 'optimizer': self.config.optimizer_type,
1058
+ 'scheduler': self.config.scheduler_type,
1059
+ 'learning_rate': self.scheduler.get_last_lr()[0]
1060
+ }
1061
+ metadata_path = checkpoint_dir / f"{prefix}_metadata.json"
1062
+ with open(metadata_path, 'w') as f:
1063
+ json.dump(metadata, f, indent=2)
1064
+
1065
+ is_best = (prefix == "best")
1066
+
1067
+ if is_best:
1068
+ print(f" πŸ’Ύ Saved best: {prefix}_model.safetensors")
1069
+ else:
1070
+ print(f" πŸ’Ύ Saved: {prefix}_model.safetensors", end="")
1071
+
1072
+ # Upload to HuggingFace
1073
+ if self.hf_uploader and upload:
1074
+ self.hf_uploader.upload_file(model_path, f"checkpoints/{prefix}_model.safetensors")
1075
+ self.hf_uploader.upload_file(training_state_path, f"checkpoints/{prefix}_training_state.pt")
1076
+ self.hf_uploader.upload_file(metadata_path, f"checkpoints/{prefix}_metadata.json")
1077
+
1078
+ if is_best:
1079
+ config_path = self.config.output_dir / "config.yaml"
1080
+ if config_path.exists():
1081
+ self.hf_uploader.upload_file(config_path, "config.yaml")
1082
+
1083
+ self.upload_count += 1
1084
+
1085
+ if not is_best:
1086
+ print(" β†’ Uploaded to HF")
1087
+ else:
1088
+ if not is_best:
1089
+ print(" (local only)")
1090
+
1091
+
1092
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
1093
+ # Data Loading (with Cutout)
1094
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
1095
+
1096
+ class Cutout:
1097
+ """Cutout data augmentation."""
1098
+ def __init__(self, length: int):
1099
+ self.length = length
1100
+
1101
+ def __call__(self, img):
1102
+ h, w = img.size(1), img.size(2)
1103
+ mask = torch.ones((h, w), dtype=torch.float32)
1104
+ y = torch.randint(h, (1,)).item()
1105
+ x = torch.randint(w, (1,)).item()
1106
+
1107
+ y1 = max(0, y - self.length // 2)
1108
+ y2 = min(h, y + self.length // 2)
1109
+ x1 = max(0, x - self.length // 2)
1110
+ x2 = min(w, x + self.length // 2)
1111
+
1112
+ mask[y1:y2, x1:x2] = 0.
1113
+ mask = mask.expand_as(img)
1114
+ return img * mask
1115
+
1116
+
1117
+ def get_data_loaders(config: CantorTrainingConfig) -> Tuple[DataLoader, DataLoader]:
1118
+ """Create data loaders."""
1119
+
1120
+ # Normalization
1121
+ mean = (0.4914, 0.4822, 0.4465)
1122
+ std = (0.2470, 0.2435, 0.2616)
1123
+
1124
+ # Augmentation
1125
+ if config.use_augmentation:
1126
+ transforms_list = []
1127
+
1128
+ if config.use_autoaugment:
1129
+ policy = transforms.AutoAugmentPolicy.CIFAR10
1130
+ transforms_list.append(transforms.AutoAugment(policy))
1131
+ else:
1132
+ transforms_list.extend([
1133
+ transforms.RandomCrop(32, padding=4),
1134
+ transforms.RandomHorizontalFlip(),
1135
+ ])
1136
+
1137
+ transforms_list.append(transforms.ToTensor())
1138
+ transforms_list.append(transforms.Normalize(mean, std))
1139
+
1140
+ if config.use_cutout:
1141
+ transforms_list.append(Cutout(config.cutout_length))
1142
+
1143
+ train_transform = transforms.Compose(transforms_list)
1144
+ else:
1145
+ train_transform = transforms.Compose([
1146
+ transforms.ToTensor(),
1147
+ transforms.Normalize(mean, std)
1148
+ ])
1149
+
1150
+ val_transform = transforms.Compose([
1151
+ transforms.ToTensor(),
1152
+ transforms.Normalize(mean, std)
1153
+ ])
1154
+
1155
+ # Dataset selection
1156
+ if config.dataset == "cifar10":
1157
+ train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
1158
+ val_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=val_transform)
1159
+ elif config.dataset == "cifar100":
1160
+ train_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=train_transform)
1161
+ val_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=val_transform)
1162
+ else:
1163
+ raise ValueError(f"Unknown dataset: {config.dataset}")
1164
+
1165
+ train_loader = DataLoader(
1166
+ train_dataset,
1167
+ batch_size=config.batch_size,
1168
+ shuffle=True,
1169
+ num_workers=config.num_workers,
1170
+ pin_memory=(config.device == "cuda")
1171
+ )
1172
+
1173
+ val_loader = DataLoader(
1174
+ val_dataset,
1175
+ batch_size=config.batch_size,
1176
+ shuffle=False,
1177
+ num_workers=config.num_workers,
1178
+ pin_memory=(config.device == "cuda")
1179
+ )
1180
+
1181
+ return train_loader, val_loader
1182
+
1183
+
1184
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
1185
+ # Main - AdamW + CosineAnnealingWarmRestarts
1186
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
1187
+
1188
+ def main():
1189
+ """Main training function with AdamW + Warm Restarts."""
1190
+
1191
+ # ═══════════════════════════════════════════════════════════════════
1192
+ # Configuration - AdamW with Cosine Annealing Warm Restarts (SGDR)
1193
+ # ═══════════════════════════════════════════════════════════════════
1194
+
1195
+ config = CantorTrainingConfig(
1196
+ # Dataset
1197
+ dataset="cifar100", # "cifar10" or "cifar100"
1198
+
1199
+ # Architecture
1200
+ embed_dim=512,
1201
+ num_fusion_blocks=6,
1202
+ num_heads=8,
1203
+ fusion_mode="weighted", # "weighted" or "consciousness"
1204
+ k_simplex=4,
1205
+ use_beatrix=True,
1206
+
1207
+ # Optimizer: AdamW (best for ViTs)
1208
+ optimizer_type="adamw",
1209
+ learning_rate=3e-4, # Standard ViT learning rate
1210
+ weight_decay=0.05, # Standard ViT weight decay
1211
+ adamw_betas=(0.9, 0.999),
1212
+
1213
+ # Scheduler: Cosine Annealing with Warm Restarts (SGDR)
1214
+ scheduler_type="cosine_restarts", # THE KEY!
1215
+ restart_period=20, # T_0: First restart at epoch 50
1216
+ restart_mult=2, # T_mult: Each cycle 2x longer
1217
+ min_lr=1e-7, # Minimum LR at end of each cycle
1218
+
1219
+ # Training
1220
+ num_epochs=300, # Longer training with multiple cycles
1221
+ batch_size=128,
1222
+ grad_clip=1.0,
1223
+ label_smoothing=0.1,
1224
+
1225
+ # Augmentation
1226
+ use_augmentation=True,
1227
+ use_autoaugment=True,
1228
+ use_cutout=False,
1229
+ cutout_length=16,
1230
+
1231
+ # Regularization
1232
+ dropout=0.1,
1233
+ drop_path_rate=0.1,
1234
+
1235
+ # System
1236
+ device="cuda",
1237
+ use_mixed_precision=False,
1238
+
1239
+ # HuggingFace
1240
+ hf_username="AbstractPhil",
1241
+ upload_to_hf=True,
1242
+ checkpoint_upload_interval=25, # Upload every 25 epochs
1243
+ )
1244
+
1245
+ print("=" * 70)
1246
+ print(f"Cantor Fusion Classifier - {config.dataset.upper()}")
1247
+ print("Training Strategy: AdamW + Cosine Annealing Warm Restarts (SGDR)")
1248
+ print("=" * 70)
1249
+ print(f"\nConfiguration:")
1250
+ print(f" Dataset: {config.dataset}")
1251
+ print(f" Fusion mode: {config.fusion_mode}")
1252
+ print(f" Optimizer: AdamW")
1253
+ print(f" Scheduler: CosineAnnealingWarmRestarts")
1254
+ print(f" Initial LR: {config.learning_rate}")
1255
+ print(f" Min LR: {config.min_lr}")
1256
+ print(f" Restart period (T_0): {config.restart_period} epochs")
1257
+ print(f" Cycle multiplier (T_mult): {config.restart_mult}x")
1258
+ print(f" Total epochs: {config.num_epochs}")
1259
+
1260
+ # Calculate restart schedule
1261
+ restarts = []
1262
+ current = config.restart_period
1263
+ period = config.restart_period
1264
+ while current < config.num_epochs:
1265
+ restarts.append(current)
1266
+ period *= config.restart_mult
1267
+ current += period
1268
+
1269
+ print(f"\n Restart schedule ({len(restarts)} restarts):")
1270
+ for i, epoch in enumerate(restarts[:5]):
1271
+ print(f" Restart #{i+1}: Epoch {epoch}")
1272
+ if len(restarts) > 5:
1273
+ print(f" ... and {len(restarts) - 5} more")
1274
+
1275
+ print(f"\n Output: {config.output_dir}")
1276
+ print(f" HuggingFace: {'Enabled' if config.upload_to_hf else 'Disabled'}")
1277
+ if config.upload_to_hf:
1278
+ print(f" Repo: {config.hf_username}/{config.hf_repo_name}")
1279
+ print(f" Run: {config.run_name}")
1280
+
1281
+ print("\n" + "=" * 70)
1282
+ print("Expected Training Behavior:")
1283
+ print("=" * 70)
1284
+ print("πŸ“‰ Cycle 1 (epochs 0-50):")
1285
+ print(" LR: 3e-4 β†’ 1e-7 (smooth drop)")
1286
+ print(" Expected: Convergence to local minimum")
1287
+ print("")
1288
+ print("πŸ”„ Epoch 50: RESTART!")
1289
+ print(" LR: 1e-7 β†’ 3e-4 (jump back up)")
1290
+ print(" Expected: Escape local minimum, explore new regions")
1291
+ print("")
1292
+ print("πŸ“‰ Cycle 2 (epochs 50-150):")
1293
+ print(" LR: 3e-4 β†’ 1e-7 (longer cycle, deeper convergence)")
1294
+ print("")
1295
+ print("πŸ”„ Epoch 150: RESTART!")
1296
+ print(" LR: 1e-7 β†’ 3e-4")
1297
+ print("")
1298
+ print("πŸ“‰ Cycle 3 (epochs 150-350+):")
1299
+ print(" LR: 3e-4 β†’ 1e-7 (longest cycle, final polish)")
1300
+ print("")
1301
+ print("🎯 Benefits:")
1302
+ print(" - Multiple exploration cycles find better minima")
1303
+ print(" - Automatic drop + restart (no manual tuning)")
1304
+ print(" - Robust to different datasets/architectures")
1305
+ print("=" * 70)
1306
+
1307
+ # Load data
1308
+ print("\nLoading data...")
1309
+ train_loader, val_loader = get_data_loaders(config)
1310
+ print(f" Train: {len(train_loader.dataset)} samples")
1311
+ print(f" Val: {len(val_loader.dataset)} samples")
1312
+
1313
+ # Train
1314
+ trainer = Trainer(config)
1315
+ trainer.train(train_loader, val_loader)
1316
+
1317
+ print("\n" + "=" * 70)
1318
+ print("🎯 Training complete!")
1319
+ print(" Check TensorBoard to see the warm restart cycles!")
1320
+ print(f" tensorboard --logdir {config.tensorboard_dir}")
1321
+ print("")
1322
+ print(" Look for:")
1323
+ print(" - Smooth LR drops during each cycle")
1324
+ print(" - Sharp LR jumps at restart epochs")
1325
+ print(" - Accuracy improvements across cycles")
1326
+ print("=" * 70)
1327
+
1328
+
1329
+ if __name__ == "__main__":
1330
+ main()