AbstractPhil commited on
Commit
3cc46bf
·
verified ·
1 Parent(s): f71172e

Create trainer.py

Browse files
Files changed (1) hide show
  1. trainer.py +1900 -0
trainer.py ADDED
@@ -0,0 +1,1900 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ David Training Pipeline
3
+ ========================
4
+ Author: AbstractPhil
5
+ Assistant: Claude Sonnet 4.5
6
+ ------------------------------------------------------------=
7
+ Training pipeline for David multi-scale feature classifier.
8
+
9
+ Will be placed officially at: geovocab2/train/model/core/david_trainer.py
10
+ Or run from: scripts/train_david.py
11
+
12
+ Runs on colab without hassle, set your repo and your HF_TOKEN as a userdata secret in colab.
13
+
14
+ Features:
15
+ - Pure fp32 training (no mixed precision for geometric stability)
16
+ - Can enable mixed if you want speed.
17
+ - Adaptive training controller (freeze/unfreeze scales)
18
+ - Gradient analysis and scaling
19
+ - SafeTensors checkpointing and epoch control support
20
+ - Enhanced loss component tracking
21
+ - Proper weight organization: weights/model_name/timestamp/
22
+ - Accuracy in filenames and comprehensive tracking
23
+ - Saves models into a shared index (MODELS_INDEX.json) in the repo.
24
+ - Parses a readme if one exists, creates a repo if one doesn't.
25
+ """
26
+
27
+ import torch
28
+ import torch.nn.functional as F
29
+ from torch.utils.data import Dataset, DataLoader
30
+ from torch.utils.tensorboard import SummaryWriter
31
+ from datasets import load_dataset
32
+ from huggingface_hub import HfApi, create_repo, upload_folder, upload_file
33
+ import numpy as np
34
+ import os
35
+ import json
36
+ import time
37
+ import tempfile
38
+ from datetime import datetime
39
+ from tqdm.auto import tqdm
40
+ from pathlib import Path
41
+ from typing import Dict, List, Optional, Tuple, Union
42
+ from dataclasses import dataclass, field, asdict
43
+
44
+ # Import David components
45
+ from geovocab2.train.config.david_config import (
46
+ DavidArchitectureConfig,
47
+ DavidPresets,
48
+ SharingMode,
49
+ FusionMode
50
+ )
51
+
52
+ from geovocab2.train.model.core.david import (
53
+ David,
54
+ MultiScaleCrystalLoss,
55
+ )
56
+
57
+ # Import SimplexFactory
58
+ from geovocab2.shapes.factory import SimplexFactory
59
+
60
+
61
+ # ============================================================================
62
+ # TRAINING CONFIGURATION
63
+ # ============================================================================
64
+
65
+ @dataclass
66
+ class DavidTrainingConfig:
67
+ """
68
+ Complete training configuration for David.
69
+ Separate from model architecture config.
70
+ """
71
+
72
+ # Metadata
73
+ name: str = "david_training"
74
+ run_id: str = "" # Auto-generated timestamp
75
+
76
+ # Dataset
77
+ dataset_name: str = "AbstractPhil/imagenet-clip-features-orderly"
78
+ model_variant: str = "clip_vit_b16"
79
+ num_classes: int = 1000
80
+
81
+ # Model architecture (references to david_config)
82
+ preset: Optional[str] = "balanced" # Or None to use custom config
83
+ custom_config_path: Optional[str] = None # Path to custom david_config.json
84
+
85
+ # Architecture overrides (applied to preset or custom config)
86
+ num_classes_override: Optional[int] = None
87
+ use_belly_override: Optional[bool] = None
88
+ belly_expand_override: Optional[float] = None
89
+ progressive_training_override: Optional[bool] = None # Override progressive training
90
+ scale_warmup_epochs_override: Optional[Dict[int, int]] = None # Custom warmup schedule
91
+
92
+ # Training hyperparameters
93
+ num_epochs: int = 50
94
+ batch_size: int = 512
95
+ learning_rate: float = 5e-3
96
+ weight_decay: float = 1e-5
97
+ warmup_epochs: int = 3
98
+
99
+ # Loss weights
100
+ use_rose_loss: bool = True
101
+ rose_initial_weight: float = 0.01
102
+ rose_max_weight: float = 0.1
103
+ rose_weight_schedule: str = "adaptive"
104
+ use_cayley_loss: bool = False
105
+ cayley_weight: float = 0.001
106
+ scale_loss_balance: Optional[Dict[int, float]] = None
107
+
108
+ # Optimization
109
+ use_mixed_precision: bool = False # Keep False for stability
110
+ gradient_clip: float = 5.0
111
+ scheduler_type: str = "cosine_restarts"
112
+ min_lr: float = 1e-6
113
+
114
+ # Adaptive training (safer defaults)
115
+ freeze_strategy: str = "never" # "performance" or "never"
116
+ freeze_threshold: float = 90.0 # Only freeze when scale hits 90% accuracy
117
+ unfreeze_on_plateau: bool = True
118
+ patience: int = 10
119
+
120
+ # Gradient monitoring
121
+ track_gradients: bool = True
122
+ gradient_scale_threshold: float = 1e-5
123
+ gradient_scale_multiplier: float = 10.0
124
+
125
+ # Logging
126
+ log_interval: int = 50
127
+ val_interval: int = 1
128
+ save_interval: int = 5
129
+ log_fusion_weights: bool = True
130
+ log_loss_components: bool = True
131
+
132
+ # Checkpointing
133
+ save_format: str = "both" # "pytorch", "safetensors", or "both"
134
+
135
+ # HuggingFace Hub (optional)
136
+ hf_repo: Optional[str] = "YourName/Repo" #"AbstractPhil/gated-david" # Your HF repo
137
+ upload_to_hub: bool = False
138
+
139
+ # Local paths
140
+ base_dir: str = "./david_training"
141
+
142
+ # Hardware
143
+ num_workers: int = 10
144
+ pin_memory: bool = True
145
+ prefetch_factor: int = 4
146
+ persistent_workers: bool = True
147
+
148
+ def __post_init__(self):
149
+ """Generate run_id if not provided."""
150
+ if not self.run_id:
151
+ self.run_id = datetime.now().strftime('%Y%m%d_%H%M%S')
152
+
153
+ def to_dict(self) -> dict:
154
+ """Convert to dictionary."""
155
+ return asdict(self)
156
+
157
+ @classmethod
158
+ def from_dict(cls, data: dict) -> 'DavidTrainingConfig':
159
+ """Create from dictionary."""
160
+ return cls(**data)
161
+
162
+ def to_json(self, path: str):
163
+ """Save to JSON."""
164
+ data = self.to_dict()
165
+ # Convert any nested dicts with int keys to str keys
166
+ if data.get('scale_loss_balance'):
167
+ data['scale_loss_balance'] = {
168
+ str(k): v for k, v in data['scale_loss_balance'].items()
169
+ }
170
+ if data.get('scale_warmup_epochs_override'):
171
+ data['scale_warmup_epochs_override'] = {
172
+ str(k): v for k, v in data['scale_warmup_epochs_override'].items()
173
+ }
174
+ with open(path, 'w') as f:
175
+ json.dump(data, f, indent=2)
176
+
177
+ @classmethod
178
+ def from_json(cls, path: str) -> 'DavidTrainingConfig':
179
+ """Load from JSON."""
180
+ with open(path, 'r') as f:
181
+ data = json.load(f)
182
+ # Convert str keys back to int for scale_loss_balance
183
+ if 'scale_loss_balance' in data and data['scale_loss_balance']:
184
+ data['scale_loss_balance'] = {
185
+ int(k): v for k, v in data['scale_loss_balance'].items()
186
+ }
187
+ # Convert str keys back to int for scale_warmup_epochs_override
188
+ if 'scale_warmup_epochs_override' in data and data['scale_warmup_epochs_override']:
189
+ data['scale_warmup_epochs_override'] = {
190
+ int(k): v for k, v in data['scale_warmup_epochs_override'].items()
191
+ }
192
+ return cls(**data)
193
+
194
+
195
+ # ============================================================================
196
+ # ADAPTIVE TRAINING CONTROLLER
197
+ # ============================================================================
198
+
199
+ class AdaptiveTrainingController:
200
+ """Manages adaptive training strategies for multi-scale model."""
201
+
202
+ def __init__(self, model: David, config: DavidTrainingConfig):
203
+ self.model = model
204
+ self.config = config
205
+
206
+ scales = model.scales
207
+ self.scale_history = {scale: [] for scale in scales}
208
+ self.best_scale_acc = {scale: 0.0 for scale in scales}
209
+ self.scales_frozen = {scale: False for scale in scales}
210
+
211
+ self.overall_history = []
212
+ self.plateau_counter = 0
213
+ self.best_overall = 0.0
214
+
215
+ def update_metrics(self, scale_accuracies: Dict[int, float], overall_accuracy: float):
216
+ """Update metrics and best scores."""
217
+ for scale, acc in scale_accuracies.items():
218
+ self.scale_history[scale].append(acc)
219
+ if acc > self.best_scale_acc[scale]:
220
+ self.best_scale_acc[scale] = acc
221
+
222
+ self.overall_history.append(overall_accuracy)
223
+
224
+ if overall_accuracy > self.best_overall:
225
+ self.best_overall = overall_accuracy
226
+ self.plateau_counter = 0
227
+ else:
228
+ self.plateau_counter += 1
229
+
230
+ def should_freeze_scale(self, scale: int, current_acc: float) -> bool:
231
+ """Determine if a scale should be frozen."""
232
+ if self.config.freeze_strategy == "never":
233
+ return False
234
+
235
+ if self.scales_frozen[scale]:
236
+ return False
237
+
238
+ if self.config.freeze_strategy == "performance":
239
+ return current_acc >= self.config.freeze_threshold
240
+
241
+ return False
242
+
243
+ def should_unfreeze_scales(self) -> bool:
244
+ """Check if scales should be unfrozen due to plateau."""
245
+ if not self.config.unfreeze_on_plateau:
246
+ return False
247
+ return self.plateau_counter >= 5
248
+
249
+ def apply_adaptive_strategies(self, scale_accuracies: Dict[int, float], epoch: int):
250
+ """Apply freeze/unfreeze based on performance."""
251
+ active_scales = self.model.get_active_scales()
252
+
253
+ # Don't freeze scales if it would leave no trainable parameters
254
+ for scale, acc in scale_accuracies.items():
255
+ if self.should_freeze_scale(scale, acc):
256
+ # Count how many active scales would remain unfrozen
257
+ active_unfrozen = [s for s in active_scales if not self.scales_frozen.get(s, False)]
258
+
259
+ if len(active_unfrozen) <= 1:
260
+ print(f"[⚠️] Skipping freeze of scale {scale} (would leave no active trainable scales)")
261
+ continue
262
+
263
+ self.model.freeze_scale(scale)
264
+ self.scales_frozen[scale] = True
265
+ print(f"[❄️] Froze scale {scale} (acc={acc:.2f}%)")
266
+
267
+ if self.should_unfreeze_scales() and any(self.scales_frozen.values()):
268
+ for scale in self.model.scales:
269
+ if self.scales_frozen[scale]:
270
+ self.model.unfreeze_scale(scale)
271
+ self.scales_frozen[scale] = False
272
+ self.plateau_counter = 0
273
+ print(f"[🔥] Unfroze all scales due to plateau")
274
+
275
+
276
+ # ============================================================================
277
+ # OPTIMIZER & SCHEDULER CREATION
278
+ # ============================================================================
279
+
280
+ def create_optimizer(david: David, config: DavidTrainingConfig) -> torch.optim.Optimizer:
281
+ """Create optimizer with parameter groups."""
282
+
283
+ param_groups = []
284
+
285
+ # Shared parameters (if exists)
286
+ if hasattr(david, 'shared_extractor'):
287
+ param_groups.append({
288
+ 'params': david.shared_extractor.parameters(),
289
+ 'lr': config.learning_rate,
290
+ 'name': 'shared'
291
+ })
292
+ elif hasattr(david, 'shared_base'):
293
+ param_groups.append({
294
+ 'params': david.shared_base.parameters(),
295
+ 'lr': config.learning_rate,
296
+ 'name': 'shared'
297
+ })
298
+
299
+ # Scale-specific parameters
300
+ for scale in david.scales:
301
+ scale_params = []
302
+ if david.sharing_mode == SharingMode.HIERARCHICAL:
303
+ head = getattr(david, f'head_{scale}', None)
304
+ if head:
305
+ scale_params.extend(head.parameters())
306
+ refine = getattr(david, f'refine_{scale}', None)
307
+ if refine:
308
+ scale_params.extend(refine.parameters())
309
+ else:
310
+ scale_params.extend(david.heads[str(scale)].parameters())
311
+
312
+ if scale_params:
313
+ param_groups.append({
314
+ 'params': scale_params,
315
+ 'lr': config.learning_rate,
316
+ 'name': f'scale_{scale}'
317
+ })
318
+
319
+ # Fusion parameters
320
+ if hasattr(david, 'fusion'):
321
+ param_groups.append({
322
+ 'params': david.fusion.parameters(),
323
+ 'lr': config.learning_rate * 0.5,
324
+ 'name': 'fusion'
325
+ })
326
+ elif hasattr(david, 'fusion_weights'):
327
+ param_groups.append({
328
+ 'params': [david.fusion_weights],
329
+ 'lr': config.learning_rate * 0.5,
330
+ 'name': 'fusion'
331
+ })
332
+
333
+ return torch.optim.AdamW(param_groups, weight_decay=config.weight_decay)
334
+
335
+
336
+ def create_scheduler(optimizer: torch.optim.Optimizer,
337
+ config: DavidTrainingConfig) -> torch.optim.lr_scheduler._LRScheduler:
338
+ """Create learning rate scheduler."""
339
+
340
+ if config.scheduler_type == "cosine_restarts":
341
+ return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
342
+ optimizer, T_0=10, T_mult=2, eta_min=config.min_lr
343
+ )
344
+ elif config.scheduler_type == "cosine":
345
+ return torch.optim.lr_scheduler.CosineAnnealingLR(
346
+ optimizer, T_max=config.num_epochs, eta_min=config.min_lr
347
+ )
348
+ else:
349
+ return None
350
+
351
+
352
+ # ============================================================================
353
+ # GRADIENT ANALYSIS
354
+ # ============================================================================
355
+
356
+ def analyze_gradients(model: David, config: DavidTrainingConfig) -> Dict[str, float]:
357
+ """Analyze gradient magnitudes for debugging."""
358
+ grad_stats = {
359
+ 'mean': 0.0,
360
+ 'max': 0.0,
361
+ 'min': float('inf'),
362
+ 'num_zero': 0,
363
+ 'num_small': 0,
364
+ 'total': 0
365
+ }
366
+
367
+ for name, param in model.named_parameters():
368
+ if param.grad is not None:
369
+ grad_norm = param.grad.norm().item()
370
+ grad_stats['mean'] += grad_norm
371
+ grad_stats['max'] = max(grad_stats['max'], grad_norm)
372
+ grad_stats['min'] = min(grad_stats['min'], grad_norm)
373
+ grad_stats['total'] += 1
374
+
375
+ if grad_norm < 1e-10:
376
+ grad_stats['num_zero'] += 1
377
+ elif grad_norm < config.gradient_scale_threshold:
378
+ grad_stats['num_small'] += 1
379
+
380
+ if grad_stats['total'] > 0:
381
+ grad_stats['mean'] /= grad_stats['total']
382
+
383
+ return grad_stats
384
+
385
+
386
+ def scale_small_gradients(model: David, config: DavidTrainingConfig):
387
+ """Scale up very small gradients to prevent vanishing."""
388
+ if not config.track_gradients:
389
+ return
390
+
391
+ for param in model.parameters():
392
+ if param.grad is not None:
393
+ grad_norm = param.grad.norm()
394
+ if grad_norm < config.gradient_scale_threshold and grad_norm > 0:
395
+ param.grad.mul_(config.gradient_scale_multiplier)
396
+
397
+
398
+ # ============================================================================
399
+ # HUGGINGFACE HUB UTILITIES
400
+ # ============================================================================
401
+
402
+ def generate_model_readme(
403
+ config: DavidTrainingConfig,
404
+ david_config: DavidArchitectureConfig,
405
+ best_metrics: Dict,
406
+ run_id: str
407
+ ) -> str:
408
+ """Generate README.md for model card."""
409
+
410
+ readme = f"""---
411
+ language: en
412
+ license: mit
413
+ tags:
414
+ - image-classification
415
+ - imagenet
416
+ - multi-scale
417
+ - feature-geometry
418
+ - david
419
+ datasets:
420
+ - imagenet-1k
421
+ metrics:
422
+ - accuracy
423
+ model-index:
424
+ - name: David-{david_config.sharing_mode}-{david_config.fusion_mode}
425
+ results:
426
+ - task:
427
+ type: image-classification
428
+ dataset:
429
+ name: ImageNet-1K
430
+ type: imagenet-1k
431
+ metrics:
432
+ - type: accuracy
433
+ value: {best_metrics.get('best_val_acc', 0.0):.2f}
434
+ ---
435
+
436
+ # David: Multi-Scale Feature Classifier
437
+
438
+ **David** is a multi-scale deep learning classifier that uses feature geometry (pentachora/4-simplexes)
439
+ as class prototypes with role-weighted similarity computation (Rose Loss).
440
+
441
+ ## Model Details
442
+
443
+ ### Architecture
444
+ - **Preset**: {config.preset}
445
+ - **Sharing Mode**: {david_config.sharing_mode}
446
+ - **Fusion Mode**: {david_config.fusion_mode}
447
+ - **Scales**: {david_config.scales}
448
+ - **Feature Dim**: {david_config.feature_dim}
449
+ - **Parameters**: {best_metrics.get('parameters', 0):,}
450
+
451
+ ### Training Configuration
452
+ - **Dataset**: {config.dataset_name}
453
+ - **Model Variant**: {config.model_variant}
454
+ - **Epochs**: {config.num_epochs}
455
+ - **Batch Size**: {config.batch_size}
456
+ - **Learning Rate**: {config.learning_rate}
457
+ - **Rose Loss Weight**: {config.rose_initial_weight} → {config.rose_max_weight}
458
+ - **Cayley Loss**: {config.use_cayley_loss}
459
+
460
+ ## Performance
461
+
462
+ ### Best Results
463
+ - **Validation Accuracy**: {best_metrics.get('best_val_acc', 0.0):.2f}%
464
+ - **Best Epoch**: {best_metrics.get('best_epoch', 0)}
465
+ - **Final Train Accuracy**: {best_metrics.get('final_train_acc', 0.0):.2f}%
466
+
467
+ ### Per-Scale Performance
468
+ """
469
+
470
+ if 'scale_accuracies' in best_metrics:
471
+ for scale, acc in best_metrics['scale_accuracies'].items():
472
+ readme += f"- **Scale {scale}**: {acc:.2f}%\n"
473
+
474
+ readme += f"""
475
+
476
+ ## Usage
477
+
478
+ ### Quick Model Lookup
479
+
480
+ **Check `MODELS_INDEX.json` in the repo root** - it lists all trained models sorted by accuracy with links to weights and configs.
481
+
482
+ ### Repository Structure
483
+
484
+ ```
485
+ {config.hf_repo if config.hf_repo else 'AbstractPhil/david'}/
486
+ ├── MODELS_INDEX.json # 📊 Master index of all models (sorted by accuracy)
487
+ ├── README.md # This file
488
+ ├── best_model.json # Latest best model info
489
+ ├── weights/
490
+ │ └── {david_config.name}/
491
+ │ └── {run_id}/
492
+ │ ├── MODEL_SUMMARY.txt # 🎯 Human-readable performance summary
493
+ │ ├── training_history.json # 📈 Epoch-by-epoch training curve
494
+ │ ├── best_model_acc{best_metrics.get('best_val_acc', 0.0):.2f}.safetensors # ⭐ Accuracy in filename!
495
+ │ ├── best_model_acc{best_metrics.get('best_val_acc', 0.0):.2f}_metadata.json
496
+ │ ├── final_model.safetensors
497
+ │ ├── checkpoint_epoch_X_accYY.YY.safetensors
498
+ │ ├── david_config.json
499
+ │ └── train_config.json
500
+ └── runs/
501
+ └── {david_config.name}/
502
+ └── {run_id}/
503
+ └── events.out.tfevents.* # TensorBoard logs
504
+ ```
505
+
506
+ ### Loading the Model
507
+
508
+ ```python
509
+ from geovocab2.train.model.core.david import David, DavidArchitectureConfig
510
+ from huggingface_hub import hf_hub_download
511
+
512
+ # Browse available models in MODELS_INDEX.json first!
513
+
514
+ # Specify model variant and run
515
+ model_name = "{david_config.name}"
516
+ run_id = "{run_id}"
517
+ accuracy = "{best_metrics.get('best_val_acc', 0.0):.2f}" # From MODELS_INDEX.json
518
+
519
+ # Download config
520
+ config_path = hf_hub_download(
521
+ repo_id="{config.hf_repo if config.hf_repo else 'AbstractPhil/david'}",
522
+ filename=f"weights/{{model_name}}/{{run_id}}/david_config.json"
523
+ )
524
+ config = DavidArchitectureConfig.from_json(config_path)
525
+
526
+ # Download weights (accuracy in filename!)
527
+ weights_path = hf_hub_download(
528
+ repo_id="{config.hf_repo if config.hf_repo else 'AbstractPhil/david'}",
529
+ filename=f"weights/{{model_name}}/{{run_id}}/best_model_acc{{accuracy}}.safetensors"
530
+ )
531
+
532
+ # Download training history (optional - see full training curve)
533
+ history_path = hf_hub_download(
534
+ repo_id="{config.hf_repo if config.hf_repo else 'AbstractPhil/david'}",
535
+ filename=f"weights/{{model_name}}/{{run_id}}/training_history.json"
536
+ )
537
+
538
+ # Load model
539
+ from safetensors.torch import load_file
540
+ david = David.from_config(config)
541
+ david.load_state_dict(load_file(weights_path))
542
+ david.eval()
543
+ ```
544
+
545
+ ### Inference
546
+
547
+ ```python
548
+ import torch
549
+ import torch.nn.functional as F
550
+
551
+ # Assuming you have CLIP features (512-dim for ViT-B/16)
552
+ features = get_clip_features(image) # [1, 512]
553
+
554
+ # Load anchors
555
+ anchors_dict = torch.load("anchors.pth")
556
+
557
+ # Forward pass
558
+ with torch.no_grad():
559
+ logits, _ = david(features, anchors_dict)
560
+ predictions = logits.argmax(dim=-1)
561
+ ```
562
+
563
+ ## Architecture Overview
564
+
565
+ ### Multi-Scale Processing
566
+ David processes inputs at multiple scales ({', '.join(map(str, david_config.scales))}),
567
+ allowing it to capture both coarse and fine-grained features.
568
+
569
+ ### Feature Geometry
570
+ Each class is represented by a pentachoron (4-simplex) in embedding space with 5 vertices:
571
+ - **Anchor**: Primary class representative
572
+ - **Need**: Complementary direction
573
+ - **Relation**: Contextual alignment
574
+ - **Purpose**: Functional direction
575
+ - **Observer**: Meta-perspective
576
+
577
+ ### Rose Loss
578
+ Similarity computation uses role-weighted cosine similarities:
579
+ ```
580
+ score = w_anchor * sim(z, anchor) + w_need * sim(z, need) + ...
581
+ ```
582
+
583
+ ### Fusion Strategy
584
+ **{david_config.fusion_mode}**: Intelligently combines predictions from multiple scales.
585
+
586
+ ## Training Details
587
+
588
+ ### Loss Components
589
+ - **Cross-Entropy**: Standard classification loss
590
+ - **Rose Loss**: Pentachora role-weighted margin loss (weight: {config.rose_initial_weight}→{config.rose_max_weight})
591
+ - **Cayley Loss**: Geometric regularization ({'enabled' if config.use_cayley_loss else 'disabled'})
592
+
593
+ ### Optimization
594
+ - **Optimizer**: AdamW
595
+ - **Weight Decay**: {config.weight_decay}
596
+ - **Scheduler**: {config.scheduler_type}
597
+ - **Gradient Clip**: {config.gradient_clip}
598
+ - **Mixed Precision**: {config.use_mixed_precision}
599
+
600
+ ## Citation
601
+
602
+ ```bibtex
603
+ @software{{david_classifier_2025,
604
+ title = {{David: Multi-Scale Feature Classifier}},
605
+ author = {{AbstractPhil}},
606
+ year = {{2025}},
607
+ url = {{https://huggingface.co/{config.hf_repo if config.hf_repo else 'AbstractPhil/david'}}},
608
+ note = {{Run ID: {run_id}}}
609
+ }}
610
+ ```
611
+
612
+ ## License
613
+
614
+ MIT License
615
+
616
+ ## Acknowledgments
617
+
618
+ Built with lattice geometry and multi-scale deep learning.
619
+ Special thanks to Claude (Anthropic) for debugging assistance.
620
+
621
+ ---
622
+
623
+ *Generated on {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}*
624
+ """
625
+
626
+ return readme
627
+
628
+
629
+ def save_best_model_json(
630
+ filepath: str,
631
+ metrics: Dict,
632
+ config: DavidTrainingConfig,
633
+ david_config: DavidArchitectureConfig
634
+ ):
635
+ """Save best_model.json with comprehensive metrics."""
636
+
637
+ model_name = f"David-{david_config.sharing_mode}-{david_config.fusion_mode}"
638
+
639
+ best_model_info = {
640
+ "model_name": model_name,
641
+ "run_id": config.run_id,
642
+ "timestamp": datetime.now().isoformat(),
643
+
644
+ # Best metrics
645
+ "best_val_acc": metrics.get('best_val_acc', 0.0),
646
+ "best_epoch": metrics.get('best_epoch', 0),
647
+ "final_train_acc": metrics.get('final_train_acc', 0.0),
648
+ "final_train_loss": metrics.get('final_train_loss', 0.0),
649
+
650
+ # Per-scale performance
651
+ "scale_accuracies": metrics.get('scale_accuracies', {}),
652
+
653
+ # Architecture
654
+ "architecture": {
655
+ "preset": config.preset,
656
+ "sharing_mode": david_config.sharing_mode,
657
+ "fusion_mode": david_config.fusion_mode,
658
+ "scales": david_config.scales,
659
+ "feature_dim": david_config.feature_dim,
660
+ "num_classes": david_config.num_classes,
661
+ "use_belly": david_config.use_belly,
662
+ "belly_expand": david_config.belly_expand,
663
+ },
664
+
665
+ # Training config
666
+ "training": {
667
+ "dataset": config.dataset_name,
668
+ "model_variant": config.model_variant,
669
+ "num_epochs": config.num_epochs,
670
+ "batch_size": config.batch_size,
671
+ "learning_rate": config.learning_rate,
672
+ "rose_weight": f"{config.rose_initial_weight}→{config.rose_max_weight}",
673
+ "cayley_loss": config.use_cayley_loss,
674
+ "optimizer": "AdamW",
675
+ "scheduler": config.scheduler_type,
676
+ },
677
+
678
+ # Files (organized by model/run)
679
+ "files": {
680
+ "weights_safetensors": f"weights/{model_name}/{config.run_id}/best_model_acc{metrics.get('best_val_acc', 0.0):.2f}.safetensors",
681
+ "weights_pytorch": f"weights/{model_name}/{config.run_id}/best_model.pth",
682
+ "config": f"weights/{model_name}/{config.run_id}/david_config.json",
683
+ "training_config": f"weights/{model_name}/{config.run_id}/train_config.json",
684
+ "tensorboard": f"runs/{model_name}/{config.run_id}/"
685
+ }
686
+ }
687
+
688
+ with open(filepath, 'w') as f:
689
+ json.dump(best_model_info, f, indent=2)
690
+
691
+ print(f"[📄] Saved best_model.json: {filepath}")
692
+
693
+
694
+ def create_model_summary(
695
+ weights_dir: str,
696
+ config: DavidTrainingConfig,
697
+ david_config: DavidArchitectureConfig,
698
+ best_metrics: Dict,
699
+ model_name: str
700
+ ):
701
+ """Create prominent model summary with accuracy front and center."""
702
+
703
+ summary_path = os.path.join(weights_dir, 'MODEL_SUMMARY.txt')
704
+
705
+ best_acc = best_metrics.get('best_val_acc', 0.0)
706
+ training_history = best_metrics.get('training_history', {})
707
+
708
+ summary = f"""
709
+ ╔══════════════════════════════════════════════════════════════╗
710
+ ║ DAVID MODEL SUMMARY ║
711
+ ╠══════════════════════════════════════════════════════════════╣
712
+ ║ ║
713
+ ║ 🎯 VALIDATION ACCURACY: {best_acc:.2f}% ║
714
+ ║ ║
715
+ ╚══════════════════════════════════════════��═══════════════════╝
716
+
717
+ MODEL: {model_name}
718
+ RUN ID: {config.run_id}
719
+ BEST EPOCH: {best_metrics.get('best_epoch', 0) + 1}/{config.num_epochs}
720
+
721
+ ═══════════════════════════════════════════════════════════════
722
+
723
+ 📊 PERFORMANCE BREAKDOWN
724
+
725
+ Final Training Accuracy: {best_metrics.get('final_train_acc', 0.0):.2f}%
726
+ Best Validation Accuracy: {best_acc:.2f}%
727
+
728
+ Per-Scale Accuracies:
729
+ """
730
+
731
+ scale_accs = best_metrics.get('scale_accuracies', {})
732
+ for scale in sorted(scale_accs.keys()):
733
+ acc = scale_accs[scale]
734
+ summary += f" • Scale {scale:4d}: {acc:.2f}%\n"
735
+
736
+ summary += f"""
737
+ ═══════════════════════════════════════════════════════════════
738
+
739
+ 🏗️ ARCHITECTURE
740
+
741
+ Preset: {config.preset}
742
+ Sharing Mode: {david_config.sharing_mode}
743
+ Fusion Mode: {david_config.fusion_mode}
744
+ Scales: {len(david_config.scales)} scales - {david_config.scales}
745
+ Feature Dim: {david_config.feature_dim}
746
+ Parameters: {best_metrics.get('parameters', 0):,}
747
+
748
+ ═══════════════════════════════════════════════════════════════
749
+
750
+ 📈 TRAINING CURVE
751
+
752
+ """
753
+
754
+ if training_history and 'val_acc' in training_history:
755
+ summary += "Epoch | Train Acc | Val Acc | Learning Rate\n"
756
+ summary += "------|-----------|----------|--------------\n"
757
+
758
+ for i, epoch in enumerate(training_history.get('epochs', [])):
759
+ train_acc = training_history['train_acc'][i] if i < len(training_history['train_acc']) else 0
760
+ val_acc = training_history['val_acc'][i] if i < len(training_history['val_acc']) else 0
761
+ lr = training_history['lr'][i] if i < len(training_history['lr']) else 0
762
+
763
+ marker = " 👑" if val_acc == best_acc else ""
764
+ summary += f"{epoch:5d} | {train_acc:8.2f}% | {val_acc:7.2f}%{marker} | {lr:.2e}\n"
765
+
766
+ summary += f"""
767
+ ═══════════════════════════════════════════════════════════════
768
+
769
+ 📁 FILES
770
+
771
+ Best Model: best_model_acc{best_acc:.2f}.safetensors
772
+ Config: david_config.json
773
+ Training Cfg: train_config.json
774
+ History: training_history.json
775
+
776
+ ═══════════════════════════════════════════════════════════════
777
+
778
+ Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
779
+ """
780
+
781
+ with open(summary_path, 'w') as f:
782
+ f.write(summary)
783
+
784
+ print(f"[📄] Created MODEL_SUMMARY.txt")
785
+ return summary_path
786
+
787
+
788
+ def update_models_index(
789
+ config: DavidTrainingConfig,
790
+ david_config: DavidArchitectureConfig,
791
+ best_metrics: Dict,
792
+ model_name: str
793
+ ):
794
+ """Update master models index file tracking all trained models."""
795
+
796
+ if not config.upload_to_hub or not config.hf_repo:
797
+ return
798
+
799
+ try:
800
+ from huggingface_hub import hf_hub_download
801
+ api = HfApi()
802
+
803
+ # Try to download existing index
804
+ try:
805
+ index_path = hf_hub_download(
806
+ repo_id=config.hf_repo,
807
+ filename="MODELS_INDEX.json",
808
+ repo_type="model"
809
+ )
810
+ with open(index_path, 'r') as f:
811
+ models_index = json.load(f)
812
+ except:
813
+ # Create new index if doesn't exist
814
+ models_index = {
815
+ "repository": config.hf_repo,
816
+ "updated": datetime.now().isoformat(),
817
+ "models": []
818
+ }
819
+
820
+ # Add current model entry
821
+ model_entry = {
822
+ "model_name": model_name,
823
+ "run_id": config.run_id,
824
+ "timestamp": datetime.now().isoformat(),
825
+ "best_val_acc": best_metrics.get('best_val_acc', 0.0),
826
+ "best_epoch": best_metrics.get('best_epoch', 0),
827
+ "num_scales": len(david_config.scales),
828
+ "scales": david_config.scales,
829
+ "parameters": best_metrics.get('parameters', 0),
830
+ "sharing_mode": david_config.sharing_mode,
831
+ "fusion_mode": david_config.fusion_mode,
832
+ "preset": config.preset,
833
+ "weights_path": f"weights/{model_name}/{config.run_id}/best_model_acc{best_metrics.get('best_val_acc', 0.0):.2f}.safetensors",
834
+ "config_path": f"weights/{model_name}/{config.run_id}/david_config.json",
835
+ "history_path": f"weights/{model_name}/{config.run_id}/training_history.json"
836
+ }
837
+
838
+ # Remove old entry for same run_id if exists (update)
839
+ models_index["models"] = [m for m in models_index["models"] if m.get("run_id") != config.run_id]
840
+ models_index["models"].append(model_entry)
841
+
842
+ # Sort by accuracy (descending)
843
+ models_index["models"].sort(key=lambda x: x.get("best_val_acc", 0), reverse=True)
844
+ models_index["updated"] = datetime.now().isoformat()
845
+ models_index["total_models"] = len(models_index["models"])
846
+
847
+ # Save locally
848
+ with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.json') as f:
849
+ json.dump(models_index, f, indent=2)
850
+ temp_path = f.name
851
+
852
+ # Upload to hub root
853
+ api.upload_file(
854
+ path_or_fileobj=temp_path,
855
+ path_in_repo="MODELS_INDEX.json",
856
+ repo_id=config.hf_repo,
857
+ commit_message=f"Update models index - {model_name} @ {best_metrics.get('best_val_acc', 0.0):.2f}%"
858
+ )
859
+
860
+ os.unlink(temp_path)
861
+ print(f"[📊] Updated MODELS_INDEX.json - {len(models_index['models'])} models tracked")
862
+
863
+ except Exception as e:
864
+ print(f"[⚠️] Failed to update models index: {e}")
865
+
866
+
867
+ def upload_to_huggingface(
868
+ local_dir: str,
869
+ repo_id: str,
870
+ commit_message: str,
871
+ path_in_repo: Optional[str] = None,
872
+ patterns: Optional[List[str]] = None
873
+ ):
874
+ """Upload directory to HuggingFace Hub."""
875
+
876
+ try:
877
+ api = HfApi()
878
+
879
+ # Create repo if it doesn't exist
880
+ try:
881
+ create_repo(repo_id, exist_ok=True, repo_type="model")
882
+ print(f"[🤗] Repo ready: {repo_id}")
883
+ except Exception as e:
884
+ print(f"[⚠️] Repo exists or creation failed: {e}")
885
+
886
+ # Upload folder
887
+ if patterns:
888
+ # Upload specific patterns
889
+ for pattern in patterns:
890
+ matching_files = list(Path(local_dir).rglob(pattern))
891
+ for file_path in matching_files:
892
+ rel_path = file_path.relative_to(local_dir)
893
+ if path_in_repo:
894
+ repo_path = f"{path_in_repo}/{rel_path}"
895
+ else:
896
+ repo_path = str(rel_path)
897
+
898
+ api.upload_file(
899
+ path_or_fileobj=str(file_path),
900
+ path_in_repo=repo_path,
901
+ repo_id=repo_id,
902
+ commit_message=commit_message
903
+ )
904
+ else:
905
+ # Upload entire folder
906
+ api.upload_folder(
907
+ folder_path=local_dir,
908
+ repo_id=repo_id,
909
+ path_in_repo=path_in_repo,
910
+ commit_message=commit_message
911
+ )
912
+
913
+ print(f"[✅] Uploaded to Hub: https://huggingface.co/{repo_id}")
914
+
915
+ except Exception as e:
916
+ print(f"[❌] Hub upload failed: {e}")
917
+ print(f" Continuing training (files saved locally)")
918
+
919
+
920
+ def prepare_hub_upload(
921
+ weights_dir: str,
922
+ runs_dir: str,
923
+ config: DavidTrainingConfig,
924
+ david_config: DavidArchitectureConfig,
925
+ best_metrics: Dict,
926
+ model_name: str
927
+ ):
928
+ """Prepare and upload all artifacts to HuggingFace Hub."""
929
+
930
+ if not config.upload_to_hub or not config.hf_repo:
931
+ return
932
+
933
+ print("\n[🤗] Preparing HuggingFace Hub upload...")
934
+
935
+ # Create model summary file
936
+ summary_path = create_model_summary(weights_dir, config, david_config, best_metrics, model_name)
937
+
938
+ # Update master models index
939
+ update_models_index(config, david_config, best_metrics, model_name)
940
+
941
+ api = HfApi()
942
+ try:
943
+ create_repo(config.hf_repo, exist_ok=True, repo_type="model")
944
+ except:
945
+ pass
946
+
947
+ # Create temporary directory for root files
948
+ with tempfile.TemporaryDirectory() as temp_dir:
949
+ # Generate README at root
950
+ readme_path = os.path.join(temp_dir, "README.md")
951
+ readme_content = generate_model_readme(config, david_config, best_metrics, config.run_id)
952
+ with open(readme_path, 'w') as f:
953
+ f.write(readme_content)
954
+ print(f"[📝] Generated README.md")
955
+
956
+ # Save best_model.json at root
957
+ best_json_path = os.path.join(temp_dir, "best_model.json")
958
+ save_best_model_json(best_json_path, best_metrics, config, david_config)
959
+
960
+ # Upload root files (README.md, best_model.json)
961
+ print(f"[📤] Uploading root files...")
962
+
963
+ api.upload_file(
964
+ path_or_fileobj=readme_path,
965
+ path_in_repo="README.md",
966
+ repo_id=config.hf_repo,
967
+ commit_message=f"Update README - Run {config.run_id}"
968
+ )
969
+
970
+ api.upload_file(
971
+ path_or_fileobj=best_json_path,
972
+ path_in_repo="best_model.json",
973
+ repo_id=config.hf_repo,
974
+ commit_message=f"Update metrics - Run {config.run_id}"
975
+ )
976
+
977
+ # Upload ONLY essential weight files (not entire directory!)
978
+ weights_repo_path = f"weights/{model_name}/{config.run_id}"
979
+ best_acc = best_metrics.get('best_val_acc', 0.0)
980
+
981
+ print(f"[📤] Uploading essential files to {weights_repo_path}...")
982
+
983
+ # List of specific files to upload (not entire directory)
984
+ files_to_upload = [
985
+ ('MODEL_SUMMARY.txt', 'MODEL_SUMMARY.txt'),
986
+ ('training_history.json', 'training_history.json'),
987
+ ('david_config.json', 'david_config.json'),
988
+ ('train_config.json', 'train_config.json'),
989
+ (f'best_model_acc{best_acc:.2f}.safetensors', f'best_model_acc{best_acc:.2f}.safetensors'),
990
+ (f'best_model_acc{best_acc:.2f}_metadata.json', f'best_model_acc{best_acc:.2f}_metadata.json'),
991
+ ]
992
+
993
+ for local_filename, repo_filename in files_to_upload:
994
+ local_path = os.path.join(weights_dir, local_filename)
995
+ if os.path.exists(local_path):
996
+ try:
997
+ api.upload_file(
998
+ path_or_fileobj=local_path,
999
+ path_in_repo=f"{weights_repo_path}/{repo_filename}",
1000
+ repo_id=config.hf_repo,
1001
+ commit_message=f"Update {repo_filename} - Run {config.run_id}"
1002
+ )
1003
+ except Exception as e:
1004
+ print(f"[⚠️] Failed to upload {repo_filename}: {e}")
1005
+
1006
+ print(f"[✅] Uploaded to Hub: https://huggingface.co/{config.hf_repo}")
1007
+
1008
+ # Upload tensorboard logs (only if they exist and it's final upload)
1009
+ # Skip TensorBoard during training to avoid huge uploads every epoch
1010
+ # if os.path.exists(runs_dir):
1011
+ # runs_repo_path = f"runs/{model_name}/{config.run_id}"
1012
+ # print(f"[📤] Uploading TensorBoard logs to {runs_repo_path}...")
1013
+ # upload_to_huggingface(
1014
+ # local_dir=runs_dir,
1015
+ # repo_id=config.hf_repo,
1016
+ # commit_message=f"Upload TensorBoard logs - {model_name} - Run {config.run_id}",
1017
+ # path_in_repo=runs_repo_path
1018
+ # )
1019
+
1020
+
1021
+ # ============================================================================
1022
+ # CHECKPOINT UTILITIES
1023
+ # ============================================================================
1024
+
1025
+ def save_checkpoint(
1026
+ filepath: str,
1027
+ david: David,
1028
+ optimizer: torch.optim.Optimizer,
1029
+ scheduler: Optional[torch.optim.lr_scheduler._LRScheduler],
1030
+ epoch: int,
1031
+ metrics: Dict,
1032
+ train_config: DavidTrainingConfig
1033
+ ):
1034
+ """Save checkpoint in PyTorch and/or SafeTensors format."""
1035
+
1036
+ checkpoint = {
1037
+ 'epoch': epoch,
1038
+ 'model_state_dict': david.state_dict(),
1039
+ 'optimizer_state_dict': optimizer.state_dict(),
1040
+ 'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
1041
+ 'metrics': metrics,
1042
+ 'train_config': train_config.to_dict(),
1043
+ }
1044
+
1045
+ # Add accuracy to filename if available
1046
+ val_acc = metrics.get('best_val_acc') or metrics.get('val_acc')
1047
+ if val_acc:
1048
+ acc_suffix = f"_acc{val_acc:.2f}"
1049
+ filepath = filepath + acc_suffix
1050
+
1051
+ if train_config.save_format in ['pytorch', 'both']:
1052
+ torch.save(checkpoint, filepath + '.pth')
1053
+ print(f"[💾] Saved PyTorch: {filepath}.pth")
1054
+
1055
+ if train_config.save_format in ['safetensors', 'both']:
1056
+ try:
1057
+ from safetensors.torch import save_file
1058
+
1059
+ # Save model state
1060
+ model_state = {k: v.contiguous() for k, v in david.state_dict().items()}
1061
+ save_file(model_state, filepath + '.safetensors')
1062
+
1063
+ # Save metadata separately (now includes full training history)
1064
+ metadata = {k: v for k, v in checkpoint.items()
1065
+ if k not in ['model_state_dict']}
1066
+ with open(filepath + '_metadata.json', 'w') as f:
1067
+ json.dump(metadata, f, indent=2, default=str)
1068
+
1069
+ print(f"[💾] Saved SafeTensors: {filepath}.safetensors")
1070
+ except ImportError:
1071
+ print(f"[⚠️] SafeTensors not available, skipping")
1072
+
1073
+
1074
+ def load_checkpoint(
1075
+ checkpoint_path: str,
1076
+ david: David,
1077
+ optimizer: Optional[torch.optim.Optimizer] = None,
1078
+ scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
1079
+ device: str = "cuda"
1080
+ ) -> Tuple[int, Dict]:
1081
+ """Load checkpoint and return epoch and metrics."""
1082
+
1083
+ if checkpoint_path.endswith('.safetensors'):
1084
+ # Load SafeTensors format
1085
+ try:
1086
+ from safetensors.torch import load_file
1087
+
1088
+ model_state = load_file(checkpoint_path, device=device)
1089
+ david.load_state_dict(model_state)
1090
+
1091
+ # Load metadata
1092
+ metadata_path = checkpoint_path.replace('.safetensors', '_metadata.json')
1093
+ with open(metadata_path, 'r') as f:
1094
+ metadata = json.load(f)
1095
+
1096
+ epoch = metadata.get('epoch', 0)
1097
+ metrics = metadata.get('metrics', {})
1098
+
1099
+ if optimizer and 'optimizer_state_dict' in metadata:
1100
+ optimizer.load_state_dict(metadata['optimizer_state_dict'])
1101
+
1102
+ if scheduler and 'scheduler_state_dict' in metadata and metadata['scheduler_state_dict']:
1103
+ scheduler.load_state_dict(metadata['scheduler_state_dict'])
1104
+
1105
+ print(f"[✅] Loaded from SafeTensors: {checkpoint_path}")
1106
+ return epoch, metrics
1107
+
1108
+ except ImportError:
1109
+ raise ImportError("safetensors not installed")
1110
+
1111
+ else:
1112
+ # Load PyTorch format
1113
+ checkpoint = torch.load(checkpoint_path, map_location=device)
1114
+
1115
+ david.load_state_dict(checkpoint['model_state_dict'])
1116
+
1117
+ if optimizer and 'optimizer_state_dict' in checkpoint:
1118
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
1119
+
1120
+ if scheduler and 'scheduler_state_dict' in checkpoint and checkpoint['scheduler_state_dict']:
1121
+ scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
1122
+
1123
+ print(f"[✅] Loaded from PyTorch: {checkpoint_path}")
1124
+ return checkpoint['epoch'], checkpoint.get('metrics', {})
1125
+
1126
+
1127
+ # ============================================================================
1128
+ # DATASET
1129
+ # ============================================================================
1130
+
1131
+ class ImageNetHFDataset(Dataset):
1132
+ """PyTorch Dataset wrapper for HuggingFace ImageNet features."""
1133
+
1134
+ def __init__(self, dataset_name: str, model_variant: str, split: str = "train"):
1135
+ # Load only the specific split to avoid downloading all data
1136
+ print(f"[📥] Loading {split} split for {model_variant}...")
1137
+ self.dataset = load_dataset(
1138
+ dataset_name,
1139
+ name=model_variant, # Dataset configuration/variant name
1140
+ split=split # Only load this specific split
1141
+ )
1142
+ self.length = len(self.dataset)
1143
+ print(f"[✅] Loaded {self.length:,} samples from {split} split")
1144
+
1145
+ def __len__(self):
1146
+ return self.length
1147
+
1148
+ def __getitem__(self, idx):
1149
+ item = self.dataset[idx]
1150
+ features = torch.tensor(item['clip_features'], dtype=torch.float32)
1151
+ label = torch.tensor(item['label'], dtype=torch.long)
1152
+ return features, label
1153
+
1154
+
1155
+ def create_dataloaders(config: DavidTrainingConfig):
1156
+ """Create train and validation dataloaders."""
1157
+
1158
+ train_dataset = ImageNetHFDataset(
1159
+ config.dataset_name, config.model_variant, "train"
1160
+ )
1161
+ val_dataset = ImageNetHFDataset(
1162
+ config.dataset_name, config.model_variant, "validation"
1163
+ )
1164
+
1165
+ train_loader = DataLoader(
1166
+ train_dataset,
1167
+ batch_size=config.batch_size,
1168
+ shuffle=True,
1169
+ num_workers=config.num_workers,
1170
+ pin_memory=config.pin_memory,
1171
+ prefetch_factor=config.prefetch_factor,
1172
+ persistent_workers=config.persistent_workers
1173
+ )
1174
+
1175
+ val_loader = DataLoader(
1176
+ val_dataset,
1177
+ batch_size=config.batch_size * 2,
1178
+ shuffle=False,
1179
+ num_workers=config.num_workers,
1180
+ pin_memory=config.pin_memory,
1181
+ prefetch_factor=config.prefetch_factor,
1182
+ persistent_workers=config.persistent_workers
1183
+ )
1184
+
1185
+ return train_loader, val_loader
1186
+
1187
+
1188
+ # ============================================================================
1189
+ # CRYSTAL GENERATOR
1190
+ # ============================================================================
1191
+
1192
+ class CrystalGenerator:
1193
+ """Generate crystals for all scales."""
1194
+
1195
+ def __init__(self, num_classes: int, scales: List[int], device: str = "cuda"):
1196
+ self.num_classes = num_classes
1197
+ self.scales = scales
1198
+ self.device = device
1199
+ self.factories = {
1200
+ scale: SimplexFactory(k=4, embed_dim=scale, method="random")
1201
+ for scale in scales
1202
+ }
1203
+
1204
+ def generate(self, seed: int = 42) -> Tuple[Dict[int, torch.Tensor], Dict[int, torch.Tensor]]:
1205
+ """Generate anchors and crystals for all scales."""
1206
+
1207
+ anchors_dict = {}
1208
+ crystals_dict = {}
1209
+
1210
+ for scale in tqdm(self.scales, desc="Generating crystals"):
1211
+ factory = self.factories[scale]
1212
+ batch_crystals = []
1213
+
1214
+ for class_idx in range(self.num_classes):
1215
+ crystal = factory.build(
1216
+ backend="torch",
1217
+ device=self.device,
1218
+ dtype=torch.float32,
1219
+ seed=seed + class_idx,
1220
+ validate=True
1221
+ )
1222
+ batch_crystals.append(crystal)
1223
+
1224
+ crystals = torch.stack(batch_crystals)
1225
+ anchors = F.normalize(crystals[:, 0, :], dim=-1)
1226
+
1227
+ # Verify anchor diversity
1228
+ anchor_sims = anchors @ anchors.T
1229
+ off_diag = anchor_sims[~torch.eye(self.num_classes, dtype=bool, device=anchors.device)]
1230
+ max_sim = off_diag.max().item()
1231
+ mean_sim = off_diag.mean().item()
1232
+
1233
+ print(f" Scale {scale}: max_sim={max_sim:.4f}, mean_sim={mean_sim:.4f}")
1234
+
1235
+ if max_sim > 0.99:
1236
+ print(f" ⚠️ WARNING: Anchors too similar at scale {scale}!")
1237
+
1238
+ anchors_dict[scale] = anchors
1239
+ crystals_dict[scale] = crystals
1240
+
1241
+ return anchors_dict, crystals_dict
1242
+
1243
+
1244
+ # ============================================================================
1245
+ # TRAINING LOOP
1246
+ # ============================================================================
1247
+
1248
+ def train_epoch(
1249
+ david: David,
1250
+ train_loader: DataLoader,
1251
+ optimizer: torch.optim.Optimizer,
1252
+ criterion: MultiScaleCrystalLoss,
1253
+ anchors_dict: Dict[int, torch.Tensor],
1254
+ crystals_dict: Dict[int, torch.Tensor],
1255
+ epoch: int,
1256
+ config: DavidTrainingConfig,
1257
+ writer: Optional[SummaryWriter],
1258
+ global_step: int
1259
+ ) -> Tuple[float, float, int, Dict]:
1260
+ """Train for one epoch - Pure FP32."""
1261
+
1262
+ david.train()
1263
+ david.update_epoch(epoch)
1264
+
1265
+ total_loss = 0
1266
+ correct = 0
1267
+ total = 0
1268
+ loss_components_sum = {}
1269
+
1270
+ active_scales = david.get_active_scales()
1271
+
1272
+ pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.num_epochs}")
1273
+
1274
+ for batch_idx, (features, labels) in enumerate(pbar):
1275
+ features = features.cuda(non_blocking=True)
1276
+ labels = labels.cuda(non_blocking=True)
1277
+
1278
+ # Zero gradients
1279
+ optimizer.zero_grad()
1280
+
1281
+ # Forward pass - Pure FP32, no autocast
1282
+ combined, logits_list, features_list, fusion_weights = david(
1283
+ features, anchors_dict, return_all_scales=True
1284
+ )
1285
+
1286
+ # Compute loss
1287
+ losses = criterion(
1288
+ combined, logits_list, features_list,
1289
+ labels, crystals_dict, epoch
1290
+ )
1291
+
1292
+ # Backward
1293
+ losses['total'].backward()
1294
+
1295
+ # Gradient analysis
1296
+ if config.track_gradients and batch_idx % config.log_interval == 0:
1297
+ grad_stats = analyze_gradients(david, config)
1298
+ if writer:
1299
+ step = global_step + batch_idx
1300
+ writer.add_scalar('train/grad_mean', grad_stats['mean'], step)
1301
+ writer.add_scalar('train/grad_max', grad_stats['max'], step)
1302
+ writer.add_scalar('train/grad_num_small', grad_stats['num_small'], step)
1303
+
1304
+ # Scale small gradients
1305
+ scale_small_gradients(david, config)
1306
+
1307
+ # Gradient clipping
1308
+ torch.nn.utils.clip_grad_norm_(david.parameters(), config.gradient_clip)
1309
+
1310
+ # Optimizer step
1311
+ optimizer.step()
1312
+
1313
+ # Metrics
1314
+ total_loss += losses['total'].item()
1315
+ _, predicted = torch.max(combined, 1)
1316
+ total += labels.size(0)
1317
+ correct += (predicted == labels).sum().item()
1318
+
1319
+ # Accumulate loss components
1320
+ for key, value in losses.items():
1321
+ if key not in loss_components_sum:
1322
+ loss_components_sum[key] = 0.0
1323
+ loss_components_sum[key] += value.item()
1324
+
1325
+ # Logging
1326
+ if writer and batch_idx % config.log_interval == 0:
1327
+ step = global_step + batch_idx
1328
+ writer.add_scalar('train/loss_batch', losses['total'].item(), step)
1329
+ writer.add_scalar('train/acc_batch', 100 * correct / total, step)
1330
+
1331
+ if config.log_loss_components:
1332
+ for key, value in losses.items():
1333
+ if key != 'total':
1334
+ writer.add_scalar(f'train/loss_{key}', value.item(), step)
1335
+
1336
+ if config.log_fusion_weights and fusion_weights is not None:
1337
+ if fusion_weights.dim() == 2:
1338
+ mean_weights = fusion_weights.mean(dim=0)
1339
+ for i, w in enumerate(mean_weights):
1340
+ if i < len(active_scales):
1341
+ writer.add_scalar(
1342
+ f'train/fusion_weight_{active_scales[i]}',
1343
+ w.item(), step
1344
+ )
1345
+
1346
+ writer.add_scalar('train/lr', optimizer.param_groups[0]['lr'], step)
1347
+
1348
+ pbar.set_postfix({
1349
+ 'loss': f'{total_loss / (batch_idx + 1):.4f}',
1350
+ 'acc': f'{100 * correct / total:.2f}%'
1351
+ })
1352
+
1353
+ global_step += 1
1354
+
1355
+ # Average loss components
1356
+ avg_components = {k: v / len(train_loader) for k, v in loss_components_sum.items()}
1357
+
1358
+ return (
1359
+ total_loss / len(train_loader),
1360
+ 100 * correct / total,
1361
+ global_step,
1362
+ avg_components
1363
+ )
1364
+
1365
+
1366
+ @torch.no_grad()
1367
+ def validate(
1368
+ david: David,
1369
+ val_loader: DataLoader,
1370
+ anchors_dict: Dict[int, torch.Tensor],
1371
+ config: DavidTrainingConfig
1372
+ ) -> Tuple[float, Dict[int, float]]:
1373
+ """Validate model - Pure FP32."""
1374
+
1375
+ david.eval()
1376
+
1377
+ correct = 0
1378
+ total = 0
1379
+ active_scales = david.get_active_scales()
1380
+ scale_correct = {scale: 0 for scale in active_scales}
1381
+
1382
+ for features, labels in tqdm(val_loader, desc="Validation", leave=False):
1383
+ features = features.cuda(non_blocking=True)
1384
+ labels = labels.cuda(non_blocking=True)
1385
+
1386
+ # Forward pass - no autocast
1387
+ combined, logits_list, _, _ = david(
1388
+ features, anchors_dict, return_all_scales=True
1389
+ )
1390
+
1391
+ _, predicted = torch.max(combined, 1)
1392
+ total += labels.size(0)
1393
+ correct += (predicted == labels).sum().item()
1394
+
1395
+ for i, scale in enumerate(active_scales):
1396
+ if i < len(logits_list):
1397
+ _, scale_pred = torch.max(logits_list[i], 1)
1398
+ scale_correct[scale] += (scale_pred == labels).sum().item()
1399
+
1400
+ accuracy = 100 * correct / total
1401
+ scale_accs = {s: 100 * scale_correct[s] / total for s in scale_correct}
1402
+
1403
+ return accuracy, scale_accs
1404
+
1405
+
1406
+ # ============================================================================
1407
+ # MAIN TRAINING FUNCTION
1408
+ # ============================================================================
1409
+
1410
+ def train_david(config: DavidTrainingConfig):
1411
+ """Main training pipeline."""
1412
+
1413
+ # Enable TensorFloat32 for better performance on Ampere+ GPUs
1414
+ torch.set_float32_matmul_precision('high')
1415
+
1416
+ print("="*80)
1417
+ print("🌟 DAVID TRAINING PIPELINE")
1418
+ print("="*80)
1419
+ print(f"Run ID: {config.run_id}")
1420
+ print(f"Preset: {config.preset}")
1421
+ print(f"Batch Size: {config.batch_size}")
1422
+ print(f"Learning Rate: {config.learning_rate}")
1423
+ print(f"Mixed Precision: {config.use_mixed_precision}")
1424
+ print(f"TensorFloat32: Enabled (high precision)")
1425
+ print("="*80)
1426
+
1427
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
1428
+
1429
+ # Load or create David config FIRST (needed for model_name)
1430
+ if config.custom_config_path:
1431
+ david_config = DavidArchitectureConfig.from_json(config.custom_config_path)
1432
+ print(f"[📁] Loaded custom config: {config.custom_config_path}")
1433
+ elif config.preset:
1434
+ david_config = DavidPresets.get_preset(config.preset)
1435
+ print(f"[⚙️] Using preset: {config.preset}")
1436
+ else:
1437
+ raise ValueError("Must specify either preset or custom_config_path")
1438
+
1439
+ # Create model name from architecture
1440
+ model_name = f"David-{david_config.sharing_mode}-{david_config.fusion_mode}"
1441
+ print(f"[🏷️] Model: {model_name}")
1442
+
1443
+ # Setup directories with proper hierarchy: weights/model_name/timestamp/
1444
+ weights_dir = os.path.join(config.base_dir, "weights", model_name, config.run_id)
1445
+ runs_dir = os.path.join(config.base_dir, "runs", model_name, config.run_id)
1446
+ os.makedirs(weights_dir, exist_ok=True)
1447
+ os.makedirs(runs_dir, exist_ok=True)
1448
+
1449
+ print(f"[📁] Weights: {weights_dir}")
1450
+ print(f"[📁] Logs: {runs_dir}")
1451
+
1452
+ writer = SummaryWriter(runs_dir)
1453
+
1454
+ # Apply overrides
1455
+ if config.num_classes_override:
1456
+ david_config.num_classes = config.num_classes_override
1457
+ if config.use_belly_override is not None:
1458
+ david_config.use_belly = config.use_belly_override
1459
+ if config.belly_expand_override is not None:
1460
+ david_config.belly_expand = config.belly_expand_override
1461
+ if config.progressive_training_override is not None:
1462
+ david_config.progressive_training = config.progressive_training_override
1463
+ if not david_config.progressive_training:
1464
+ # Disable warmup if progressive training disabled
1465
+ david_config.scale_warmup_epochs = {s: 0 for s in david_config.scales}
1466
+
1467
+ # Override scale warmup schedule if provided
1468
+ if config.scale_warmup_epochs_override is not None:
1469
+ david_config.scale_warmup_epochs = config.scale_warmup_epochs_override
1470
+ # Enable progressive training if custom schedule provided
1471
+ if not david_config.progressive_training:
1472
+ print(f"[⚙️] Enabling progressive training (custom warmup schedule provided)")
1473
+ david_config.progressive_training = True
1474
+
1475
+ print(f"[⚙️] Progressive training: {david_config.progressive_training}")
1476
+ if david_config.progressive_training:
1477
+ print(f" Scale warmup schedule: {david_config.scale_warmup_epochs}")
1478
+
1479
+ # Save configs
1480
+ david_config_path = os.path.join(weights_dir, "david_config.json")
1481
+ david_config.to_json(david_config_path)
1482
+ print(f"[💾] Saved David config: {david_config_path}")
1483
+
1484
+ train_config_path = os.path.join(weights_dir, "train_config.json")
1485
+ config.to_json(train_config_path)
1486
+ print(f"[💾] Saved training config: {train_config_path}")
1487
+
1488
+ # Initialize David
1489
+ david = David.from_config(david_config).cuda()
1490
+ print(f"\n{david}\n")
1491
+
1492
+ # Count parameters
1493
+ total_params = sum(p.numel() for p in david.parameters())
1494
+ trainable_params = sum(p.numel() for p in david.parameters() if p.requires_grad)
1495
+ print(f"[📊] Total Parameters: {total_params:,}")
1496
+ print(f"[📊] Trainable Parameters: {trainable_params:,}")
1497
+
1498
+ # Load data
1499
+ train_loader, val_loader = create_dataloaders(config)
1500
+
1501
+ # Generate crystals
1502
+ crystal_gen = CrystalGenerator(
1503
+ david_config.num_classes,
1504
+ david_config.scales,
1505
+ str(device)
1506
+ )
1507
+ anchors_dict, crystals_dict = crystal_gen.generate()
1508
+
1509
+ # Setup training
1510
+ criterion = MultiScaleCrystalLoss(
1511
+ scales=david_config.scales,
1512
+ num_classes=david_config.num_classes,
1513
+ use_rose_loss=config.use_rose_loss,
1514
+ use_cayley_loss=config.use_cayley_loss,
1515
+ rose_initial_weight=config.rose_initial_weight,
1516
+ rose_max_weight=config.rose_max_weight,
1517
+ cayley_weight=config.cayley_weight,
1518
+ scale_loss_balance=config.scale_loss_balance
1519
+ ).cuda()
1520
+
1521
+ optimizer = create_optimizer(david, config)
1522
+ scheduler = create_scheduler(optimizer, config)
1523
+
1524
+ controller = AdaptiveTrainingController(david, config)
1525
+
1526
+ # Tracking
1527
+ best_val_acc = 0.0
1528
+ best_epoch = 0
1529
+ best_scale_accs = {}
1530
+ global_step = 0
1531
+ final_train_acc = 0.0
1532
+ final_train_loss = 0.0
1533
+
1534
+ # Training history for epoch-by-epoch tracking
1535
+ training_history = {
1536
+ 'epochs': [],
1537
+ 'train_loss': [],
1538
+ 'train_acc': [],
1539
+ 'val_acc': [],
1540
+ 'scale_accs': {},
1541
+ 'lr': []
1542
+ }
1543
+
1544
+ # DIAGNOSTIC: Test one forward/backward pass before training
1545
+ print("\n[🔍] Running diagnostic forward/backward pass...")
1546
+ #david.compile()
1547
+ david.train()
1548
+
1549
+ # Get a small batch
1550
+ for features_test, labels_test in train_loader:
1551
+ features_test = features_test.cuda(non_blocking=True)[:8] # Just 8 samples
1552
+ labels_test = labels_test.cuda(non_blocking=True)[:8]
1553
+
1554
+ # Forward
1555
+ combined_test, logits_test, features_test_out, _ = david(
1556
+ features_test, anchors_dict, return_all_scales=True
1557
+ )
1558
+
1559
+ # Loss
1560
+ losses_test = criterion(
1561
+ combined_test, logits_test, features_test_out,
1562
+ labels_test, crystals_dict, epoch=0
1563
+ )
1564
+
1565
+ print(f" Initial loss: {losses_test['total'].item():.6f}")
1566
+ print(f" Loss components:")
1567
+ for key, value in losses_test.items():
1568
+ if key != 'total':
1569
+ print(f" {key}: {value.item():.6f}")
1570
+
1571
+ # Backward
1572
+ optimizer.zero_grad()
1573
+ losses_test['total'].backward()
1574
+
1575
+ # Check gradients
1576
+ grad_count = sum(1 for p in david.parameters() if p.grad is not None and p.grad.norm() > 0)
1577
+ total_grad_params = sum(1 for p in david.parameters() if p.requires_grad)
1578
+ print(f" Parameters with non-zero gradients: {grad_count}/{total_grad_params}")
1579
+
1580
+ if grad_count == 0:
1581
+ print(f" ❌ ERROR: No gradients! Training will not work.")
1582
+ return None, 0.0
1583
+ elif grad_count < total_grad_params * 0.5:
1584
+ print(f" ⚠️ WARNING: Less than 50% of parameters have gradients")
1585
+ else:
1586
+ print(f" ✅ Gradients look good")
1587
+
1588
+ break # Only test one batch
1589
+
1590
+ print("\n[🚀] Starting training...\n")
1591
+
1592
+ for epoch in range(config.num_epochs):
1593
+ epoch_start = time.time()
1594
+
1595
+ # Train
1596
+ train_loss, train_acc, global_step, loss_components = train_epoch(
1597
+ david, train_loader, optimizer, criterion,
1598
+ anchors_dict, crystals_dict, epoch, config,
1599
+ writer, global_step
1600
+ )
1601
+
1602
+ # Validate
1603
+ val_acc, scale_accs = validate(david, val_loader, anchors_dict, config)
1604
+
1605
+ # Update controller
1606
+ controller.update_metrics(scale_accs, val_acc)
1607
+ controller.apply_adaptive_strategies(scale_accs, epoch)
1608
+
1609
+ # Step scheduler
1610
+ if scheduler:
1611
+ scheduler.step()
1612
+
1613
+ epoch_time = time.time() - epoch_start
1614
+
1615
+ # Print
1616
+ print(f"\n📊 Epoch {epoch+1}/{config.num_epochs} ({epoch_time:.1f}s)")
1617
+ print(f" Train: Loss={train_loss:.4f}, Acc={train_acc:.2f}%")
1618
+ print(f" Val: Acc={val_acc:.2f}% (Best: {best_val_acc:.2f}%)")
1619
+ print(f" Active scales: {david.get_active_scales()}")
1620
+ print(f" LR: {optimizer.param_groups[0]['lr']:.2e}")
1621
+
1622
+ if config.log_loss_components and loss_components:
1623
+ print(f" Loss breakdown:")
1624
+ for key, value in sorted(loss_components.items()):
1625
+ if key != 'total':
1626
+ print(f" {key:20s}: {value:.6f}")
1627
+
1628
+ for scale, acc in scale_accs.items():
1629
+ frozen = "❄️" if controller.scales_frozen.get(scale, False) else "🔥"
1630
+ print(f" {frozen} Scale {scale}: {acc:.2f}%")
1631
+
1632
+ # Update tracking
1633
+ final_train_acc = train_acc
1634
+ final_train_loss = train_loss
1635
+
1636
+ # Record training history
1637
+ training_history['epochs'].append(epoch + 1)
1638
+ training_history['train_loss'].append(train_loss)
1639
+ training_history['train_acc'].append(train_acc)
1640
+ training_history['val_acc'].append(val_acc)
1641
+ training_history['lr'].append(optimizer.param_groups[0]['lr'])
1642
+
1643
+ # Record per-scale accuracies
1644
+ for scale, acc in scale_accs.items():
1645
+ if scale not in training_history['scale_accs']:
1646
+ training_history['scale_accs'][scale] = []
1647
+ training_history['scale_accs'][scale].append(acc)
1648
+
1649
+ # TensorBoard
1650
+ writer.add_scalar('train/loss', train_loss, epoch)
1651
+ writer.add_scalar('train/acc', train_acc, epoch)
1652
+ writer.add_scalar('val/acc', val_acc, epoch)
1653
+
1654
+ for scale, acc in scale_accs.items():
1655
+ writer.add_scalar(f'val/acc_scale_{scale}', acc, epoch)
1656
+
1657
+ # Save best
1658
+ if val_acc > best_val_acc:
1659
+ best_val_acc = val_acc
1660
+ best_epoch = epoch
1661
+ best_scale_accs = scale_accs.copy()
1662
+
1663
+ # Save training history alongside best model
1664
+ history_path = os.path.join(weights_dir, 'training_history.json')
1665
+ with open(history_path, 'w') as f:
1666
+ json.dump(training_history, f, indent=2)
1667
+
1668
+ save_checkpoint(
1669
+ os.path.join(weights_dir, 'best_model'),
1670
+ david, optimizer, scheduler, epoch,
1671
+ {
1672
+ 'best_val_acc': best_val_acc,
1673
+ 'best_epoch': best_epoch,
1674
+ 'scale_accuracies': best_scale_accs,
1675
+ 'training_history': training_history
1676
+ },
1677
+ config
1678
+ )
1679
+
1680
+ # Upload to hub when best model improves
1681
+ if config.upload_to_hub:
1682
+ best_metrics = {
1683
+ 'best_val_acc': best_val_acc,
1684
+ 'best_epoch': best_epoch,
1685
+ 'scale_accuracies': best_scale_accs,
1686
+ 'final_train_acc': train_acc,
1687
+ 'final_train_loss': train_loss,
1688
+ 'training_history': training_history,
1689
+ 'parameters': total_params
1690
+ }
1691
+ prepare_hub_upload(weights_dir, runs_dir, config, david_config, best_metrics, model_name)
1692
+
1693
+ # Periodic save
1694
+ if (epoch + 1) % config.save_interval == 0:
1695
+ save_checkpoint(
1696
+ os.path.join(weights_dir, f'checkpoint_epoch_{epoch+1}'),
1697
+ david, optimizer, scheduler, epoch,
1698
+ {'val_acc': val_acc},
1699
+ config
1700
+ )
1701
+
1702
+ # Final save
1703
+ save_checkpoint(
1704
+ os.path.join(weights_dir, 'final_model'),
1705
+ david, optimizer, scheduler, config.num_epochs - 1,
1706
+ {'final_val_acc': val_acc},
1707
+ config
1708
+ )
1709
+
1710
+ writer.close()
1711
+
1712
+ # Final hub upload with all artifacts
1713
+ if config.upload_to_hub:
1714
+ print("\n[🤗] Performing final HuggingFace Hub upload...")
1715
+ final_metrics = {
1716
+ 'best_val_acc': best_val_acc,
1717
+ 'best_epoch': best_epoch,
1718
+ 'scale_accuracies': best_scale_accs,
1719
+ 'final_train_acc': final_train_acc,
1720
+ 'final_train_loss': final_train_loss,
1721
+ 'training_history': training_history,
1722
+ 'parameters': total_params
1723
+ }
1724
+ prepare_hub_upload(weights_dir, runs_dir, config, david_config, final_metrics, model_name)
1725
+
1726
+ print("\n" + "="*80)
1727
+ print(f"🎉 Training Complete!")
1728
+ print(f" Best Val Acc: {best_val_acc:.2f}% (Epoch {best_epoch+1})")
1729
+ print(f" Final Train Acc: {final_train_acc:.2f}%")
1730
+ print(f" Weights: {weights_dir}")
1731
+ if config.upload_to_hub:
1732
+ print(f" Hub: https://huggingface.co/{config.hf_repo}")
1733
+ print("="*80)
1734
+
1735
+ return david, best_val_acc
1736
+
1737
+
1738
+ # ============================================================================
1739
+ # USAGE EXAMPLE
1740
+ # ============================================================================
1741
+
1742
+ if __name__ == "__main__":
1743
+ configs = []
1744
+ config = DavidTrainingConfig(
1745
+ preset="clip_vit_bigg14", # Uses progressive training by default
1746
+ model_variant="clip_vit_laion_bigg14",
1747
+
1748
+ num_epochs=10,
1749
+ batch_size=1024,
1750
+ learning_rate=1e-3,
1751
+
1752
+ use_mixed_precision=False, # leave off, mixed precision kills accuracy with rose
1753
+ gradient_clip=10.0,
1754
+
1755
+ use_rose_loss=True,
1756
+ rose_initial_weight=0.1,
1757
+ rose_max_weight=0.5,
1758
+ use_cayley_loss=False,
1759
+ progressive_training_override=False,
1760
+
1761
+ # Adaptive training (disabled by default for stability)
1762
+ freeze_strategy="none", # Set to "performance" to enable
1763
+ freeze_threshold=90.0,
1764
+
1765
+ save_format="safetensors",
1766
+
1767
+ # HuggingFace Hub upload
1768
+ # DO NOT PUT YOUR HF TOKEN IN THE CONFIG, load it in os.environ or other means.
1769
+ upload_to_hub=False, # set to true for your repo
1770
+ hf_repo= "YourName/Repo" # "AbstractPhil/gated-david",
1771
+ )
1772
+ #configs.append(DavidTrainingConfig(
1773
+ # preset="high_accuracy", # Uses progressive training by default
1774
+ # model_variant="clip_vit_laion_b32",
1775
+ #
1776
+ # num_epochs=20,
1777
+ # batch_size=1024,
1778
+ # learning_rate=1e-3,
1779
+ #
1780
+ # use_mixed_precision=False,
1781
+ # gradient_clip=10.0,
1782
+ #
1783
+ # use_rose_loss=True,
1784
+ # rose_initial_weight=0.1,
1785
+ # rose_max_weight=0.5,
1786
+ # use_cayley_loss=False,
1787
+ #
1788
+ # # Adaptive training (disabled by default for stability)
1789
+ # freeze_strategy="never", # Set to "performance" to enable
1790
+ # freeze_threshold=90.0,
1791
+ #
1792
+ # save_format="safetensors",
1793
+ #
1794
+ # # HuggingFace Hub upload
1795
+ # upload_to_hub=True,
1796
+ # hf_repo="AbstractPhil/gated-david",
1797
+ #))
1798
+ #configs.append(DavidTrainingConfig(
1799
+ # preset="balanced", # Uses progressive training by default
1800
+ # model_variant="clip_vit_laion_b32",
1801
+ #
1802
+ # num_epochs=20,
1803
+ # batch_size=1024,
1804
+ # learning_rate=1e-3,
1805
+ #
1806
+ # # Custom scale warmup schedule (overrides preset)
1807
+ # #scale_warmup_epochs_override={
1808
+ # # 384: 0, # Sca+le 256 active from epoch 0
1809
+ # # #512: 1, # Scale 512 active from epoch 2
1810
+ # # 768: 1, # Scale 768 active from epoch 5
1811
+ # # 1024: 2, # Scale 1024 active from epoch 8
1812
+ # # 1280: 3 # Scale 1280 active from epoch 10
1813
+ # #},
1814
+ # #scale_warmup_epochs_override={
1815
+ # # 256: 0,
1816
+ # # 512: 1,
1817
+ # # 768: 2,
1818
+ # # 1024: 3,
1819
+ # # 1280: 4,
1820
+ # # 1536: 5,
1821
+ # # 1792: 6,
1822
+ # # 2048: 7,
1823
+ # # 2304: 8,
1824
+ # # 2560: 9
1825
+ # #},
1826
+ #
1827
+ # use_mixed_precision=False,
1828
+ # gradient_clip=10.0,
1829
+ #
1830
+ # use_rose_loss=True,
1831
+ # rose_initial_weight=0.1,
1832
+ # rose_max_weight=0.5,
1833
+ # use_cayley_loss=False,
1834
+ #
1835
+ # # Adaptive training (disabled by default for stability)
1836
+ # freeze_strategy="never", # Set to "performance" to enable
1837
+ # freeze_threshold=90.0,
1838
+ #
1839
+ # save_format="safetensors",
1840
+ #
1841
+ # # HuggingFace Hub upload
1842
+ # upload_to_hub=True,
1843
+ # hf_repo="AbstractPhil/gated-david",
1844
+ #))
1845
+ #
1846
+ #configs.append(DavidTrainingConfig(
1847
+ # preset="clip_vit_l14_ultra_deep", # Uses progressive training by default
1848
+ # model_variant="clip_vit_l14",
1849
+ #
1850
+ # num_epochs=10,
1851
+ # batch_size=1024,
1852
+ # learning_rate=1e-3,
1853
+ #
1854
+ # # Custom scale warmup schedule (overrides preset)
1855
+ # #scale_warmup_epochs_override={
1856
+ # # 384: 0, # Scale 256 active from epoch 0
1857
+ # # #512: 1, # Scale 512 active from epoch 2
1858
+ # # 768: 1, # Scale 768 active from epoch 5
1859
+ # # 1024: 2, # Scale 1024 active from epoch 8
1860
+ # # 1280: 3 # Scale 1280 active from epoch 10
1861
+ # #},
1862
+ # #scale_warmup_epochs_override={
1863
+ # # 256: 0,
1864
+ # # 512: 1,
1865
+ # # 768: 2,
1866
+ # # 1024: 3,
1867
+ # # 1280: 4,
1868
+ # # 1536: 5,
1869
+ # # 1792: 6,
1870
+ # # 2048: 7,
1871
+ # # 2304: 8,
1872
+ # # 2560: 9
1873
+ # #},
1874
+ #
1875
+ # use_mixed_precision=False,
1876
+ # gradient_clip=10.0,
1877
+ #
1878
+ # use_rose_loss=True,
1879
+ # rose_initial_weight=0.1,
1880
+ # rose_max_weight=0.5,
1881
+ # use_cayley_loss=False,
1882
+ #
1883
+ # # Adaptive training (disabled by default for stability)
1884
+ # freeze_strategy="never", # Set to "performance" to enable
1885
+ # freeze_threshold=90.0,
1886
+ #
1887
+ # save_format="safetensors",
1888
+ #
1889
+ # # HuggingFace Hub upload
1890
+ # upload_to_hub=True,
1891
+ # hf_repo="AbstractPhil/gated-david",
1892
+ #))
1893
+
1894
+ #for config in configs:
1895
+ # print("Starting train")
1896
+ # try:
1897
+ david, best_acc = train_david(config)
1898
+ #except Exception as e:
1899
+ # print(f"Error during training: {e}")
1900
+ #print("train complete")