AbstractPhil commited on
Commit
356a611
Β·
verified Β·
1 Parent(s): f3a9ceb

Create trainer.py

Browse files
Files changed (1) hide show
  1. trainer.py +1118 -0
trainer.py ADDED
@@ -0,0 +1,1118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # train_cantor_fusion_hf.py - PRODUCTION WITH HUGGINGFACE + TENSORBOARD + SAFETENSORS
2
+
3
+ """
4
+ Cantor Fusion Classifier with HuggingFace Integration
5
+ ------------------------------------------------------
6
+
7
+ # Install
8
+ try:
9
+ !pip uninstall -qy geometricvocab
10
+ except:
11
+ pass
12
+
13
+ !pip install -q git+https://github.com/AbstractEyes/lattice_vocabulary.git
14
+
15
+ #
16
+
17
+ Features:
18
+ - HuggingFace Hub uploads (ONE shared repo, organized by run)
19
+ - TensorBoard logging (loss, accuracy, fusion metrics)
20
+ - Easy CIFAR-10/100 switching
21
+ - Automatic checkpoint management
22
+ - SafeTensors format (ClamAV safe)
23
+ - Smart upload intervals
24
+
25
+ Author: AbstractPhil
26
+ License: MIT
27
+ """
28
+
29
+ import torch
30
+ import torch.nn as nn
31
+ import torch.nn.functional as F
32
+ from torch.utils.data import DataLoader
33
+ from torch.utils.tensorboard import SummaryWriter
34
+ from torchvision import datasets, transforms
35
+ from torch.cuda.amp import autocast, GradScaler
36
+ from safetensors.torch import save_file, load_file
37
+
38
+ import math
39
+ import os
40
+ import json
41
+ from typing import Optional, Dict, List, Tuple, Union
42
+ from dataclasses import dataclass, asdict
43
+ import time
44
+ from pathlib import Path
45
+ from tqdm import tqdm
46
+
47
+ # HuggingFace
48
+ from huggingface_hub import HfApi, create_repo, upload_folder, upload_file
49
+ import yaml
50
+
51
+ # Import from your repo
52
+ from geovocab2.train.model.layers.attention.cantor_multiheaded_fusion import (
53
+ CantorMultiheadFusion,
54
+ CantorFusionConfig
55
+ )
56
+ from geovocab2.shapes.factory.cantor_route_factory import (
57
+ CantorRouteFactory,
58
+ RouteMode,
59
+ SimplexConfig
60
+ )
61
+
62
+
63
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
64
+ # Configuration
65
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
66
+
67
+ @dataclass
68
+ class CantorTrainingConfig:
69
+ """Complete configuration for Cantor fusion training."""
70
+
71
+ # Dataset
72
+ dataset: str = "cifar10" # "cifar10" or "cifar100"
73
+ num_classes: int = 10
74
+
75
+ # Architecture
76
+ image_size: int = 32
77
+ patch_size: int = 4
78
+ embed_dim: int = 384
79
+ num_fusion_blocks: int = 6
80
+ num_heads: int = 8
81
+ fusion_window: int = 32
82
+ fusion_mode: str = "weighted" # "weighted" or "consciousness"
83
+ k_simplex: int = 4
84
+ use_beatrix: bool = False
85
+ beatrix_tau: float = 0.25
86
+
87
+ # Optimization
88
+ precompute_geometric: bool = True
89
+ use_torch_compile: bool = True
90
+ use_mixed_precision: bool = False
91
+
92
+ # Regularization
93
+ dropout: float = 0.1
94
+ drop_path_rate: float = 0.15
95
+
96
+ # Training
97
+ batch_size: int = 128
98
+ num_epochs: int = 100
99
+ learning_rate: float = 3e-4
100
+ weight_decay: float = 0.05
101
+ warmup_epochs: int = 5
102
+ grad_clip: float = 1.0
103
+
104
+ # Data augmentation
105
+ use_augmentation: bool = True
106
+ use_autoaugment: bool = True
107
+
108
+ # System
109
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
110
+ num_workers: int = 4
111
+ seed: int = 42
112
+
113
+ # Paths
114
+ weights_dir: str = "weights"
115
+ model_name: str = "vit-beans-v3"
116
+ run_name: Optional[str] = None # Auto-generated if None
117
+
118
+ # HuggingFace - ONE SHARED REPO
119
+ hf_username: str = "AbstractPhil"
120
+ hf_repo_name: Optional[str] = None # Auto-generated if None (shared repo)
121
+ upload_to_hf: bool = True
122
+ hf_token: Optional[str] = None # Set via environment or pass directly
123
+
124
+ # Logging
125
+ log_interval: int = 50 # Log every N batches
126
+ save_interval: int = 10 # Save checkpoint every N epochs
127
+ checkpoint_upload_interval: int = 10 # Upload checkpoint every N epochs
128
+
129
+ def __post_init__(self):
130
+ # Auto-set num_classes based on dataset
131
+ if self.dataset == "cifar10":
132
+ self.num_classes = 10
133
+ elif self.dataset == "cifar100":
134
+ self.num_classes = 100
135
+ else:
136
+ raise ValueError(f"Unknown dataset: {self.dataset}")
137
+
138
+ # Auto-generate run name
139
+ if self.run_name is None:
140
+ timestamp = time.strftime("%Y%m%d_%H%M%S")
141
+ self.run_name = f"{self.dataset}_{self.fusion_mode}_{timestamp}"
142
+
143
+ # ONE SHARED REPO for all runs
144
+ if self.hf_repo_name is None:
145
+ self.hf_repo_name = self.model_name # "cantor-fusion-cifar"
146
+
147
+ # Set HF token from environment if not provided
148
+ if self.hf_token is None:
149
+ self.hf_token = os.environ.get("HF_TOKEN")
150
+
151
+ # Calculate derived values
152
+ assert self.image_size % self.patch_size == 0
153
+ self.num_patches = (self.image_size // self.patch_size) ** 2
154
+ self.patch_dim = self.patch_size * self.patch_size * 3
155
+
156
+ # Create paths
157
+ self.output_dir = Path(self.weights_dir) / self.model_name / self.run_name
158
+ self.checkpoint_dir = self.output_dir / "checkpoints"
159
+ self.tensorboard_dir = self.output_dir / "tensorboard"
160
+
161
+ # Create directories
162
+ self.output_dir.mkdir(parents=True, exist_ok=True)
163
+ self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
164
+ self.tensorboard_dir.mkdir(parents=True, exist_ok=True)
165
+
166
+ def save(self, path: Union[str, Path]):
167
+ """Save config to YAML file."""
168
+ path = Path(path)
169
+ with open(path, 'w') as f:
170
+ yaml.dump(asdict(self), f, default_flow_style=False)
171
+
172
+ @classmethod
173
+ def load(cls, path: Union[str, Path]):
174
+ """Load config from YAML file."""
175
+ path = Path(path)
176
+ with open(path, 'r') as f:
177
+ config_dict = yaml.safe_load(f)
178
+ return cls(**config_dict)
179
+
180
+
181
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
182
+ # Model Components
183
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
184
+
185
+ class PatchEmbedding(nn.Module):
186
+ """Patch embedding layer."""
187
+ def __init__(self, config: CantorTrainingConfig):
188
+ super().__init__()
189
+ self.config = config
190
+ self.proj = nn.Conv2d(3, config.embed_dim, kernel_size=config.patch_size, stride=config.patch_size)
191
+ self.pos_embed = nn.Parameter(torch.randn(1, config.num_patches, config.embed_dim) * 0.02)
192
+
193
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
194
+ x = self.proj(x)
195
+ x = x.flatten(2).transpose(1, 2)
196
+ x = x + self.pos_embed
197
+ return x
198
+
199
+
200
+ class DropPath(nn.Module):
201
+ """Stochastic depth."""
202
+ def __init__(self, drop_prob: float = 0.0):
203
+ super().__init__()
204
+ self.drop_prob = drop_prob
205
+
206
+ def forward(self, x):
207
+ if self.drop_prob == 0. or not self.training:
208
+ return x
209
+ keep_prob = 1 - self.drop_prob
210
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
211
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
212
+ random_tensor.floor_()
213
+ return x.div(keep_prob) * random_tensor
214
+
215
+
216
+ class CantorFusionBlock(nn.Module):
217
+ """Cantor fusion block."""
218
+ def __init__(self, config: CantorTrainingConfig, drop_path: float = 0.0):
219
+ super().__init__()
220
+ self.norm1 = nn.LayerNorm(config.embed_dim)
221
+
222
+ fusion_config = CantorFusionConfig(
223
+ dim=config.embed_dim,
224
+ num_heads=config.num_heads,
225
+ fusion_window=config.fusion_window,
226
+ fusion_mode=config.fusion_mode,
227
+ k_simplex=config.k_simplex,
228
+ use_beatrix_routing=config.use_beatrix,
229
+ use_consciousness_weighting=(config.fusion_mode == "consciousness"),
230
+ beatrix_tau=config.beatrix_tau,
231
+ use_gating=True,
232
+ dropout=config.dropout,
233
+ residual=False,
234
+ precompute_staircase=config.precompute_geometric,
235
+ precompute_routes=config.precompute_geometric,
236
+ precompute_distances=config.precompute_geometric,
237
+ use_optimized_gather=True,
238
+ staircase_cache_sizes=[config.num_patches],
239
+ use_torch_compile=config.use_torch_compile
240
+ )
241
+ self.fusion = CantorMultiheadFusion(fusion_config)
242
+
243
+ self.norm2 = nn.LayerNorm(config.embed_dim)
244
+ mlp_hidden = config.embed_dim * 4
245
+ self.mlp = nn.Sequential(
246
+ nn.Linear(config.embed_dim, mlp_hidden),
247
+ nn.GELU(),
248
+ nn.Dropout(config.dropout),
249
+ nn.Linear(mlp_hidden, config.embed_dim),
250
+ nn.Dropout(config.dropout)
251
+ )
252
+ self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity()
253
+
254
+ def forward(self, x: torch.Tensor, return_fusion_info: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict]]:
255
+ fusion_result = self.fusion(self.norm1(x))
256
+ x = x + self.drop_path(fusion_result['output'])
257
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
258
+
259
+ if return_fusion_info:
260
+ fusion_info = {
261
+ 'consciousness': fusion_result.get('consciousness'),
262
+ 'cantor_measure': fusion_result.get('cantor_measure')
263
+ }
264
+ return x, fusion_info
265
+ return x
266
+
267
+
268
+ class CantorClassifier(nn.Module):
269
+ """Cantor fusion classifier."""
270
+ def __init__(self, config: CantorTrainingConfig):
271
+ super().__init__()
272
+ self.config = config
273
+
274
+ self.patch_embed = PatchEmbedding(config)
275
+
276
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_fusion_blocks)]
277
+ self.blocks = nn.ModuleList([
278
+ CantorFusionBlock(config, drop_path=dpr[i])
279
+ for i in range(config.num_fusion_blocks)
280
+ ])
281
+
282
+ self.norm = nn.LayerNorm(config.embed_dim)
283
+ self.head = nn.Linear(config.embed_dim, config.num_classes)
284
+
285
+ self.apply(self._init_weights)
286
+
287
+ def _init_weights(self, m):
288
+ if isinstance(m, nn.Linear):
289
+ nn.init.trunc_normal_(m.weight, std=0.02)
290
+ if m.bias is not None:
291
+ nn.init.constant_(m.bias, 0)
292
+ elif isinstance(m, nn.LayerNorm):
293
+ nn.init.constant_(m.bias, 0)
294
+ nn.init.constant_(m.weight, 1.0)
295
+ elif isinstance(m, nn.Conv2d):
296
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
297
+
298
+ def forward(self, x: torch.Tensor, return_fusion_info: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, List[Dict]]]:
299
+ x = self.patch_embed(x)
300
+
301
+ fusion_infos = []
302
+ for i, block in enumerate(self.blocks):
303
+ if return_fusion_info and i == len(self.blocks) - 1:
304
+ x, fusion_info = block(x, return_fusion_info=True)
305
+ fusion_infos.append(fusion_info)
306
+ else:
307
+ x = block(x)
308
+
309
+ x = self.norm(x)
310
+ x = x.mean(dim=1)
311
+ logits = self.head(x)
312
+
313
+ if return_fusion_info:
314
+ return logits, fusion_infos
315
+ return logits
316
+
317
+
318
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
319
+ # HuggingFace Integration - ONE SHARED REPO
320
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
321
+
322
+ class HuggingFaceUploader:
323
+ """Manages HuggingFace Hub uploads to ONE shared repo."""
324
+
325
+ def __init__(self, config: CantorTrainingConfig):
326
+ self.config = config
327
+ self.api = HfApi(token=config.hf_token) if config.upload_to_hf else None
328
+ self.repo_id = f"{config.hf_username}/{config.hf_repo_name}"
329
+ # Organize by run inside the shared repo
330
+ self.run_prefix = f"runs/{config.run_name}"
331
+
332
+ if config.upload_to_hf:
333
+ self._create_repo()
334
+ self._update_main_readme() # NEW: Update main README
335
+
336
+ def _create_repo(self):
337
+ """Create HuggingFace repo if it doesn't exist."""
338
+ try:
339
+ create_repo(
340
+ repo_id=self.repo_id,
341
+ token=self.config.hf_token,
342
+ exist_ok=True,
343
+ private=False
344
+ )
345
+ print(f"[HF] Repository: https://huggingface.co/{self.repo_id}")
346
+ print(f"[HF] Run folder: {self.run_prefix}")
347
+ except Exception as e:
348
+ print(f"[HF] Warning: Could not create repo: {e}")
349
+
350
+ def _update_main_readme(self):
351
+ """Create or update the main shared README at repo root."""
352
+ if not self.config.upload_to_hf or self.api is None:
353
+ return
354
+
355
+ main_readme = f"""---
356
+ tags:
357
+ - image-classification
358
+ - cantor-fusion
359
+ - geometric-deep-learning
360
+ - safetensors
361
+ - vision-transformer
362
+ library_name: pytorch
363
+ datasets:
364
+ - cifar10
365
+ - cifar100
366
+ metrics:
367
+ - accuracy
368
+ ---
369
+
370
+ # {self.config.hf_repo_name}
371
+
372
+ **Geometric Deep Learning with Cantor Multihead Fusion**
373
+
374
+ This repository contains multiple training runs using Cantor fusion architecture with pentachoron structures and geometric routing. All models use SafeTensors format for security.
375
+
376
+ ## Repository Structure
377
+ ```
378
+ {self.config.hf_repo_name}/
379
+ β”œβ”€β”€ runs/
380
+ β”‚ β”œβ”€β”€ cifar10_weighted_TIMESTAMP/
381
+ β”‚ β”‚ β”œβ”€β”€ checkpoints/
382
+ β”‚ β”‚ β”‚ β”œβ”€β”€ best_model.safetensors
383
+ β”‚ β”‚ β”‚ β”œβ”€β”€ best_training_state.pt
384
+ β”‚ β”‚ β”‚ └── best_metadata.json
385
+ β”‚ β”‚ β”œβ”€β”€ tensorboard/
386
+ β”‚ β”‚ β”œβ”€β”€ config.yaml
387
+ β”‚ β”‚ └── README.md
388
+ β”‚ β”œβ”€β”€ cifar100_consciousness_TIMESTAMP/
389
+ β”‚ β”‚ └── ...
390
+ β”‚ └── ...
391
+ └── README.md (this file)
392
+ ```
393
+
394
+ ## Current Run
395
+
396
+ **Latest**: `{self.config.run_name}`
397
+ - **Dataset**: {self.config.dataset.upper()}
398
+ - **Fusion Mode**: {self.config.fusion_mode}
399
+ - **Architecture**: {self.config.num_fusion_blocks} blocks, {self.config.num_heads} heads
400
+ - **Simplex**: {self.config.k_simplex}-simplex ({self.config.k_simplex + 1} vertices)
401
+
402
+ ## Architecture
403
+
404
+ The Cantor Fusion architecture uses:
405
+ - **Geometric Routing**: Pentachoron (5-simplex) structures for token routing
406
+ - **Cantor Multihead Fusion**: Multiple fusion heads with geometric attention
407
+ - **Beatrix Consciousness Routing**: Optional consciousness-aware token fusion using the Devil's Staircase
408
+ - **SafeTensors Format**: All model weights use SafeTensors (not pickle) for security
409
+
410
+ ## Usage
411
+
412
+ ### Download a Model
413
+ ```python
414
+ from huggingface_hub import hf_hub_download
415
+ from safetensors.torch import load_file
416
+ import torch
417
+
418
+ # Download model weights
419
+ model_path = hf_hub_download(
420
+ repo_id="{self.repo_id}",
421
+ filename="runs/YOUR_RUN_NAME/checkpoints/best_model.safetensors"
422
+ )
423
+
424
+ # Load weights (SafeTensors - no pickle!)
425
+ state_dict = load_file(model_path)
426
+ model.load_state_dict(state_dict)
427
+ ```
428
+
429
+ ### Browse Runs
430
+
431
+ Each run directory contains:
432
+ - `checkpoints/` - Model weights (safetensors), training state, metadata
433
+ - `tensorboard/` - TensorBoard logs for visualization
434
+ - `config.yaml` - Complete training configuration
435
+ - `README.md` - Run-specific details and results
436
+
437
+ ## Model Variants
438
+
439
+ - **Weighted Fusion**: Standard geometric fusion with learned weights
440
+ - **Consciousness Fusion**: Uses Beatrix routing with consciousness emergence
441
+
442
+ ## Citation
443
+ ```bibtex
444
+ @misc{{{self.config.hf_repo_name.replace('-', '_')},
445
+ author = {{AbstractPhil}},
446
+ title = {{{self.config.hf_repo_name}: Geometric Deep Learning with Cantor Fusion}},
447
+ year = {{2025}},
448
+ publisher = {{HuggingFace}},
449
+ url = {{https://huggingface.co/{self.repo_id}}}
450
+ }}
451
+ ```
452
+
453
+ ## Training Details
454
+
455
+ All models trained with:
456
+ - Optimizer: AdamW
457
+ - Mixed Precision: Available on A100
458
+ - Augmentation: AutoAugment (CIFAR10 policy)
459
+ - Format: SafeTensors (ClamAV safe)
460
+
461
+ Built with geometric consciousness-aware routing using the Devil's Staircase (Beatrix) and pentachoron parameterization.
462
+
463
+ ---
464
+
465
+ **Repository maintained by**: [@{self.config.hf_username}](https://huggingface.co/{self.config.hf_username})
466
+
467
+ **Latest update**: {time.strftime("%Y-%m-%d %H:%M:%S")}
468
+ """
469
+
470
+ # Save main README locally
471
+ main_readme_path = Path(self.config.weights_dir) / self.config.model_name / "MAIN_README.md"
472
+ main_readme_path.parent.mkdir(parents=True, exist_ok=True)
473
+ with open(main_readme_path, 'w') as f:
474
+ f.write(main_readme)
475
+
476
+ try:
477
+ # Upload to repo root (not inside runs/)
478
+ upload_file(
479
+ path_or_fileobj=str(main_readme_path),
480
+ path_in_repo="README.md", # Root level!
481
+ repo_id=self.repo_id,
482
+ token=self.config.hf_token
483
+ )
484
+ print(f"[HF] Updated main README")
485
+ except Exception as e:
486
+ print(f"[HF] Main README upload failed: {e}")
487
+
488
+ def upload_checkpoint(self, checkpoint_path: Path, is_best: bool = False):
489
+ """Upload checkpoint to HuggingFace."""
490
+ if not self.config.upload_to_hf or self.api is None:
491
+ return
492
+
493
+ try:
494
+ # Upload to run-specific folder
495
+ path_in_repo = f"{self.run_prefix}/checkpoints/{checkpoint_path.name}"
496
+ if is_best:
497
+ path_in_repo = f"{self.run_prefix}/checkpoints/best_model.pt"
498
+
499
+ upload_file(
500
+ path_or_fileobj=str(checkpoint_path),
501
+ path_in_repo=path_in_repo,
502
+ repo_id=self.repo_id,
503
+ token=self.config.hf_token
504
+ )
505
+ print(f"[HF] Uploaded: {path_in_repo}")
506
+ except Exception as e:
507
+ print(f"[HF] Upload failed: {e}")
508
+
509
+ def upload_file(self, file_path: Path, repo_path: str):
510
+ """Upload single file to HuggingFace."""
511
+ if not self.config.upload_to_hf or self.api is None:
512
+ return
513
+
514
+ try:
515
+ # Prepend run prefix if not already there
516
+ if not repo_path.startswith(self.run_prefix) and not repo_path.startswith("runs/"):
517
+ full_path = f"{self.run_prefix}/{repo_path}"
518
+ else:
519
+ full_path = repo_path
520
+
521
+ upload_file(
522
+ path_or_fileobj=str(file_path),
523
+ path_in_repo=full_path,
524
+ repo_id=self.repo_id,
525
+ token=self.config.hf_token
526
+ )
527
+ print(f"[HF] βœ“ Uploaded: {full_path}")
528
+ except Exception as e:
529
+ print(f"[HF] βœ— Upload failed ({full_path}): {e}")
530
+
531
+ def upload_folder_contents(self, folder_path: Path, repo_folder: str):
532
+ """Upload entire folder to HuggingFace."""
533
+ if not self.config.upload_to_hf or self.api is None:
534
+ return
535
+
536
+ try:
537
+ # Upload to run-specific folder
538
+ full_path = f"{self.run_prefix}/{repo_folder}"
539
+ upload_folder(
540
+ folder_path=str(folder_path),
541
+ repo_id=self.repo_id,
542
+ path_in_repo=full_path,
543
+ token=self.config.hf_token,
544
+ ignore_patterns=["*.pyc", "__pycache__"]
545
+ )
546
+ print(f"[HF] Uploaded folder: {full_path}")
547
+ except Exception as e:
548
+ print(f"[HF] Folder upload failed: {e}")
549
+
550
+ def create_model_card(self, trainer_stats: Dict):
551
+ """Create and upload run-specific model card."""
552
+ if not self.config.upload_to_hf:
553
+ return
554
+
555
+ run_card = f"""# Run: {self.config.run_name}
556
+
557
+ ## Configuration
558
+ - **Dataset**: {self.config.dataset.upper()}
559
+ - **Fusion Mode**: {self.config.fusion_mode}
560
+ - **Parameters**: {trainer_stats['total_params']:,}
561
+ - **Simplex**: {self.config.k_simplex}-simplex ({self.config.k_simplex + 1} vertices)
562
+
563
+ ## Performance
564
+ - **Best Validation Accuracy**: {trainer_stats['best_acc']:.2f}%
565
+ - **Training Time**: {trainer_stats['training_time']:.1f} hours
566
+ - **Batch Size**: {trainer_stats.get('batch_size', 'N/A')}
567
+ - **Mixed Precision**: {trainer_stats.get('mixed_precision', False)}
568
+ - **Final Epoch**: {trainer_stats['final_epoch']}
569
+
570
+ ## Files
571
+ - `{self.run_prefix}/checkpoints/best_model.safetensors` - Model weights (SafeTensors)
572
+ - `{self.run_prefix}/checkpoints/best_training_state.pt` - Optimizer/scheduler state
573
+ - `{self.run_prefix}/checkpoints/best_metadata.json` - Training metadata
574
+ - `{self.run_prefix}/config.yaml` - Full configuration
575
+ - `{self.run_prefix}/tensorboard/` - TensorBoard logs
576
+
577
+ ## Usage
578
+ ```python
579
+ from safetensors.torch import load_file
580
+ import torch
581
+
582
+ # Download from HuggingFace Hub
583
+ from huggingface_hub import hf_hub_download
584
+
585
+ model_path = hf_hub_download(
586
+ repo_id="{self.repo_id}",
587
+ filename="{self.run_prefix}/checkpoints/best_model.safetensors"
588
+ )
589
+
590
+ # Load model weights (SafeTensors - no pickle!)
591
+ state_dict = load_file(model_path)
592
+ model.load_state_dict(state_dict)
593
+ ```
594
+
595
+ ## Training Configuration
596
+ ```yaml
597
+ embed_dim: {self.config.embed_dim}
598
+ num_fusion_blocks: {self.config.num_fusion_blocks}
599
+ num_heads: {self.config.num_heads}
600
+ fusion_mode: {self.config.fusion_mode}
601
+ k_simplex: {self.config.k_simplex}
602
+ learning_rate: {self.config.learning_rate}
603
+ batch_size: {self.config.batch_size}
604
+ epochs: {self.config.num_epochs}
605
+ weight_decay: {self.config.weight_decay}
606
+ ```
607
+
608
+ ## Details
609
+
610
+ Built with geometric consciousness-aware routing using the Devil's Staircase (Beatrix) and pentachoron parameterization.
611
+
612
+ **Training completed**: {time.strftime("%Y-%m-%d %H:%M:%S")}
613
+
614
+ **Safe Format**: All model weights use SafeTensors (not pickle) for maximum security.
615
+
616
+ ---
617
+
618
+ [← Back to main repository](https://huggingface.co/{self.repo_id})
619
+ """
620
+
621
+ # Save run-specific README
622
+ readme_path = self.config.output_dir / "RUN_README.md"
623
+ with open(readme_path, 'w') as f:
624
+ f.write(run_card)
625
+
626
+ try:
627
+ upload_file(
628
+ path_or_fileobj=str(readme_path),
629
+ path_in_repo=f"{self.run_prefix}/README.md",
630
+ repo_id=self.repo_id,
631
+ token=self.config.hf_token
632
+ )
633
+ print(f"[HF] Uploaded run README")
634
+ except Exception as e:
635
+ print(f"[HF] Run README upload failed: {e}")
636
+
637
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
638
+ # Trainer with TensorBoard + HuggingFace + SafeTensors
639
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
640
+
641
+ class Trainer:
642
+ """Training manager with TensorBoard, HuggingFace, and SafeTensors."""
643
+
644
+ def __init__(self, config: CantorTrainingConfig):
645
+ self.config = config
646
+ self.device = torch.device(config.device)
647
+
648
+ # Set seed
649
+ torch.manual_seed(config.seed)
650
+ if torch.cuda.is_available():
651
+ torch.cuda.manual_seed(config.seed)
652
+
653
+ # Model
654
+ print("\n" + "=" * 70)
655
+ print(f"Initializing Cantor Classifier - {config.dataset.upper()}")
656
+ print("=" * 70)
657
+
658
+ init_start = time.time()
659
+ self.model = CantorClassifier(config).to(self.device)
660
+ init_time = time.time() - init_start
661
+
662
+ print(f"\n[Model] Initialization time: {init_time:.2f}s")
663
+ self.print_model_info()
664
+
665
+ # Optimizer & Scheduler
666
+ self.optimizer = torch.optim.AdamW(
667
+ self.model.parameters(),
668
+ lr=config.learning_rate,
669
+ weight_decay=config.weight_decay
670
+ )
671
+ self.scheduler = self.create_scheduler()
672
+ self.criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
673
+
674
+ # Mixed precision
675
+ self.use_amp = config.use_mixed_precision and config.device == "cuda"
676
+ self.scaler = GradScaler() if self.use_amp else None
677
+
678
+ if self.use_amp:
679
+ print(f"[Training] Mixed precision enabled")
680
+
681
+ # TensorBoard
682
+ self.writer = SummaryWriter(log_dir=str(config.tensorboard_dir))
683
+ print(f"[TensorBoard] Logging to: {config.tensorboard_dir}")
684
+ print(f"[Checkpoints] Format: SafeTensors (ClamAV safe)")
685
+
686
+ # HuggingFace
687
+ self.hf_uploader = HuggingFaceUploader(config) if config.upload_to_hf else None
688
+
689
+ # Save config
690
+ config.save(config.output_dir / "config.yaml")
691
+
692
+ # Metrics
693
+ self.best_acc = 0.0
694
+ self.global_step = 0
695
+ self.start_time = time.time()
696
+ self.upload_count = 0
697
+
698
+ def print_model_info(self):
699
+ """Print model info."""
700
+ total_params = sum(p.numel() for p in self.model.parameters())
701
+ print(f"\nParameters: {total_params:,}")
702
+ print(f"Dataset: {self.config.dataset.upper()}")
703
+ print(f"Classes: {self.config.num_classes}")
704
+ print(f"Fusion mode: {self.config.fusion_mode}")
705
+ print(f"Output: {self.config.output_dir}")
706
+
707
+ def create_scheduler(self):
708
+ """Create scheduler with warmup."""
709
+ def lr_lambda(epoch):
710
+ if epoch < self.config.warmup_epochs:
711
+ return (epoch + 1) / self.config.warmup_epochs
712
+ progress = (epoch - self.config.warmup_epochs) / (self.config.num_epochs - self.config.warmup_epochs)
713
+ return 0.5 * (1 + math.cos(math.pi * progress))
714
+ return torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda)
715
+
716
+ def train_epoch(self, train_loader: DataLoader, epoch: int) -> Tuple[float, float]:
717
+ """Train one epoch."""
718
+ self.model.train()
719
+ total_loss, correct, total = 0.0, 0, 0
720
+
721
+ pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{self.config.num_epochs} [Train]")
722
+
723
+ for batch_idx, (images, labels) in enumerate(pbar):
724
+ images, labels = images.to(self.device, non_blocking=True), labels.to(self.device, non_blocking=True)
725
+
726
+ # Forward
727
+ if self.use_amp:
728
+ with autocast():
729
+ logits = self.model(images)
730
+ loss = self.criterion(logits, labels)
731
+ self.optimizer.zero_grad(set_to_none=True)
732
+ self.scaler.scale(loss).backward()
733
+ self.scaler.unscale_(self.optimizer)
734
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_clip)
735
+ self.scaler.step(self.optimizer)
736
+ self.scaler.update()
737
+ else:
738
+ logits = self.model(images)
739
+ loss = self.criterion(logits, labels)
740
+ self.optimizer.zero_grad(set_to_none=True)
741
+ loss.backward()
742
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_clip)
743
+ self.optimizer.step()
744
+
745
+ # Metrics
746
+ total_loss += loss.item()
747
+ _, predicted = logits.max(1)
748
+ correct += predicted.eq(labels).sum().item()
749
+ total += labels.size(0)
750
+
751
+ # TensorBoard logging
752
+ if batch_idx % self.config.log_interval == 0:
753
+ self.writer.add_scalar('train/loss', loss.item(), self.global_step)
754
+ self.writer.add_scalar('train/accuracy', 100. * correct / total, self.global_step)
755
+ self.writer.add_scalar('train/learning_rate', self.scheduler.get_last_lr()[0], self.global_step)
756
+
757
+ self.global_step += 1
758
+
759
+ pbar.set_postfix({
760
+ 'loss': f'{loss.item():.4f}',
761
+ 'acc': f'{100. * correct / total:.2f}%',
762
+ 'lr': f'{self.scheduler.get_last_lr()[0]:.6f}'
763
+ })
764
+
765
+ return total_loss / len(train_loader), 100. * correct / total
766
+
767
+ @torch.no_grad()
768
+ def evaluate(self, val_loader: DataLoader, epoch: int) -> Tuple[float, Dict]:
769
+ """Evaluate."""
770
+ self.model.eval()
771
+ total_loss, correct, total = 0.0, 0, 0
772
+ consciousness_values = []
773
+
774
+ pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{self.config.num_epochs} [Val] ")
775
+
776
+ for batch_idx, (images, labels) in enumerate(pbar):
777
+ images, labels = images.to(self.device, non_blocking=True), labels.to(self.device, non_blocking=True)
778
+
779
+ # Forward with fusion info on last batch
780
+ return_info = (batch_idx == len(val_loader) - 1)
781
+
782
+ if self.use_amp:
783
+ with autocast():
784
+ if return_info:
785
+ logits, fusion_infos = self.model(images, return_fusion_info=True)
786
+ if fusion_infos and fusion_infos[0].get('consciousness') is not None:
787
+ consciousness_values.append(fusion_infos[0]['consciousness'].mean().item())
788
+ else:
789
+ logits = self.model(images)
790
+ loss = self.criterion(logits, labels)
791
+ else:
792
+ if return_info:
793
+ logits, fusion_infos = self.model(images, return_fusion_info=True)
794
+ if fusion_infos and fusion_infos[0].get('consciousness') is not None:
795
+ consciousness_values.append(fusion_infos[0]['consciousness'].mean().item())
796
+ else:
797
+ logits = self.model(images)
798
+ loss = self.criterion(logits, labels)
799
+
800
+ total_loss += loss.item()
801
+ _, predicted = logits.max(1)
802
+ correct += predicted.eq(labels).sum().item()
803
+ total += labels.size(0)
804
+
805
+ pbar.set_postfix({
806
+ 'loss': f'{total_loss / (batch_idx + 1):.4f}',
807
+ 'acc': f'{100. * correct / total:.2f}%'
808
+ })
809
+
810
+ avg_loss = total_loss / len(val_loader)
811
+ accuracy = 100. * correct / total
812
+
813
+ # TensorBoard logging
814
+ self.writer.add_scalar('val/loss', avg_loss, epoch)
815
+ self.writer.add_scalar('val/accuracy', accuracy, epoch)
816
+ if consciousness_values:
817
+ self.writer.add_scalar('val/consciousness', sum(consciousness_values) / len(consciousness_values), epoch)
818
+
819
+ metrics = {
820
+ 'loss': avg_loss,
821
+ 'accuracy': accuracy,
822
+ 'consciousness': sum(consciousness_values) / len(consciousness_values) if consciousness_values else None
823
+ }
824
+
825
+ return accuracy, metrics
826
+
827
+ def train(self, train_loader: DataLoader, val_loader: DataLoader):
828
+ """Full training loop."""
829
+ print("\n" + "=" * 70)
830
+ print("Starting training...")
831
+ print(f"Format: SafeTensors (model) + PT (training state)")
832
+ print(f"Upload: Best + every {self.config.checkpoint_upload_interval} epochs")
833
+ print("=" * 70 + "\n")
834
+
835
+ for epoch in range(self.config.num_epochs):
836
+ # Train
837
+ train_loss, train_acc = self.train_epoch(train_loader, epoch)
838
+
839
+ # Evaluate
840
+ val_acc, val_metrics = self.evaluate(val_loader, epoch)
841
+
842
+ # Update scheduler
843
+ self.scheduler.step()
844
+
845
+ # Print summary
846
+ print(f"\n{'='*70}")
847
+ print(f"Epoch [{epoch + 1}/{self.config.num_epochs}] Summary:")
848
+ print(f" Train: Loss={train_loss:.4f}, Acc={train_acc:.2f}%")
849
+ print(f" Val: Loss={val_metrics['loss']:.4f}, Acc={val_acc:.2f}%")
850
+ if val_metrics['consciousness']:
851
+ print(f" Consciousness: {val_metrics['consciousness']:.4f}")
852
+
853
+ # Checkpoint logic
854
+ is_best = val_acc > self.best_acc
855
+ should_save_regular = ((epoch + 1) % self.config.save_interval == 0)
856
+ should_upload_regular = ((epoch + 1) % self.config.checkpoint_upload_interval == 0)
857
+
858
+ if is_best:
859
+ self.best_acc = val_acc
860
+ print(f" βœ“ New best model! Accuracy: {val_acc:.2f}%")
861
+ # Save best locally, upload only on interval
862
+ self.save_checkpoint(epoch, val_acc, prefix="best", upload=should_upload_regular)
863
+
864
+ if should_save_regular:
865
+ self.save_checkpoint(epoch, val_acc, prefix=f"epoch_{epoch+1}", upload=should_upload_regular)
866
+
867
+ print(f" HF Uploads: {self.upload_count}")
868
+ print(f"{'='*70}\n")
869
+
870
+ # Flush TensorBoard
871
+ if (epoch + 1) % 10 == 0:
872
+ self.writer.flush()
873
+
874
+ # Training complete
875
+ training_time = (time.time() - self.start_time) / 3600
876
+
877
+ print("\n" + "=" * 70)
878
+ print("Training Complete!")
879
+ print(f"Best Validation Accuracy: {self.best_acc:.2f}%")
880
+ print(f"Training Time: {training_time:.2f} hours")
881
+ print(f"Total Uploads: {self.upload_count}")
882
+ print("=" * 70)
883
+
884
+ # Upload to HuggingFace
885
+ if self.hf_uploader:
886
+ # Always upload final best model
887
+ print("\n[HF] Uploading final best model...")
888
+ best_model_path = self.config.checkpoint_dir / "best_model.safetensors"
889
+ best_state_path = self.config.checkpoint_dir / "best_training_state.pt"
890
+ best_metadata_path = self.config.checkpoint_dir / "best_metadata.json"
891
+ config_path = self.config.output_dir / "config.yaml"
892
+
893
+ if best_model_path.exists():
894
+ self.hf_uploader.upload_file(best_model_path, "checkpoints/best_model.safetensors")
895
+ if best_state_path.exists():
896
+ self.hf_uploader.upload_file(best_state_path, "checkpoints/best_training_state.pt")
897
+ if best_metadata_path.exists():
898
+ self.hf_uploader.upload_file(best_metadata_path, "checkpoints/best_metadata.json")
899
+ if config_path.exists():
900
+ self.hf_uploader.upload_file(config_path, "config.yaml")
901
+
902
+ print("[HF] Final upload: TensorBoard logs...")
903
+ self.hf_uploader.upload_folder_contents(self.config.tensorboard_dir, "tensorboard")
904
+
905
+ trainer_stats = {
906
+ 'total_params': sum(p.numel() for p in self.model.parameters()),
907
+ 'best_acc': self.best_acc,
908
+ 'training_time': training_time,
909
+ 'final_epoch': self.config.num_epochs,
910
+ 'batch_size': self.config.batch_size,
911
+ 'mixed_precision': self.use_amp
912
+ }
913
+ self.hf_uploader.create_model_card(trainer_stats)
914
+
915
+ self.writer.close()
916
+
917
+ def save_checkpoint(self, epoch: int, accuracy: float, prefix: str = "checkpoint", upload: bool = False):
918
+ """Save checkpoint as safetensors with selective upload."""
919
+ checkpoint_dir = self.config.checkpoint_dir
920
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
921
+
922
+ # 1. Save model weights as safetensors (SAFE!)
923
+ model_path = checkpoint_dir / f"{prefix}_model.safetensors"
924
+ save_file(self.model.state_dict(), str(model_path))
925
+
926
+ # 2. Save optimizer/scheduler state separately (small .pt files)
927
+ training_state = {
928
+ 'optimizer_state_dict': self.optimizer.state_dict(),
929
+ 'scheduler_state_dict': self.scheduler.state_dict(),
930
+ }
931
+ if self.scaler is not None:
932
+ training_state['scaler_state_dict'] = self.scaler.state_dict()
933
+
934
+ training_state_path = checkpoint_dir / f"{prefix}_training_state.pt"
935
+ torch.save(training_state, training_state_path)
936
+
937
+ # 3. Save metadata as JSON
938
+ metadata = {
939
+ 'epoch': epoch,
940
+ 'accuracy': accuracy,
941
+ 'best_accuracy': self.best_acc,
942
+ 'global_step': self.global_step,
943
+ 'timestamp': time.strftime("%Y-%m-%d %H:%M:%S")
944
+ }
945
+ metadata_path = checkpoint_dir / f"{prefix}_metadata.json"
946
+ with open(metadata_path, 'w') as f:
947
+ json.dump(metadata, f, indent=2)
948
+
949
+ is_best = (prefix == "best")
950
+
951
+ if is_best:
952
+ print(f" πŸ’Ύ Saved best: {prefix}_model.safetensors")
953
+ else:
954
+ print(f" πŸ’Ύ Saved: {prefix}_model.safetensors", end="")
955
+
956
+ # Upload to HuggingFace
957
+ if self.hf_uploader and upload:
958
+ # Upload model weights (safetensors)
959
+ self.hf_uploader.upload_file(
960
+ model_path,
961
+ f"checkpoints/{prefix}_model.safetensors"
962
+ )
963
+
964
+ # Upload training state (.pt - small file)
965
+ self.hf_uploader.upload_file(
966
+ training_state_path,
967
+ f"checkpoints/{prefix}_training_state.pt"
968
+ )
969
+
970
+ # Upload metadata (json)
971
+ self.hf_uploader.upload_file(
972
+ metadata_path,
973
+ f"checkpoints/{prefix}_metadata.json"
974
+ )
975
+
976
+ # Upload config (only for best)
977
+ if is_best:
978
+ config_path = self.config.output_dir / "config.yaml"
979
+ if config_path.exists():
980
+ self.hf_uploader.upload_file(config_path, "config.yaml")
981
+
982
+ self.upload_count += 1
983
+
984
+ if not is_best:
985
+ print(" β†’ Uploaded to HF")
986
+ else:
987
+ if not is_best:
988
+ print(" (local only)")
989
+
990
+
991
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
992
+ # Data Loading
993
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
994
+
995
+ def get_data_loaders(config: CantorTrainingConfig) -> Tuple[DataLoader, DataLoader]:
996
+ """Create data loaders."""
997
+
998
+ # Normalization (same for both datasets)
999
+ mean = (0.4914, 0.4822, 0.4465)
1000
+ std = (0.2470, 0.2435, 0.2616)
1001
+
1002
+ # Augmentation
1003
+ if config.use_augmentation:
1004
+ if config.use_autoaugment:
1005
+ policy = transforms.AutoAugmentPolicy.CIFAR10
1006
+ train_transform = transforms.Compose([
1007
+ transforms.AutoAugment(policy),
1008
+ transforms.ToTensor(),
1009
+ transforms.Normalize(mean, std)
1010
+ ])
1011
+ else:
1012
+ train_transform = transforms.Compose([
1013
+ transforms.RandomCrop(32, padding=4),
1014
+ transforms.RandomHorizontalFlip(),
1015
+ transforms.ToTensor(),
1016
+ transforms.Normalize(mean, std)
1017
+ ])
1018
+ else:
1019
+ train_transform = transforms.Compose([
1020
+ transforms.ToTensor(),
1021
+ transforms.Normalize(mean, std)
1022
+ ])
1023
+
1024
+ val_transform = transforms.Compose([
1025
+ transforms.ToTensor(),
1026
+ transforms.Normalize(mean, std)
1027
+ ])
1028
+
1029
+ # Dataset selection
1030
+ if config.dataset == "cifar10":
1031
+ train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
1032
+ val_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=val_transform)
1033
+ elif config.dataset == "cifar100":
1034
+ train_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=train_transform)
1035
+ val_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=val_transform)
1036
+ else:
1037
+ raise ValueError(f"Unknown dataset: {config.dataset}")
1038
+
1039
+ train_loader = DataLoader(
1040
+ train_dataset,
1041
+ batch_size=config.batch_size,
1042
+ shuffle=True,
1043
+ num_workers=config.num_workers,
1044
+ pin_memory=(config.device == "cuda")
1045
+ )
1046
+
1047
+ val_loader = DataLoader(
1048
+ val_dataset,
1049
+ batch_size=config.batch_size,
1050
+ shuffle=False,
1051
+ num_workers=config.num_workers,
1052
+ pin_memory=(config.device == "cuda")
1053
+ )
1054
+
1055
+ return train_loader, val_loader
1056
+
1057
+
1058
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
1059
+ # Main
1060
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
1061
+
1062
+ def main():
1063
+ """Main training function."""
1064
+
1065
+ config = CantorTrainingConfig(
1066
+ # Dataset: "cifar10" or "cifar100"
1067
+ dataset="cifar100",
1068
+
1069
+ # Architecture
1070
+ embed_dim=512,
1071
+ num_fusion_blocks=6,
1072
+ num_heads=8,
1073
+ fusion_mode="consciousness", # "weighted" or "consciousness"
1074
+ k_simplex=4,
1075
+ use_beatrix=False,
1076
+
1077
+ # Training
1078
+ batch_size=128,
1079
+ num_epochs=100,
1080
+ learning_rate=3e-4,
1081
+
1082
+ # Augmentation
1083
+ use_augmentation=True,
1084
+ use_autoaugment=True,
1085
+
1086
+ # System
1087
+ device="cuda",
1088
+
1089
+ # HuggingFace - ONE SHARED REPO
1090
+ hf_username="AbstractPhil",
1091
+ upload_to_hf=True,
1092
+ )
1093
+
1094
+ print("=" * 70)
1095
+ print(f"Cantor Fusion Classifier - {config.dataset.upper()}")
1096
+ print("=" * 70)
1097
+ print(f"\nConfiguration:")
1098
+ print(f" Dataset: {config.dataset}")
1099
+ print(f" Fusion mode: {config.fusion_mode}")
1100
+ print(f" Output: {config.output_dir}")
1101
+ print(f" HuggingFace: {'Enabled' if config.upload_to_hf else 'Disabled'}")
1102
+ if config.upload_to_hf:
1103
+ print(f" Repo: {config.hf_username}/{config.hf_repo_name}")
1104
+ print(f" Run: {config.run_name}")
1105
+
1106
+ # Load data
1107
+ print("\nLoading data...")
1108
+ train_loader, val_loader = get_data_loaders(config)
1109
+ print(f" Train: {len(train_loader.dataset)} samples")
1110
+ print(f" Val: {len(val_loader.dataset)} samples")
1111
+
1112
+ # Train
1113
+ trainer = Trainer(config)
1114
+ trainer.train(train_loader, val_loader)
1115
+
1116
+
1117
+ if __name__ == "__main__":
1118
+ main()