AbstractPhil commited on
Commit
18e6c6b
Β·
verified Β·
1 Parent(s): 613faa7

Create trainer_v1.py

Browse files
Files changed (1) hide show
  1. trainer_v1.py +1594 -0
trainer_v1.py ADDED
@@ -0,0 +1,1594 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train DavidBeans: The Dynamic Duo
3
+ ==================================
4
+
5
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
6
+ β”‚ BEANS β”‚ "I see the patches..."
7
+ β”‚ (ViT Backbone)β”‚
8
+ β”‚ 🫘 β†’ 🫘 β†’ 🫘 β”‚ Cantor-routed sparse attention
9
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”˜
10
+ β”‚
11
+ β–Ό
12
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
13
+ β”‚ DAVID β”‚ "I know the crystals..."
14
+ β”‚ (Classifier) β”‚
15
+ β”‚ πŸ’Ž β†’ πŸ’Ž β†’ πŸ’Ž β”‚ Multi-scale projection
16
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”˜
17
+ β”‚
18
+ β–Ό
19
+ [Prediction]
20
+
21
+ Cross-contrast learning aligns patch features with crystal anchors.
22
+ Unified Cayley-Menger loss maintains geometric structure throughout.
23
+
24
+ Features:
25
+ - HuggingFace Hub integration for model upload
26
+ - Automatic model card generation
27
+ - Checkpoint management
28
+
29
+ Author: AbstractPhil
30
+ Date: November 28, 2025
31
+ """
32
+
33
+ import torch
34
+ import torch.nn as nn
35
+ import torch.nn.functional as F
36
+ from torch.utils.data import DataLoader
37
+ from torch.optim import AdamW
38
+ from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR
39
+ from tqdm.auto import tqdm
40
+ import time
41
+ import math
42
+ from pathlib import Path
43
+ from typing import Dict, Optional, Tuple, List
44
+ from dataclasses import dataclass, field
45
+ import json
46
+ import os
47
+ from datetime import datetime
48
+
49
+ # Import the model
50
+ from geofractal.model.david_beans.model import DavidBeans, DavidBeansConfig
51
+
52
+ # HuggingFace Hub integration
53
+ try:
54
+ from huggingface_hub import HfApi, create_repo, upload_folder
55
+ HF_HUB_AVAILABLE = True
56
+ except ImportError:
57
+ HF_HUB_AVAILABLE = False
58
+ print(" [!] huggingface_hub not installed. Run: pip install huggingface_hub")
59
+
60
+ # Safetensors support
61
+ try:
62
+ from safetensors.torch import save_file as save_safetensors
63
+ SAFETENSORS_AVAILABLE = True
64
+ except ImportError:
65
+ SAFETENSORS_AVAILABLE = False
66
+
67
+ # TensorBoard support
68
+ try:
69
+ from torch.utils.tensorboard import SummaryWriter
70
+ TENSORBOARD_AVAILABLE = True
71
+ except ImportError:
72
+ TENSORBOARD_AVAILABLE = False
73
+ print(" [!] tensorboard not installed. Run: pip install tensorboard")
74
+
75
+
76
+ # ============================================================================
77
+ # TRAINING CONFIGURATION
78
+ # ============================================================================
79
+
80
+ @dataclass
81
+ class TrainingConfig:
82
+ """Training hyperparameters."""
83
+
84
+ # Run identification
85
+ run_name: str = "default" # Descriptive name for this run
86
+ run_number: Optional[int] = None # Auto-incremented if None
87
+
88
+ # Data
89
+ dataset: str = "cifar10"
90
+ image_size: int = 32
91
+ batch_size: int = 128
92
+ num_workers: int = 4
93
+
94
+ # Training schedule
95
+ epochs: int = 100
96
+ warmup_epochs: int = 5
97
+
98
+ # Optimizer
99
+ learning_rate: float = 1e-3
100
+ weight_decay: float = 0.05
101
+ betas: Tuple[float, float] = (0.9, 0.999)
102
+
103
+ # Learning rate schedule
104
+ scheduler: str = "cosine"
105
+ min_lr: float = 1e-6
106
+
107
+ # Loss weights
108
+ ce_weight: float = 1.0
109
+ cayley_weight: float = 0.01
110
+ contrast_weight: float = 0.5
111
+ scale_ce_weight: float = 0.1
112
+
113
+ # Regularization
114
+ gradient_clip: float = 1.0
115
+ label_smoothing: float = 0.1
116
+
117
+ # Augmentation
118
+ use_augmentation: bool = True
119
+ mixup_alpha: float = 0.2
120
+ cutmix_alpha: float = 1.0
121
+
122
+ # Checkpointing
123
+ save_interval: int = 10
124
+ output_dir: str = "./checkpoints"
125
+ resume_from: Optional[str] = None # Path to checkpoint or "latest"
126
+
127
+ # TensorBoard
128
+ use_tensorboard: bool = True
129
+ log_interval: int = 50 # Log every N batches
130
+
131
+ # HuggingFace Hub
132
+ push_to_hub: bool = False
133
+ hub_repo_id: Optional[str] = None
134
+ hub_private: bool = False
135
+ hub_append_run: bool = True # Append run info to repo_id (e.g., repo-run001-baseline)
136
+
137
+ # Device
138
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
139
+
140
+ def to_dict(self) -> Dict:
141
+ return {k: v for k, v in self.__dict__.items()}
142
+
143
+
144
+ # ============================================================================
145
+ # HUGGINGFACE HUB INTEGRATION
146
+ # ============================================================================
147
+
148
+ def generate_model_card(
149
+ model_config: DavidBeansConfig,
150
+ train_config: TrainingConfig,
151
+ best_acc: float,
152
+ training_history: Optional[Dict] = None
153
+ ) -> str:
154
+ """Generate a model card for HuggingFace Hub."""
155
+
156
+ scales_str = ", ".join([str(s) for s in model_config.scales])
157
+
158
+ dataset_info = {
159
+ "cifar10": ("CIFAR-10", 10, "Image classification on 32x32 images"),
160
+ "cifar100": ("CIFAR-100", 100, "Fine-grained image classification on 32x32 images"),
161
+ }.get(train_config.dataset, (train_config.dataset, model_config.num_classes, ""))
162
+
163
+ card_content = f"""---
164
+ library_name: pytorch
165
+ license: apache-2.0
166
+ tags:
167
+ - vision
168
+ - image-classification
169
+ - geometric-deep-learning
170
+ - vit
171
+ - cantor-routing
172
+ - pentachoron
173
+ - multi-scale
174
+ datasets:
175
+ - {train_config.dataset}
176
+ metrics:
177
+ - accuracy
178
+ model-index:
179
+ - name: DavidBeans
180
+ results:
181
+ - task:
182
+ type: image-classification
183
+ name: Image Classification
184
+ dataset:
185
+ name: {dataset_info[0]}
186
+ type: {train_config.dataset}
187
+ metrics:
188
+ - type: accuracy
189
+ value: {best_acc:.2f}
190
+ name: Top-1 Accuracy
191
+ ---
192
+
193
+ # πŸ«˜πŸ’Ž DavidBeans: Unified Vision-to-Crystal Architecture
194
+
195
+ DavidBeans combines **ViT-Beans** (Cantor-routed sparse attention) with **David** (multi-scale crystal classification) into a unified geometric deep learning architecture.
196
+
197
+ ## Model Description
198
+
199
+ This model implements several novel techniques:
200
+
201
+ - **Hybrid Cantor Routing**: Combines fractal Cantor set distances with positional proximity for sparse attention patterns
202
+ - **Pentachoron Experts**: 5-vertex simplex structure with Cayley-Menger geometric regularization
203
+ - **Multi-Scale Crystal Projection**: Projects features to multiple representation scales with learned fusion
204
+ - **Cross-Contrastive Learning**: Aligns patch-level features with crystal anchors
205
+
206
+ ## Architecture
207
+
208
+ ```
209
+ Image [B, 3, {model_config.image_size}, {model_config.image_size}]
210
+ β”‚
211
+ β–Ό
212
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
213
+ β”‚ BEANS BACKBONE β”‚
214
+ β”‚ β”œβ”€ Patch Embed β†’ [{model_config.num_patches} patches, {model_config.dim}d]
215
+ β”‚ β”œβ”€ Hybrid Cantor Router (Ξ±={model_config.cantor_weight})
216
+ β”‚ β”œβ”€ {model_config.num_layers} Γ— Attention Blocks ({model_config.num_heads} heads)
217
+ β”‚ └─ {model_config.num_layers} Γ— Pentachoron Expert Layers
218
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
219
+ β”‚
220
+ β–Ό
221
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
222
+ β”‚ DAVID HEAD β”‚
223
+ β”‚ β”œβ”€ Multi-scale projection: [{scales_str}]
224
+ β”‚ β”œβ”€ Per-scale Crystal Heads
225
+ β”‚ └─ Geometric Fusion (learned weights)
226
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
227
+ β”‚
228
+ β–Ό
229
+ [{model_config.num_classes} classes]
230
+ ```
231
+
232
+ ## Training Details
233
+
234
+ | Parameter | Value |
235
+ |-----------|-------|
236
+ | Dataset | {dataset_info[0]} |
237
+ | Classes | {model_config.num_classes} |
238
+ | Image Size | {model_config.image_size}Γ—{model_config.image_size} |
239
+ | Patch Size | {model_config.patch_size}Γ—{model_config.patch_size} |
240
+ | Embedding Dim | {model_config.dim} |
241
+ | Layers | {model_config.num_layers} |
242
+ | Attention Heads | {model_config.num_heads} |
243
+ | Experts | {model_config.num_experts} (pentachoron) |
244
+ | Sparse Neighbors | k={model_config.k_neighbors} |
245
+ | Scales | [{scales_str}] |
246
+ | Epochs | {train_config.epochs} |
247
+ | Batch Size | {train_config.batch_size} |
248
+ | Learning Rate | {train_config.learning_rate} |
249
+ | Weight Decay | {train_config.weight_decay} |
250
+ | Mixup Ξ± | {train_config.mixup_alpha} |
251
+ | CutMix Ξ± | {train_config.cutmix_alpha} |
252
+ | Label Smoothing | {train_config.label_smoothing} |
253
+
254
+ ## Results
255
+
256
+ | Metric | Value |
257
+ |--------|-------|
258
+ | **Top-1 Accuracy** | **{best_acc:.2f}%** |
259
+
260
+ ## TensorBoard Logs
261
+
262
+ Training logs are included in the `tensorboard/` directory. To view:
263
+
264
+ ```bash
265
+ tensorboard --logdir tensorboard/
266
+ ```
267
+
268
+ ## Usage
269
+
270
+ ```python
271
+ import torch
272
+ from safetensors.torch import load_file
273
+ from david_beans import DavidBeans, DavidBeansConfig
274
+
275
+ # Load config
276
+ config = DavidBeansConfig(
277
+ image_size={model_config.image_size},
278
+ patch_size={model_config.patch_size},
279
+ dim={model_config.dim},
280
+ num_layers={model_config.num_layers},
281
+ num_heads={model_config.num_heads},
282
+ num_experts={model_config.num_experts},
283
+ k_neighbors={model_config.k_neighbors},
284
+ cantor_weight={model_config.cantor_weight},
285
+ scales={model_config.scales},
286
+ num_classes={model_config.num_classes}
287
+ )
288
+
289
+ # Create model and load weights
290
+ model = DavidBeans(config)
291
+ state_dict = load_file("model.safetensors")
292
+ model.load_state_dict(state_dict)
293
+
294
+ # Inference
295
+ model.eval()
296
+ with torch.no_grad():
297
+ output = model(images)
298
+ predictions = output['logits'].argmax(dim=-1)
299
+ ```
300
+
301
+ ## Citation
302
+
303
+ ```bibtex
304
+ @misc{{davidbeans2025,
305
+ author = {{AbstractPhil}},
306
+ title = {{DavidBeans: Unified Vision-to-Crystal Architecture}},
307
+ year = {{2025}},
308
+ publisher = {{HuggingFace}},
309
+ url = {{https://huggingface.co/{train_config.hub_repo_id or 'AbstractPhil/david-beans'}}}
310
+ }}
311
+ ```
312
+
313
+ ## License
314
+
315
+ Apache 2.0
316
+ """
317
+
318
+ return card_content
319
+
320
+
321
+ def save_for_hub(
322
+ model: DavidBeans,
323
+ model_config: DavidBeansConfig,
324
+ train_config: TrainingConfig,
325
+ best_acc: float,
326
+ output_dir: Path,
327
+ training_history: Optional[Dict] = None
328
+ ) -> Path:
329
+ """Save model in HuggingFace Hub format."""
330
+
331
+ hub_dir = output_dir / "hub"
332
+ hub_dir.mkdir(parents=True, exist_ok=True)
333
+
334
+ # 1. Save model weights - clone to avoid shared memory issues
335
+ state_dict = {k: v.clone() for k, v in model.state_dict().items()}
336
+
337
+ if SAFETENSORS_AVAILABLE:
338
+ try:
339
+ save_safetensors(state_dict, hub_dir / "model.safetensors")
340
+ print(f" βœ“ Saved model.safetensors")
341
+ except Exception as e:
342
+ print(f" [!] Safetensors failed ({e}), using pytorch format only")
343
+
344
+ # Also save PyTorch format for compatibility
345
+ torch.save(state_dict, hub_dir / "pytorch_model.bin")
346
+ print(f" βœ“ Saved pytorch_model.bin")
347
+
348
+ # 2. Save config
349
+ config_dict = {
350
+ "architecture": "DavidBeans",
351
+ "model_type": "david_beans",
352
+ **model_config.__dict__
353
+ }
354
+ with open(hub_dir / "config.json", "w") as f:
355
+ json.dump(config_dict, f, indent=2, default=str)
356
+ print(f" βœ“ Saved config.json")
357
+
358
+ # 3. Save training config
359
+ with open(hub_dir / "training_config.json", "w") as f:
360
+ json.dump(train_config.to_dict(), f, indent=2, default=str)
361
+
362
+ # 4. Generate and save model card
363
+ model_card = generate_model_card(model_config, train_config, best_acc, training_history)
364
+ with open(hub_dir / "README.md", "w") as f:
365
+ f.write(model_card)
366
+ print(f" βœ“ Generated README.md (model card)")
367
+
368
+ # 5. Save training history if available
369
+ if training_history:
370
+ with open(hub_dir / "training_history.json", "w") as f:
371
+ json.dump(training_history, f, indent=2)
372
+
373
+ # 6. Copy TensorBoard logs if they exist
374
+ tb_dir = output_dir / "tensorboard"
375
+ if tb_dir.exists():
376
+ import shutil
377
+ hub_tb_dir = hub_dir / "tensorboard"
378
+ if hub_tb_dir.exists():
379
+ shutil.rmtree(hub_tb_dir)
380
+ shutil.copytree(tb_dir, hub_tb_dir)
381
+ print(f" βœ“ Copied TensorBoard logs")
382
+
383
+ return hub_dir
384
+
385
+
386
+ def push_to_hub(
387
+ hub_dir: Path,
388
+ repo_id: str,
389
+ private: bool = False,
390
+ commit_message: Optional[str] = None
391
+ ) -> str:
392
+ """Push model to HuggingFace Hub."""
393
+
394
+ if not HF_HUB_AVAILABLE:
395
+ raise RuntimeError("huggingface_hub not installed. Run: pip install huggingface_hub")
396
+
397
+ api = HfApi()
398
+
399
+ # Create repo if it doesn't exist
400
+ try:
401
+ create_repo(repo_id, private=private, exist_ok=True)
402
+ print(f" βœ“ Repository ready: {repo_id}")
403
+ except Exception as e:
404
+ print(f" [!] Repo creation note: {e}")
405
+
406
+ # Upload
407
+ if commit_message is None:
408
+ commit_message = f"Upload DavidBeans model - {datetime.now().strftime('%Y-%m-%d %H:%M')}"
409
+
410
+ url = upload_folder(
411
+ folder_path=str(hub_dir),
412
+ repo_id=repo_id,
413
+ commit_message=commit_message
414
+ )
415
+
416
+ print(f" βœ“ Uploaded to: https://huggingface.co/{repo_id}")
417
+
418
+ return url
419
+
420
+
421
+ # ============================================================================
422
+ # DATA LOADING
423
+ # ============================================================================
424
+
425
+ def get_dataloaders(config: TrainingConfig) -> Tuple[DataLoader, DataLoader, int]:
426
+ """Get train and test dataloaders."""
427
+
428
+ try:
429
+ import torchvision
430
+ import torchvision.transforms as T
431
+
432
+ if config.dataset == "cifar10":
433
+ if config.use_augmentation:
434
+ train_transform = T.Compose([
435
+ T.RandomCrop(32, padding=4),
436
+ T.RandomHorizontalFlip(),
437
+ T.AutoAugment(T.AutoAugmentPolicy.CIFAR10),
438
+ T.ToTensor(),
439
+ T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
440
+ ])
441
+ else:
442
+ train_transform = T.Compose([
443
+ T.ToTensor(),
444
+ T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
445
+ ])
446
+
447
+ test_transform = T.Compose([
448
+ T.ToTensor(),
449
+ T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
450
+ ])
451
+
452
+ train_dataset = torchvision.datasets.CIFAR10(
453
+ root='./data', train=True, download=True, transform=train_transform
454
+ )
455
+ test_dataset = torchvision.datasets.CIFAR10(
456
+ root='./data', train=False, download=True, transform=test_transform
457
+ )
458
+ num_classes = 10
459
+
460
+ elif config.dataset == "cifar100":
461
+ if config.use_augmentation:
462
+ train_transform = T.Compose([
463
+ T.RandomCrop(32, padding=4),
464
+ T.RandomHorizontalFlip(),
465
+ T.AutoAugment(T.AutoAugmentPolicy.CIFAR10),
466
+ T.ToTensor(),
467
+ T.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
468
+ ])
469
+ else:
470
+ train_transform = T.Compose([
471
+ T.ToTensor(),
472
+ T.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
473
+ ])
474
+
475
+ test_transform = T.Compose([
476
+ T.ToTensor(),
477
+ T.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
478
+ ])
479
+
480
+ train_dataset = torchvision.datasets.CIFAR100(
481
+ root='./data', train=True, download=True, transform=train_transform
482
+ )
483
+ test_dataset = torchvision.datasets.CIFAR100(
484
+ root='./data', train=False, download=True, transform=test_transform
485
+ )
486
+ num_classes = 100
487
+ else:
488
+ raise ValueError(f"Unknown dataset: {config.dataset}")
489
+
490
+ train_loader = DataLoader(
491
+ train_dataset,
492
+ batch_size=config.batch_size,
493
+ shuffle=True,
494
+ num_workers=config.num_workers,
495
+ pin_memory=True,
496
+ persistent_workers=config.num_workers > 0,
497
+ drop_last=True
498
+ )
499
+ test_loader = DataLoader(
500
+ test_dataset,
501
+ batch_size=config.batch_size,
502
+ shuffle=False,
503
+ num_workers=config.num_workers,
504
+ pin_memory=True,
505
+ persistent_workers=config.num_workers > 0
506
+ )
507
+
508
+ return train_loader, test_loader, num_classes
509
+
510
+ except ImportError:
511
+ print(" [!] torchvision not available, using synthetic data")
512
+ return get_synthetic_dataloaders(config)
513
+
514
+
515
+ def get_synthetic_dataloaders(config: TrainingConfig) -> Tuple[DataLoader, DataLoader, int]:
516
+ """Fallback synthetic data for testing."""
517
+
518
+ class SyntheticDataset(torch.utils.data.Dataset):
519
+ def __init__(self, size: int, image_size: int, num_classes: int):
520
+ self.size = size
521
+ self.image_size = image_size
522
+ self.num_classes = num_classes
523
+
524
+ def __len__(self):
525
+ return self.size
526
+
527
+ def __getitem__(self, idx):
528
+ x = torch.randn(3, self.image_size, self.image_size)
529
+ y = idx % self.num_classes
530
+ return x, y
531
+
532
+ num_classes = 10
533
+ train_dataset = SyntheticDataset(5000, config.image_size, num_classes)
534
+ test_dataset = SyntheticDataset(1000, config.image_size, num_classes)
535
+
536
+ train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
537
+ test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False)
538
+
539
+ return train_loader, test_loader, num_classes
540
+
541
+
542
+ # ============================================================================
543
+ # MIXUP / CUTMIX AUGMENTATION
544
+ # ============================================================================
545
+
546
+ def mixup_data(x: torch.Tensor, y: torch.Tensor, alpha: float = 0.2):
547
+ """Mixup augmentation."""
548
+ if alpha > 0:
549
+ lam = torch.distributions.Beta(alpha, alpha).sample().item()
550
+ else:
551
+ lam = 1.0
552
+
553
+ batch_size = x.size(0)
554
+ index = torch.randperm(batch_size, device=x.device)
555
+
556
+ mixed_x = lam * x + (1 - lam) * x[index]
557
+ y_a, y_b = y, y[index]
558
+
559
+ return mixed_x, y_a, y_b, lam
560
+
561
+
562
+ def cutmix_data(x: torch.Tensor, y: torch.Tensor, alpha: float = 1.0):
563
+ """CutMix augmentation."""
564
+ if alpha > 0:
565
+ lam = torch.distributions.Beta(alpha, alpha).sample().item()
566
+ else:
567
+ lam = 1.0
568
+
569
+ batch_size = x.size(0)
570
+ index = torch.randperm(batch_size, device=x.device)
571
+
572
+ _, _, H, W = x.shape
573
+
574
+ cut_ratio = math.sqrt(1 - lam)
575
+ cut_h = int(H * cut_ratio)
576
+ cut_w = int(W * cut_ratio)
577
+
578
+ cx = torch.randint(0, H, (1,)).item()
579
+ cy = torch.randint(0, W, (1,)).item()
580
+
581
+ x1 = max(0, cx - cut_h // 2)
582
+ x2 = min(H, cx + cut_h // 2)
583
+ y1 = max(0, cy - cut_w // 2)
584
+ y2 = min(W, cy + cut_w // 2)
585
+
586
+ mixed_x = x.clone()
587
+ mixed_x[:, :, x1:x2, y1:y2] = x[index, :, x1:x2, y1:y2]
588
+
589
+ lam = 1 - ((x2 - x1) * (y2 - y1)) / (H * W)
590
+
591
+ y_a, y_b = y, y[index]
592
+
593
+ return mixed_x, y_a, y_b, lam
594
+
595
+
596
+ # ============================================================================
597
+ # METRICS TRACKER
598
+ # ============================================================================
599
+
600
+ class MetricsTracker:
601
+ """Track training metrics with EMA smoothing."""
602
+
603
+ def __init__(self, ema_decay: float = 0.9):
604
+ self.ema_decay = ema_decay
605
+ self.metrics = {}
606
+ self.ema_metrics = {}
607
+ self.history = {}
608
+
609
+ def update(self, **kwargs):
610
+ for k, v in kwargs.items():
611
+ if isinstance(v, torch.Tensor):
612
+ v = v.item()
613
+
614
+ if k not in self.metrics:
615
+ self.metrics[k] = []
616
+ self.ema_metrics[k] = v
617
+ self.history[k] = []
618
+
619
+ self.metrics[k].append(v)
620
+ self.ema_metrics[k] = self.ema_decay * self.ema_metrics[k] + (1 - self.ema_decay) * v
621
+
622
+ def get_ema(self, key: str) -> float:
623
+ return self.ema_metrics.get(key, 0.0)
624
+
625
+ def get_epoch_mean(self, key: str) -> float:
626
+ values = self.metrics.get(key, [])
627
+ return sum(values) / len(values) if values else 0.0
628
+
629
+ def end_epoch(self):
630
+ for k, v in self.metrics.items():
631
+ if v:
632
+ self.history[k].append(sum(v) / len(v))
633
+ self.metrics = {k: [] for k in self.metrics}
634
+
635
+ def get_history(self) -> Dict:
636
+ return self.history
637
+
638
+
639
+ # ============================================================================
640
+ # CHECKPOINT UTILITIES
641
+ # ============================================================================
642
+
643
+ def find_latest_checkpoint(output_dir: Path) -> Optional[Path]:
644
+ """Find the most recent checkpoint in output directory."""
645
+ checkpoints = list(output_dir.glob("checkpoint_epoch_*.pt"))
646
+
647
+ if not checkpoints:
648
+ # Try best_model.pt as fallback
649
+ best_model = output_dir / "best_model.pt"
650
+ if best_model.exists():
651
+ return best_model
652
+ return None
653
+
654
+ # Sort by epoch number
655
+ def get_epoch(p):
656
+ try:
657
+ return int(p.stem.split("_")[-1])
658
+ except:
659
+ return 0
660
+
661
+ checkpoints.sort(key=get_epoch, reverse=True)
662
+ return checkpoints[0]
663
+
664
+
665
+ def get_next_run_number(base_dir: Path) -> int:
666
+ """Get the next run number by scanning existing run directories."""
667
+ if not base_dir.exists():
668
+ return 1
669
+
670
+ max_num = 0
671
+ for d in base_dir.iterdir():
672
+ if d.is_dir() and d.name.startswith("run_"):
673
+ try:
674
+ # Extract number from "run_XXX_name_timestamp"
675
+ num = int(d.name.split("_")[1])
676
+ max_num = max(max_num, num)
677
+ except (IndexError, ValueError):
678
+ continue
679
+
680
+ return max_num + 1
681
+
682
+
683
+ def generate_run_dir_name(run_number: int, run_name: str) -> str:
684
+ """Generate a run directory name with number, name, and timestamp."""
685
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
686
+ # Sanitize run_name: lowercase, replace spaces with underscores, remove special chars
687
+ safe_name = "".join(c if c.isalnum() or c == "_" else "_" for c in run_name.lower())
688
+ safe_name = "_".join(filter(None, safe_name.split("_"))) # Remove consecutive underscores
689
+ return f"run_{run_number:03d}_{safe_name}_{timestamp}"
690
+
691
+
692
+ def find_latest_run_dir(base_dir: Path) -> Optional[Path]:
693
+ """Find the most recent run directory."""
694
+ if not base_dir.exists():
695
+ return None
696
+
697
+ run_dirs = [d for d in base_dir.iterdir() if d.is_dir() and d.name.startswith("run_")]
698
+
699
+ if not run_dirs:
700
+ return None
701
+
702
+ # Sort by modification time (most recent first)
703
+ run_dirs.sort(key=lambda d: d.stat().st_mtime, reverse=True)
704
+ return run_dirs[0]
705
+
706
+
707
+ def find_checkpoint_in_runs(base_dir: Path, resume_from: str) -> Optional[Path]:
708
+ """
709
+ Find a checkpoint to resume from.
710
+
711
+ Args:
712
+ base_dir: Base checkpoint directory (e.g., ./checkpoints/cifar100)
713
+ resume_from: Either "latest", a run directory name, or a full path
714
+
715
+ Returns:
716
+ Path to checkpoint file, or None
717
+ """
718
+ if resume_from == "latest":
719
+ # Find most recent run directory
720
+ run_dir = find_latest_run_dir(base_dir)
721
+ if run_dir:
722
+ return find_latest_checkpoint(run_dir)
723
+ # Fallback: check base_dir itself (for old-style checkpoints)
724
+ return find_latest_checkpoint(base_dir)
725
+
726
+ # Check if it's a full path
727
+ full_path = Path(resume_from)
728
+ if full_path.exists():
729
+ if full_path.is_file():
730
+ return full_path
731
+ elif full_path.is_dir():
732
+ return find_latest_checkpoint(full_path)
733
+
734
+ # Check if it's a run directory name within base_dir
735
+ run_path = base_dir / resume_from
736
+ if run_path.exists():
737
+ return find_latest_checkpoint(run_path)
738
+
739
+ return None
740
+
741
+
742
+ def load_checkpoint(
743
+ checkpoint_path: Path,
744
+ model: DavidBeans,
745
+ optimizer: Optional[torch.optim.Optimizer] = None,
746
+ device: str = "cuda"
747
+ ) -> Tuple[int, float]:
748
+ """
749
+ Load checkpoint and return (start_epoch, best_acc).
750
+
751
+ Returns:
752
+ start_epoch: Epoch to resume from (checkpoint_epoch + 1)
753
+ best_acc: Best accuracy so far
754
+ """
755
+ print(f"\nπŸ“‚ Loading checkpoint: {checkpoint_path}")
756
+ checkpoint = torch.load(checkpoint_path, map_location=device)
757
+
758
+ model.load_state_dict(checkpoint['model_state_dict'])
759
+ print(f" βœ“ Loaded model weights")
760
+
761
+ if optimizer is not None and 'optimizer_state_dict' in checkpoint:
762
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
763
+ print(f" βœ“ Loaded optimizer state")
764
+
765
+ epoch = checkpoint.get('epoch', 0)
766
+ best_acc = checkpoint.get('best_acc', 0.0)
767
+
768
+ print(f" βœ“ Loaded checkpoint from epoch {epoch + 1}, best_acc={best_acc:.2f}%")
769
+ print(f" βœ“ Will resume training from epoch {epoch + 2}")
770
+
771
+ return epoch + 1, best_acc
772
+
773
+
774
+ def get_config_from_checkpoint(checkpoint_path: Path) -> Tuple[DavidBeansConfig, dict]:
775
+ """
776
+ Extract model and training configs from a checkpoint.
777
+
778
+ Returns:
779
+ (model_config, train_config_dict)
780
+ """
781
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
782
+
783
+ model_config_dict = checkpoint.get('model_config', {})
784
+ train_config_dict = checkpoint.get('train_config', {})
785
+
786
+ # Handle tuple conversion for betas
787
+ if 'betas' in train_config_dict and isinstance(train_config_dict['betas'], list):
788
+ train_config_dict['betas'] = tuple(train_config_dict['betas'])
789
+
790
+ model_config = DavidBeansConfig(**model_config_dict)
791
+
792
+ return model_config, train_config_dict
793
+
794
+
795
+ # ============================================================================
796
+ # TRAINING LOOP
797
+ # ============================================================================
798
+
799
+ def train_epoch(
800
+ model: DavidBeans,
801
+ train_loader: DataLoader,
802
+ optimizer: torch.optim.Optimizer,
803
+ scheduler: Optional[torch.optim.lr_scheduler._LRScheduler],
804
+ config: TrainingConfig,
805
+ epoch: int,
806
+ tracker: MetricsTracker,
807
+ writer: Optional['SummaryWriter'] = None
808
+ ) -> Dict[str, float]:
809
+ """Train for one epoch."""
810
+
811
+ model.train()
812
+ device = config.device
813
+
814
+ total_loss = 0.0
815
+ total_correct = 0
816
+ total_samples = 0
817
+ global_step = epoch * len(train_loader)
818
+
819
+ pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}", leave=True)
820
+
821
+ for batch_idx, (images, targets) in enumerate(pbar):
822
+ images = images.to(device, non_blocking=True)
823
+ targets = targets.to(device, non_blocking=True)
824
+
825
+ # Apply mixup/cutmix
826
+ use_mixup = config.use_augmentation and config.mixup_alpha > 0
827
+ use_cutmix = config.use_augmentation and config.cutmix_alpha > 0
828
+
829
+ mixed = False
830
+ if use_mixup or use_cutmix:
831
+ r = torch.rand(1).item()
832
+ if r < 0.5:
833
+ pass
834
+ elif r < 0.75 and use_mixup:
835
+ images, targets_a, targets_b, lam = mixup_data(images, targets, config.mixup_alpha)
836
+ mixed = True
837
+ elif use_cutmix:
838
+ images, targets_a, targets_b, lam = cutmix_data(images, targets, config.cutmix_alpha)
839
+ mixed = True
840
+
841
+ # Forward pass
842
+ result = model(images, targets=targets, return_loss=True)
843
+ losses = result['losses']
844
+
845
+ if mixed:
846
+ logits = result['logits']
847
+ ce_loss = lam * F.cross_entropy(logits, targets_a, label_smoothing=config.label_smoothing) + \
848
+ (1 - lam) * F.cross_entropy(logits, targets_b, label_smoothing=config.label_smoothing)
849
+ losses['ce'] = ce_loss
850
+
851
+ # Compute total loss
852
+ loss = (
853
+ config.ce_weight * losses['ce'] +
854
+ config.cayley_weight * losses.get('geometric', torch.tensor(0.0, device=device)) +
855
+ config.contrast_weight * losses.get('contrast', torch.tensor(0.0, device=device))
856
+ )
857
+
858
+ for scale in model.config.scales:
859
+ scale_ce = losses.get(f'ce_{scale}', 0.0)
860
+ if isinstance(scale_ce, torch.Tensor):
861
+ loss = loss + config.scale_ce_weight * scale_ce
862
+
863
+ # Backward pass
864
+ optimizer.zero_grad()
865
+ loss.backward()
866
+
867
+ if config.gradient_clip > 0:
868
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.gradient_clip)
869
+ else:
870
+ grad_norm = 0.0
871
+
872
+ optimizer.step()
873
+
874
+ if scheduler is not None and config.scheduler == "onecycle":
875
+ scheduler.step()
876
+
877
+ # Compute accuracy
878
+ with torch.no_grad():
879
+ logits = result['logits']
880
+ preds = logits.argmax(dim=-1)
881
+
882
+ if mixed:
883
+ correct = (lam * (preds == targets_a).float() +
884
+ (1 - lam) * (preds == targets_b).float()).sum()
885
+ else:
886
+ correct = (preds == targets).sum()
887
+
888
+ total_correct += correct.item()
889
+ total_samples += targets.size(0)
890
+ total_loss += loss.item()
891
+
892
+ # Track metrics
893
+ def to_float(v):
894
+ return v.item() if isinstance(v, torch.Tensor) else float(v)
895
+
896
+ geo_loss = to_float(losses.get('geometric', 0.0))
897
+ contrast_loss = to_float(losses.get('contrast', 0.0))
898
+ expert_vol = to_float(losses.get('expert_volume', 0.0))
899
+ expert_collapse = to_float(losses.get('expert_collapse', 0.0))
900
+ expert_edge = to_float(losses.get('expert_edge_dev', 0.0))
901
+ current_lr = optimizer.param_groups[0]['lr']
902
+
903
+ tracker.update(
904
+ loss=loss.item(),
905
+ ce=losses['ce'].item(),
906
+ geo=geo_loss,
907
+ contrast=contrast_loss,
908
+ expert_vol=expert_vol,
909
+ expert_collapse=expert_collapse,
910
+ expert_edge=expert_edge,
911
+ lr=current_lr
912
+ )
913
+
914
+ # TensorBoard logging (every log_interval batches)
915
+ if writer is not None and (batch_idx + 1) % config.log_interval == 0:
916
+ step = global_step + batch_idx
917
+
918
+ # Loss components
919
+ writer.add_scalar('train/loss_total', loss.item(), step)
920
+ writer.add_scalar('train/loss_ce', losses['ce'].item(), step)
921
+ writer.add_scalar('train/loss_geometric', geo_loss, step)
922
+ writer.add_scalar('train/loss_contrast', contrast_loss, step)
923
+
924
+ # Geometric metrics
925
+ writer.add_scalar('train/expert_volume', expert_vol, step)
926
+ writer.add_scalar('train/expert_collapse', expert_collapse, step)
927
+ writer.add_scalar('train/expert_edge_dev', expert_edge, step)
928
+
929
+ # Training dynamics
930
+ writer.add_scalar('train/learning_rate', current_lr, step)
931
+ writer.add_scalar('train/grad_norm', to_float(grad_norm), step)
932
+ writer.add_scalar('train/batch_acc', 100.0 * correct.item() / targets.size(0), step)
933
+
934
+ pbar.set_postfix({
935
+ 'loss': f"{tracker.get_ema('loss'):.3f}",
936
+ 'acc': f"{100.0 * total_correct / total_samples:.1f}%",
937
+ 'geo': f"{tracker.get_ema('geo'):.4f}",
938
+ 'vol': f"{tracker.get_ema('expert_vol'):.4f}"
939
+ })
940
+
941
+ if scheduler is not None and config.scheduler == "cosine":
942
+ scheduler.step()
943
+
944
+ return {
945
+ 'loss': total_loss / len(train_loader),
946
+ 'acc': 100.0 * total_correct / total_samples
947
+ }
948
+
949
+
950
+ @torch.no_grad()
951
+ def evaluate(
952
+ model: DavidBeans,
953
+ test_loader: DataLoader,
954
+ config: TrainingConfig
955
+ ) -> Dict[str, float]:
956
+ """Evaluate on test set."""
957
+
958
+ model.eval()
959
+ device = config.device
960
+
961
+ total_loss = 0.0
962
+ total_correct = 0
963
+ total_samples = 0
964
+ scale_correct = {s: 0 for s in model.config.scales}
965
+
966
+ for images, targets in test_loader:
967
+ images = images.to(device, non_blocking=True)
968
+ targets = targets.to(device, non_blocking=True)
969
+
970
+ result = model(images, targets=targets, return_loss=True)
971
+
972
+ logits = result['logits']
973
+ losses = result['losses']
974
+
975
+ loss = losses['total']
976
+ preds = logits.argmax(dim=-1)
977
+
978
+ total_loss += loss.item() * targets.size(0)
979
+ total_correct += (preds == targets).sum().item()
980
+ total_samples += targets.size(0)
981
+
982
+ for i, scale in enumerate(model.config.scales):
983
+ scale_logits = result['scale_logits'][i]
984
+ scale_preds = scale_logits.argmax(dim=-1)
985
+ scale_correct[scale] += (scale_preds == targets).sum().item()
986
+
987
+ metrics = {
988
+ 'loss': total_loss / total_samples,
989
+ 'acc': 100.0 * total_correct / total_samples
990
+ }
991
+
992
+ for scale, correct in scale_correct.items():
993
+ metrics[f'acc_{scale}'] = 100.0 * correct / total_samples
994
+
995
+ return metrics
996
+
997
+
998
+ # ============================================================================
999
+ # MAIN TRAINING FUNCTION
1000
+ # ============================================================================
1001
+
1002
+ def train_david_beans(
1003
+ model_config: Optional[DavidBeansConfig] = None,
1004
+ train_config: Optional[TrainingConfig] = None
1005
+ ):
1006
+ """Main training function."""
1007
+
1008
+ print("=" * 70)
1009
+ print(" DAVID-BEANS TRAINING: The Dynamic Duo")
1010
+ print("=" * 70)
1011
+ print()
1012
+ print(" 🫘 BEANS (ViT) + πŸ’Ž DAVID (Crystal)")
1013
+ print(" Sparse Attention Multi-Scale Projection")
1014
+ print()
1015
+ print("=" * 70)
1016
+
1017
+ if train_config is None:
1018
+ train_config = TrainingConfig()
1019
+
1020
+ base_output_dir = Path(train_config.output_dir)
1021
+ base_output_dir.mkdir(parents=True, exist_ok=True)
1022
+
1023
+ # Check for resume FIRST - load config from checkpoint if resuming
1024
+ checkpoint_path = None
1025
+ run_dir = None # Will be set either from resume or new run
1026
+
1027
+ if train_config.resume_from:
1028
+ # Find checkpoint using the new directory structure
1029
+ checkpoint_path = find_checkpoint_in_runs(base_output_dir, train_config.resume_from)
1030
+
1031
+ if checkpoint_path and checkpoint_path.exists():
1032
+ print(f"\nπŸ“‚ Found checkpoint: {checkpoint_path}")
1033
+ # The run directory is the parent of the checkpoint
1034
+ run_dir = checkpoint_path.parent
1035
+ print(f" βœ“ Resuming in run directory: {run_dir.name}")
1036
+
1037
+ # Load config from checkpoint to ensure architecture matches
1038
+ loaded_model_config, loaded_train_config_dict = get_config_from_checkpoint(checkpoint_path)
1039
+
1040
+ if model_config is None:
1041
+ model_config = loaded_model_config
1042
+ print(f" βœ“ Using model config from checkpoint")
1043
+ else:
1044
+ # Warn if configs differ
1045
+ if model_config.dim != loaded_model_config.dim or model_config.scales != loaded_model_config.scales:
1046
+ print(f" ⚠ WARNING: Provided config differs from checkpoint!")
1047
+ print(f" Checkpoint: dim={loaded_model_config.dim}, scales={loaded_model_config.scales}")
1048
+ print(f" Provided: dim={model_config.dim}, scales={model_config.scales}")
1049
+ print(f" βœ“ Using checkpoint config to ensure compatibility")
1050
+ model_config = loaded_model_config
1051
+ else:
1052
+ print(f" [!] Checkpoint not found: {train_config.resume_from}")
1053
+ checkpoint_path = None
1054
+
1055
+ # If not resuming (or resume failed), create new run directory
1056
+ if run_dir is None:
1057
+ # Get run number
1058
+ if train_config.run_number is None:
1059
+ run_number = get_next_run_number(base_output_dir)
1060
+ else:
1061
+ run_number = train_config.run_number
1062
+
1063
+ # Generate run directory name
1064
+ run_dir_name = generate_run_dir_name(run_number, train_config.run_name)
1065
+ run_dir = base_output_dir / run_dir_name
1066
+ run_dir.mkdir(parents=True, exist_ok=True)
1067
+
1068
+ print(f"\nπŸ“ New run: {run_dir_name}")
1069
+ print(f" Run #{run_number}: {train_config.run_name}")
1070
+ else:
1071
+ # Extract run number from existing directory name for hub repo
1072
+ try:
1073
+ run_number = int(run_dir.name.split("_")[1])
1074
+ except (IndexError, ValueError):
1075
+ run_number = 1
1076
+
1077
+ # Update output_dir to point to the run directory
1078
+ output_dir = run_dir
1079
+
1080
+ # Generate effective hub repo ID with run info
1081
+ effective_hub_repo_id = train_config.hub_repo_id
1082
+ if train_config.hub_repo_id and train_config.hub_append_run:
1083
+ # Extract run name from directory (run_XXX_name_timestamp -> name)
1084
+ parts = run_dir.name.split("_")
1085
+ if len(parts) >= 3:
1086
+ run_name_part = parts[2] # Get the name part
1087
+ else:
1088
+ run_name_part = train_config.run_name
1089
+ effective_hub_repo_id = f"{train_config.hub_repo_id}-run{run_number:03d}-{run_name_part}"
1090
+ print(f" Hub repo: {effective_hub_repo_id}")
1091
+
1092
+ if model_config is None:
1093
+ model_config = DavidBeansConfig(
1094
+ image_size=train_config.image_size,
1095
+ patch_size=4,
1096
+ dim=256,
1097
+ num_layers=6,
1098
+ num_heads=8,
1099
+ num_experts=5,
1100
+ k_neighbors=16,
1101
+ cantor_weight=0.3,
1102
+ scales=[64, 128, 256],
1103
+ num_classes=10,
1104
+ contrast_weight=train_config.contrast_weight,
1105
+ cayley_weight=train_config.cayley_weight,
1106
+ dropout=0.1
1107
+ )
1108
+
1109
+ device = train_config.device
1110
+ print(f"\nDevice: {device}")
1111
+
1112
+ # Data
1113
+ print("\nLoading data...")
1114
+ train_loader, test_loader, num_classes = get_dataloaders(train_config)
1115
+ print(f" Dataset: {train_config.dataset}")
1116
+ print(f" Train: {len(train_loader.dataset)}, Test: {len(test_loader.dataset)}")
1117
+ print(f" Classes: {num_classes}")
1118
+
1119
+ model_config.num_classes = num_classes
1120
+
1121
+ # Model
1122
+ print("\nBuilding model...")
1123
+ model = DavidBeans(model_config)
1124
+ model = model.to(device)
1125
+
1126
+ print(f"\n{model}")
1127
+
1128
+ num_params = sum(p.numel() for p in model.parameters())
1129
+ num_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
1130
+ print(f"\nParameters: {num_params:,} ({num_trainable:,} trainable)")
1131
+
1132
+ # Optimizer
1133
+ print("\nSetting up optimizer...")
1134
+
1135
+ decay_params = []
1136
+ no_decay_params = []
1137
+
1138
+ for name, param in model.named_parameters():
1139
+ if not param.requires_grad:
1140
+ continue
1141
+ if 'bias' in name or 'norm' in name or 'embedding' in name:
1142
+ no_decay_params.append(param)
1143
+ else:
1144
+ decay_params.append(param)
1145
+
1146
+ optimizer = AdamW([
1147
+ {'params': decay_params, 'weight_decay': train_config.weight_decay},
1148
+ {'params': no_decay_params, 'weight_decay': 0.0}
1149
+ ], lr=train_config.learning_rate, betas=train_config.betas)
1150
+
1151
+ if train_config.scheduler == "cosine":
1152
+ scheduler = CosineAnnealingLR(
1153
+ optimizer,
1154
+ T_max=train_config.epochs - train_config.warmup_epochs,
1155
+ eta_min=train_config.min_lr
1156
+ )
1157
+ elif train_config.scheduler == "onecycle":
1158
+ scheduler = OneCycleLR(
1159
+ optimizer,
1160
+ max_lr=train_config.learning_rate,
1161
+ epochs=train_config.epochs,
1162
+ steps_per_epoch=len(train_loader),
1163
+ pct_start=train_config.warmup_epochs / train_config.epochs
1164
+ )
1165
+ else:
1166
+ scheduler = None
1167
+
1168
+ print(f" Optimizer: AdamW (lr={train_config.learning_rate}, wd={train_config.weight_decay})")
1169
+ print(f" Scheduler: {train_config.scheduler}")
1170
+ print(f" TensorBoard: {output_dir / 'tensorboard'}")
1171
+
1172
+ tracker = MetricsTracker()
1173
+ best_acc = 0.0
1174
+ start_epoch = 0
1175
+
1176
+ print(f"\nOutput directory: {output_dir}")
1177
+
1178
+ # Load weights from checkpoint if we found one earlier
1179
+ if checkpoint_path and checkpoint_path.exists():
1180
+ start_epoch, best_acc = load_checkpoint(
1181
+ checkpoint_path, model, optimizer, device
1182
+ )
1183
+
1184
+ # Adjust scheduler to correct position
1185
+ if scheduler is not None and train_config.scheduler == "cosine":
1186
+ for _ in range(start_epoch):
1187
+ scheduler.step()
1188
+
1189
+ # TensorBoard setup
1190
+ writer = None
1191
+ if train_config.use_tensorboard and TENSORBOARD_AVAILABLE:
1192
+ tb_dir = output_dir / "tensorboard"
1193
+ tb_dir.mkdir(parents=True, exist_ok=True)
1194
+ writer = SummaryWriter(log_dir=str(tb_dir))
1195
+ print(f" TensorBoard: {tb_dir}")
1196
+
1197
+ # Log model config as text
1198
+ config_text = json.dumps(model_config.__dict__, indent=2, default=str)
1199
+ writer.add_text("config/model", f"```json\n{config_text}\n```", 0)
1200
+
1201
+ train_text = json.dumps(train_config.to_dict(), indent=2, default=str)
1202
+ writer.add_text("config/training", f"```json\n{train_text}\n```", 0)
1203
+ elif train_config.use_tensorboard:
1204
+ print(" [!] TensorBoard requested but not available")
1205
+
1206
+ with open(output_dir / "model_config.json", "w") as f:
1207
+ json.dump(model_config.__dict__, f, indent=2, default=str)
1208
+ with open(output_dir / "train_config.json", "w") as f:
1209
+ json.dump(train_config.to_dict(), f, indent=2, default=str)
1210
+
1211
+ print(f"\nOutput directory: {output_dir}")
1212
+
1213
+ # Training loop
1214
+ print("\n" + "=" * 70)
1215
+ print(" TRAINING")
1216
+ print("=" * 70)
1217
+
1218
+ if start_epoch > 0:
1219
+ print(f" Resuming from epoch {start_epoch + 1}/{train_config.epochs}")
1220
+
1221
+ for epoch in range(start_epoch, train_config.epochs):
1222
+ epoch_start = time.time()
1223
+
1224
+ if epoch < train_config.warmup_epochs and train_config.scheduler == "cosine":
1225
+ warmup_lr = train_config.learning_rate * (epoch + 1) / train_config.warmup_epochs
1226
+ for param_group in optimizer.param_groups:
1227
+ param_group['lr'] = warmup_lr
1228
+
1229
+ train_metrics = train_epoch(
1230
+ model, train_loader, optimizer, scheduler,
1231
+ train_config, epoch, tracker, writer
1232
+ )
1233
+
1234
+ test_metrics = evaluate(model, test_loader, train_config)
1235
+
1236
+ epoch_time = time.time() - epoch_start
1237
+
1238
+ # TensorBoard epoch logging
1239
+ if writer is not None:
1240
+ # Epoch-level metrics
1241
+ writer.add_scalar('epoch/train_loss', train_metrics['loss'], epoch)
1242
+ writer.add_scalar('epoch/train_acc', train_metrics['acc'], epoch)
1243
+ writer.add_scalar('epoch/test_loss', test_metrics['loss'], epoch)
1244
+ writer.add_scalar('epoch/test_acc', test_metrics['acc'], epoch)
1245
+ writer.add_scalar('epoch/learning_rate', optimizer.param_groups[0]['lr'], epoch)
1246
+ writer.add_scalar('epoch/time_seconds', epoch_time, epoch)
1247
+
1248
+ # Per-scale accuracies
1249
+ for scale in model.config.scales:
1250
+ writer.add_scalar(f'scales/acc_{scale}', test_metrics[f'acc_{scale}'], epoch)
1251
+
1252
+ # Generalization gap
1253
+ writer.add_scalar('epoch/generalization_gap', test_metrics['acc'] - train_metrics['acc'], epoch)
1254
+
1255
+ # Flush periodically
1256
+ if (epoch + 1) % 5 == 0:
1257
+ writer.flush()
1258
+
1259
+ scale_accs = " | ".join([f"{s}:{test_metrics[f'acc_{s}']:.1f}%" for s in model.config.scales])
1260
+ star = "β˜…" if test_metrics['acc'] > best_acc else ""
1261
+
1262
+ print(f" β†’ Train: {train_metrics['acc']:.1f}% | Test: {test_metrics['acc']:.1f}% | "
1263
+ f"Scales: [{scale_accs}] | {epoch_time:.0f}s {star}")
1264
+
1265
+ if test_metrics['acc'] > best_acc:
1266
+ best_acc = test_metrics['acc']
1267
+
1268
+ torch.save({
1269
+ 'epoch': epoch,
1270
+ 'model_state_dict': model.state_dict(),
1271
+ 'optimizer_state_dict': optimizer.state_dict(),
1272
+ 'best_acc': best_acc,
1273
+ 'model_config': model_config.__dict__,
1274
+ 'train_config': train_config.to_dict()
1275
+ }, output_dir / "best_model.pt")
1276
+
1277
+ if (epoch + 1) % train_config.save_interval == 0:
1278
+ torch.save({
1279
+ 'epoch': epoch,
1280
+ 'model_state_dict': model.state_dict(),
1281
+ 'optimizer_state_dict': optimizer.state_dict(),
1282
+ 'best_acc': best_acc
1283
+ }, output_dir / f"checkpoint_epoch_{epoch + 1}.pt")
1284
+
1285
+ # Periodic HuggingFace Hub upload
1286
+ if train_config.push_to_hub and HF_HUB_AVAILABLE and effective_hub_repo_id:
1287
+ try:
1288
+ # Save current best for upload
1289
+ checkpoint = torch.load(output_dir / "best_model.pt", map_location='cpu')
1290
+ model_cpu = DavidBeans(model_config)
1291
+ model_cpu.load_state_dict(checkpoint['model_state_dict'])
1292
+
1293
+ hub_dir = save_for_hub(
1294
+ model=model_cpu,
1295
+ model_config=model_config,
1296
+ train_config=train_config,
1297
+ best_acc=best_acc,
1298
+ output_dir=output_dir,
1299
+ training_history=tracker.get_history()
1300
+ )
1301
+
1302
+ push_to_hub(
1303
+ hub_dir=hub_dir,
1304
+ repo_id=effective_hub_repo_id,
1305
+ private=train_config.hub_private,
1306
+ commit_message=f"Checkpoint epoch {epoch + 1} - {best_acc:.2f}% acc"
1307
+ )
1308
+ print(f" πŸ“€ Uploaded to {effective_hub_repo_id}")
1309
+ except Exception as e:
1310
+ print(f" [!] Hub upload failed: {e}")
1311
+
1312
+ tracker.end_epoch()
1313
+
1314
+ # Final summary
1315
+ print("\n" + "=" * 70)
1316
+ print(" TRAINING COMPLETE")
1317
+ print("=" * 70)
1318
+ print(f"\n Best Test Accuracy: {best_acc:.2f}%")
1319
+ print(f" Model saved to: {output_dir / 'best_model.pt'}")
1320
+
1321
+ # Save training history
1322
+ history = tracker.get_history()
1323
+ with open(output_dir / "training_history.json", "w") as f:
1324
+ json.dump(history, f, indent=2)
1325
+
1326
+ # Final TensorBoard logging
1327
+ if writer is not None:
1328
+ # Log best accuracy as hparam metric
1329
+ hparams = {
1330
+ 'dim': model_config.dim,
1331
+ 'num_layers': model_config.num_layers,
1332
+ 'num_heads': model_config.num_heads,
1333
+ 'num_experts': model_config.num_experts,
1334
+ 'k_neighbors': model_config.k_neighbors,
1335
+ 'cantor_weight': model_config.cantor_weight,
1336
+ 'learning_rate': train_config.learning_rate,
1337
+ 'weight_decay': train_config.weight_decay,
1338
+ 'batch_size': train_config.batch_size,
1339
+ 'mixup_alpha': train_config.mixup_alpha,
1340
+ 'cutmix_alpha': train_config.cutmix_alpha,
1341
+ }
1342
+ writer.add_hparams(hparams, {'hparam/best_acc': best_acc})
1343
+ writer.add_scalar('final/best_acc', best_acc, 0)
1344
+ writer.close()
1345
+ print(f" TensorBoard logs: {output_dir / 'tensorboard'}")
1346
+
1347
+ # HuggingFace Hub upload
1348
+ if train_config.push_to_hub:
1349
+ print("\n" + "=" * 70)
1350
+ print(" UPLOADING TO HUGGINGFACE HUB")
1351
+ print("=" * 70)
1352
+
1353
+ if not HF_HUB_AVAILABLE:
1354
+ print(" [!] huggingface_hub not installed. Skipping upload.")
1355
+ elif not effective_hub_repo_id:
1356
+ print(" [!] hub_repo_id not set. Skipping upload.")
1357
+ else:
1358
+ checkpoint = torch.load(output_dir / "best_model.pt", map_location='cpu')
1359
+ model.load_state_dict(checkpoint['model_state_dict'])
1360
+
1361
+ print(f"\n Preparing model for upload...")
1362
+ hub_dir = save_for_hub(
1363
+ model=model,
1364
+ model_config=model_config,
1365
+ train_config=train_config,
1366
+ best_acc=best_acc,
1367
+ output_dir=output_dir,
1368
+ training_history=history
1369
+ )
1370
+
1371
+ print(f"\n Uploading to {effective_hub_repo_id}...")
1372
+ push_to_hub(
1373
+ hub_dir=hub_dir,
1374
+ repo_id=effective_hub_repo_id,
1375
+ private=train_config.hub_private
1376
+ )
1377
+
1378
+ print(f"\n πŸŽ‰ Model uploaded to: https://huggingface.co/{effective_hub_repo_id}")
1379
+
1380
+ return model, best_acc
1381
+
1382
+
1383
+ # ============================================================================
1384
+ # PRESETS
1385
+ # ============================================================================
1386
+
1387
+ def train_cifar10_small(run_name: str = "cifar10_small"):
1388
+ """Small model for CIFAR-10."""
1389
+ model_config = DavidBeansConfig(
1390
+ image_size=32, patch_size=4, dim=256, num_layers=4,
1391
+ num_heads=4, num_experts=5, k_neighbors=16,
1392
+ cantor_weight=0.3, scales=[64, 128, 256, 512],
1393
+ num_classes=10, dropout=0.1
1394
+ )
1395
+
1396
+ train_config = TrainingConfig(
1397
+ run_name=run_name,
1398
+ dataset="cifar10", epochs=50, batch_size=128,
1399
+ learning_rate=1e-3, weight_decay=0.05, warmup_epochs=5,
1400
+ cayley_weight=0.01, contrast_weight=0.3,
1401
+ output_dir="./checkpoints/cifar10"
1402
+ )
1403
+
1404
+ return train_david_beans(model_config, train_config)
1405
+
1406
+
1407
+ def train_cifar100(
1408
+ run_name: str = "cifar100_base",
1409
+ push_to_hub: bool = False,
1410
+ hub_repo_id: Optional[str] = None,
1411
+ resume: bool = False
1412
+ ):
1413
+ """Model for CIFAR-100 with optional HF Hub upload and resume."""
1414
+ model_config = DavidBeansConfig(
1415
+ image_size=32, patch_size=4, dim=512, num_layers=8,
1416
+ num_heads=8, num_experts=5, k_neighbors=32,
1417
+ cantor_weight=0.3, scales=[256, 512, 768],
1418
+ num_classes=100, dropout=0.15
1419
+ )
1420
+
1421
+ train_config = TrainingConfig(
1422
+ run_name=run_name,
1423
+ dataset="cifar100", epochs=200, batch_size=128,
1424
+ learning_rate=5e-4, weight_decay=0.1, warmup_epochs=20,
1425
+ cayley_weight=0.01, contrast_weight=0.5,
1426
+ label_smoothing=0.1, mixup_alpha=0.3, cutmix_alpha=1.0,
1427
+ output_dir="./checkpoints/cifar100",
1428
+ resume_from="latest" if resume else None,
1429
+ push_to_hub=push_to_hub, hub_repo_id=hub_repo_id, hub_private=False
1430
+ )
1431
+
1432
+ return train_david_beans(model_config, train_config)
1433
+
1434
+
1435
+ def resume_training(
1436
+ checkpoint_dir: str = "./checkpoints/cifar100",
1437
+ push_to_hub: bool = False,
1438
+ hub_repo_id: Optional[str] = None
1439
+ ):
1440
+ """
1441
+ Resume training from the latest checkpoint in a directory.
1442
+
1443
+ Usage:
1444
+ resume_training("./checkpoints/cifar100", push_to_hub=True, hub_repo_id="user/repo")
1445
+ """
1446
+ output_dir = Path(checkpoint_dir)
1447
+
1448
+ # Load configs from checkpoint directory
1449
+ model_config_path = output_dir / "model_config.json"
1450
+ train_config_path = output_dir / "train_config.json"
1451
+
1452
+ if not model_config_path.exists():
1453
+ raise FileNotFoundError(f"No model_config.json in {output_dir}")
1454
+
1455
+ with open(model_config_path) as f:
1456
+ model_config_dict = json.load(f)
1457
+
1458
+ with open(train_config_path) as f:
1459
+ train_config_dict = json.load(f)
1460
+
1461
+ # Handle tuple conversion for betas
1462
+ if 'betas' in train_config_dict and isinstance(train_config_dict['betas'], list):
1463
+ train_config_dict['betas'] = tuple(train_config_dict['betas'])
1464
+
1465
+ model_config = DavidBeansConfig(**model_config_dict)
1466
+ train_config = TrainingConfig(**train_config_dict)
1467
+
1468
+ # Override with resume settings
1469
+ train_config.resume_from = "latest"
1470
+ train_config.push_to_hub = push_to_hub
1471
+ if hub_repo_id:
1472
+ train_config.hub_repo_id = hub_repo_id
1473
+
1474
+ return train_david_beans(model_config, train_config)
1475
+
1476
+
1477
+ # ============================================================================
1478
+ # STANDALONE UPLOAD FUNCTION
1479
+ # ============================================================================
1480
+
1481
+ def upload_checkpoint(
1482
+ checkpoint_path: str,
1483
+ repo_id: str,
1484
+ best_acc: Optional[float] = None,
1485
+ private: bool = False
1486
+ ):
1487
+ """
1488
+ Upload an existing checkpoint to HuggingFace Hub.
1489
+
1490
+ Usage:
1491
+ upload_checkpoint(
1492
+ checkpoint_path="./checkpoints/cifar100/best_model.pt",
1493
+ repo_id="AbstractPhil/david-beans-cifar100",
1494
+ best_acc=70.0 # Optional, will read from checkpoint if available
1495
+ )
1496
+ """
1497
+ if not HF_HUB_AVAILABLE:
1498
+ raise RuntimeError("huggingface_hub not installed. Run: pip install huggingface_hub")
1499
+
1500
+ print(f"\nπŸ“¦ Loading checkpoint: {checkpoint_path}")
1501
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
1502
+
1503
+ # Reconstruct configs
1504
+ model_config_dict = checkpoint.get('model_config', {})
1505
+ train_config_dict = checkpoint.get('train_config', {})
1506
+
1507
+ model_config = DavidBeansConfig(**model_config_dict)
1508
+ train_config = TrainingConfig(**train_config_dict)
1509
+ train_config.hub_repo_id = repo_id
1510
+
1511
+ # Build model and load weights
1512
+ model = DavidBeans(model_config)
1513
+ model.load_state_dict(checkpoint['model_state_dict'])
1514
+
1515
+ actual_best_acc = best_acc or checkpoint.get('best_acc', 0.0)
1516
+
1517
+ # Prepare and upload
1518
+ output_dir = Path(checkpoint_path).parent
1519
+
1520
+ print(f"\nπŸ“ Preparing files for upload...")
1521
+ hub_dir = save_for_hub(
1522
+ model=model,
1523
+ model_config=model_config,
1524
+ train_config=train_config,
1525
+ best_acc=actual_best_acc,
1526
+ output_dir=output_dir
1527
+ )
1528
+
1529
+ print(f"\nπŸš€ Uploading to {repo_id}...")
1530
+ push_to_hub(hub_dir, repo_id, private=private)
1531
+
1532
+ print(f"\nπŸŽ‰ Done! https://huggingface.co/{repo_id}")
1533
+
1534
+
1535
+ # ============================================================================
1536
+ # MAIN
1537
+ # ============================================================================
1538
+
1539
+ if __name__ == "__main__":
1540
+ # =====================================================
1541
+ # CONFIGURATION
1542
+ # =====================================================
1543
+
1544
+ PRESET = "cifar100" # "test", "small", "cifar100", "resume"
1545
+ RESUME = False # Set True to resume from latest checkpoint
1546
+ RUN_NAME = "5expert_3scale" # Descriptive name for this run
1547
+
1548
+ # HuggingFace Hub settings
1549
+ PUSH_TO_HUB = False
1550
+ HUB_REPO_ID = "AbstractPhil/geovit-david-beans"
1551
+
1552
+ # =====================================================
1553
+ # RUN
1554
+ # =====================================================
1555
+
1556
+ if PRESET == "test":
1557
+ print("πŸ§ͺ Quick test...")
1558
+ model_config = DavidBeansConfig(
1559
+ image_size=32, patch_size=4, dim=128, num_layers=2,
1560
+ num_heads=4, num_experts=5, k_neighbors=8,
1561
+ scales=[32, 64, 128], num_classes=10
1562
+ )
1563
+ train_config = TrainingConfig(
1564
+ run_name="test",
1565
+ epochs=2, batch_size=32,
1566
+ use_augmentation=False, mixup_alpha=0.0, cutmix_alpha=0.0
1567
+ )
1568
+ model, acc = train_david_beans(model_config, train_config)
1569
+
1570
+ elif PRESET == "small":
1571
+ print("πŸ«˜πŸ’Ž Training DavidBeans - Small (CIFAR-10)...")
1572
+ model, acc = train_cifar10_small()
1573
+
1574
+ elif PRESET == "cifar100":
1575
+ print("πŸ«˜πŸ’Ž Training DavidBeans - CIFAR-100...")
1576
+ model, acc = train_cifar100(
1577
+ run_name=RUN_NAME,
1578
+ push_to_hub=PUSH_TO_HUB,
1579
+ hub_repo_id=HUB_REPO_ID,
1580
+ resume=RESUME
1581
+ )
1582
+
1583
+ elif PRESET == "resume":
1584
+ print("πŸ”„ Resuming training from latest checkpoint...")
1585
+ model, acc = resume_training(
1586
+ checkpoint_dir="./checkpoints/cifar100",
1587
+ push_to_hub=PUSH_TO_HUB,
1588
+ hub_repo_id=HUB_REPO_ID
1589
+ )
1590
+
1591
+ else:
1592
+ raise ValueError(f"Unknown preset: {PRESET}")
1593
+
1594
+ print(f"\nπŸŽ‰ Done! Best accuracy: {acc:.2f}%")