AbstractPhil commited on
Commit
e3013fc
Β·
verified Β·
1 Parent(s): 38a31cc

Create trainer_v2.py

Browse files
Files changed (1) hide show
  1. trainer_v2.py +1445 -0
trainer_v2.py ADDED
@@ -0,0 +1,1445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Liminal Staircase Training - DANBOORU EDITION (BULLETPROOF + GEOMETRIC + TEXT DROPOUT)
3
+ =========================================================================================
4
+
5
+ Fully hardened trainer with:
6
+ - Geometric pentachoron initialization via SimplexFactory
7
+ - TEXT MODALITY ROBUSTNESS: dropout, noise, semantic sentinel
8
+ - Saves checkpoints BEFORE validation
9
+ - Handles all validation crashes gracefully
10
+ - Proper scheduler with actual step counts
11
+ - Clean model/loss separation
12
+ - Keyboard interrupt saves checkpoint before exit
13
+ - Fixed shared fusion controller checkpoint handling
14
+ - PROPER checkpoint naming (no step in directory name)
15
+
16
+ Author: AbstractPhil + Claude Sonnet 4.5
17
+ Date: 2025-11-17 (Text Robustness Update)
18
+ """
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ from torch.utils.data import DataLoader
24
+ from torch.utils.tensorboard import SummaryWriter
25
+ from transformers import SiglipModel, SiglipProcessor, CLIPTokenizer
26
+ from accelerate import Accelerator
27
+ from tqdm.auto import tqdm
28
+ from pathlib import Path
29
+ from typing import Dict, List, Tuple, Optional
30
+ from dataclasses import dataclass, asdict
31
+ import numpy as np
32
+ from safetensors.torch import load_file, save_file
33
+ import os
34
+ import json
35
+ from datetime import datetime
36
+ import shutil
37
+ import traceback
38
+ import signal
39
+ import sys
40
+
41
+ # HuggingFace Hub
42
+ from huggingface_hub import HfApi, create_repo, hf_hub_download
43
+
44
+ # Import from your existing modules
45
+ from geovocab2.train.model.core.liminal_staircase_collective_v2 import (
46
+ LiminalStaircase,
47
+ LiminalStaircaseConfig,
48
+ ScaleFusionConfig,
49
+ OrganizedFusionController
50
+ )
51
+
52
+
53
+ # ============================================================================
54
+ # CONFIGURATION
55
+ # ============================================================================
56
+
57
+ @dataclass
58
+ class DanbooruTrainingConfig:
59
+ """Training configuration for Danbooru dataset with organized fusion."""
60
+
61
+ # Model identifier (NO STEP COUNT HERE!)
62
+ sub_name: str = "danbooru-v1"
63
+
64
+ # Core model architecture
65
+ num_opinion_anchors: int = 225
66
+ pentachoron_dim: int = 512
67
+ scales: List[int] = None
68
+ scale_hidden_dims: Dict[int, int] = None
69
+
70
+ # Fusion controller parameters
71
+ alpha_init: float = 0.1
72
+ alpha_learnable: bool = True
73
+ alpha_per_scale: bool = True
74
+
75
+ beta_init: float = 0.5
76
+ beta_learnable: bool = True
77
+ beta_per_scale: bool = True
78
+
79
+ gamma_learnable: bool = True
80
+
81
+ learn_layer_weights: bool = True
82
+
83
+ # Encoders
84
+ siglip_model: str = "google/siglip-so400m-patch14-384"
85
+ clip_tokenizer: str = "openai/clip-vit-large-patch14"
86
+ illustrious_clip_path: str = "./models/NAI-11-epsilon_clip_l.safetensors"
87
+ clip_skip: int = 0
88
+
89
+ # Layer selection
90
+ siglip_layer_indices: Optional[List[int]] = None
91
+ clip_layer_indices: Optional[List[int]] = None
92
+
93
+ # Optimizations
94
+ use_gradient_checkpointing: bool = False
95
+ share_scale_embeddings: bool = True
96
+
97
+ # Dataset
98
+ dataset_name: str = "animetimm/danbooru-wdtagger-v4-w640-ws-50k"
99
+ image_size: int = 384
100
+ max_tag_length: int = 77
101
+
102
+ # Training
103
+ batch_size: int = 32
104
+ num_epochs: int = 5
105
+ learning_rate: float = 1e-4
106
+ weight_decay: float = 1e-2
107
+ warmup_steps: int = 1000
108
+ gradient_clip: float = 1.0
109
+ gradient_accumulation_steps: int = 1
110
+
111
+ # Loss weights
112
+ token_loss_weight: float = 1.0
113
+ geometric_weight: float = 0.1
114
+ fusion_strategy: str = "learned_weighted"
115
+
116
+ # TEXT MODALITY ROBUSTNESS (NEW!)
117
+ text_dropout_prob: float = 0.3 # 30% vision-only batches
118
+ text_noise_std: float = 0.1 # Gaussian noise std
119
+ text_noise_prob: float = 0.5 # 50% of text batches get noise
120
+ vision_only_text: str = "general: blank_image" # Semantic sentinel token
121
+
122
+ # Progressive curriculum
123
+ text_dropout_schedule: str = "linear" # constant, linear, cosine
124
+ text_dropout_start: float = 0.1 # Start at 10% dropout
125
+ text_dropout_end: float = 0.5 # End at 50% dropout
126
+
127
+ # Checkpointing & Upload
128
+ checkpoint_dir: str = "./checkpoints/liminal_staircase_danbooru"
129
+ save_every: int = 500
130
+
131
+ # HuggingFace Upload
132
+ hf_repo_id: Optional[str] = None
133
+ hf_upload_every: int = 5000
134
+ hf_private: bool = False
135
+
136
+ # Resume
137
+ resume: bool = False
138
+
139
+ # Logging
140
+ log_dir: str = "./logs/liminal_staircase_danbooru"
141
+ log_every: int = 5
142
+
143
+ # Device
144
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
145
+
146
+ def __post_init__(self):
147
+ if self.scales is None:
148
+ self.scales = [128, 256, 512]
149
+
150
+ if self.scale_hidden_dims is None:
151
+ self.scale_hidden_dims = {s: s * 2 for s in self.scales}
152
+
153
+ Path(self.checkpoint_dir).mkdir(parents=True, exist_ok=True)
154
+ Path(self.log_dir).mkdir(parents=True, exist_ok=True)
155
+
156
+ def to_model_config(self, siglip_hidden_dim: int, siglip_num_layers: int) -> LiminalStaircaseConfig:
157
+ """Convert to LiminalStaircaseConfig with organized fusion."""
158
+
159
+ # Create ScaleFusionConfig
160
+ fusion_config = ScaleFusionConfig(
161
+ scales=self.scales,
162
+ scale_hidden_dims=self.scale_hidden_dims,
163
+ alpha_init=self.alpha_init,
164
+ alpha_learnable=self.alpha_learnable,
165
+ alpha_per_scale=self.alpha_per_scale,
166
+ beta_init=self.beta_init,
167
+ beta_learnable=self.beta_learnable,
168
+ beta_per_scale=self.beta_per_scale,
169
+ gamma_learnable=self.gamma_learnable,
170
+ learn_layer_weights=self.learn_layer_weights,
171
+ learn_scale_weights=True,
172
+ track_scale_losses=True
173
+ )
174
+
175
+ # Create main model config
176
+ return LiminalStaircaseConfig(
177
+ num_opinion_anchors=self.num_opinion_anchors,
178
+ pentachoron_dim=self.pentachoron_dim,
179
+ siglip_hidden_dim=siglip_hidden_dim,
180
+ siglip_num_layers=siglip_num_layers,
181
+ clip_hidden_dim=768,
182
+ clip_num_layers=12,
183
+ clip_skip=self.clip_skip,
184
+ vocab_size=49408,
185
+ max_seq_len=77,
186
+ siglip_layer_indices=self.siglip_layer_indices,
187
+ clip_layer_indices=self.clip_layer_indices,
188
+ scale_fusion=fusion_config,
189
+ use_gradient_checkpointing=self.use_gradient_checkpointing,
190
+ share_scale_embeddings=self.share_scale_embeddings,
191
+ geometric_init_method="hybrid",
192
+ geometric_init_validate=False,
193
+ geometric_init_seed=42
194
+ )
195
+
196
+
197
+ # ============================================================================
198
+ # CHECKPOINT MANAGER
199
+ # ============================================================================
200
+
201
+ class CheckpointManager:
202
+ """Manages checkpoints with proper naming (no step in directory name)."""
203
+
204
+ def __init__(
205
+ self,
206
+ local_dir: str,
207
+ hf_repo_id: Optional[str] = None,
208
+ sub_name: str = "default",
209
+ hf_private: bool = False
210
+ ):
211
+ self.local_dir = Path(local_dir)
212
+ self.hf_repo_id = hf_repo_id
213
+ self.sub_name = sub_name
214
+ self.hf_private = hf_private
215
+
216
+ # Checkpoint directory structure: checkpoints/{sub_name}/{timestamp}/
217
+ self.sub_checkpoint_dir = self.local_dir / sub_name
218
+ self.sub_checkpoint_dir.mkdir(parents=True, exist_ok=True)
219
+
220
+ self.checkpoints_file = self.sub_checkpoint_dir / "checkpoints.json"
221
+
222
+ if hf_repo_id:
223
+ self.hf_api = HfApi()
224
+ try:
225
+ create_repo(
226
+ repo_id=hf_repo_id,
227
+ private=hf_private,
228
+ exist_ok=True
229
+ )
230
+ print(f"πŸ€— HuggingFace repo: {hf_repo_id}")
231
+ except Exception as e:
232
+ print(f"⚠️ Could not create HF repo: {e}")
233
+ self.hf_api = None
234
+ else:
235
+ self.hf_api = None
236
+
237
+ self.checkpoint_history = self._load_checkpoint_history()
238
+
239
+ def _load_checkpoint_history(self) -> Dict:
240
+ if self.checkpoints_file.exists():
241
+ with open(self.checkpoints_file, 'r') as f:
242
+ return json.load(f)
243
+ return {
244
+ "sub_name": self.sub_name,
245
+ "checkpoints": [],
246
+ "latest": None,
247
+ "best": None
248
+ }
249
+
250
+ def _save_checkpoint_history(self):
251
+ with open(self.checkpoints_file, 'w') as f:
252
+ json.dump(self.checkpoint_history, f, indent=2)
253
+
254
+ def get_checkpoint_dir(self, step: int, epoch: int) -> Path:
255
+ """Generate checkpoint directory name (timestamp-based, step in metadata)."""
256
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
257
+ dirname = f"epoch{epoch}_step{step}_{timestamp}"
258
+ return self.sub_checkpoint_dir / dirname
259
+
260
+ def _safe_state_dict(self, model: nn.Module) -> Dict[str, torch.Tensor]:
261
+ """Get state dict with shared memory removed and fusion controller deduplicated."""
262
+ state_dict = model.state_dict()
263
+
264
+ # Remove fusion controller tracking buffers (shared memory)
265
+ keys_to_remove = [
266
+ k for k in state_dict.keys() if any([
267
+ 'fusion_controller.scale_losses' in k,
268
+ 'fusion_controller.scale_loss_counts' in k,
269
+ 'fusion_controller.scale_beta_losses' in k
270
+ ])
271
+ ]
272
+
273
+ for key in keys_to_remove:
274
+ del state_dict[key]
275
+
276
+ if keys_to_remove:
277
+ print(f" ℹ️ Removed {len(keys_to_remove)} shared tracking buffers")
278
+
279
+ # DEDUPLICATE fusion controller parameters
280
+ fusion_keys_to_remove = [
281
+ k for k in state_dict.keys() if (
282
+ 'siglip_experts.' in k or
283
+ 'clip_experts.' in k or
284
+ 'fusion.' in k
285
+ ) and '.fusion_controller.' in k
286
+ ]
287
+
288
+ for key in fusion_keys_to_remove:
289
+ del state_dict[key]
290
+
291
+ if fusion_keys_to_remove:
292
+ print(f" ℹ️ Removed {len(fusion_keys_to_remove)} duplicate fusion controller references")
293
+ print(f" βœ“ Keeping only main 'fusion_controller.*' parameters")
294
+
295
+ return state_dict
296
+
297
+ def save_checkpoint(
298
+ self,
299
+ model: nn.Module,
300
+ optimizer: torch.optim.Optimizer,
301
+ scheduler,
302
+ epoch: int,
303
+ step: int,
304
+ val_loss: float,
305
+ config: DanbooruTrainingConfig,
306
+ fusion_diagnostics: Dict,
307
+ is_best: bool = False
308
+ ) -> Path:
309
+ """Save checkpoint with proper naming."""
310
+ ckpt_dir = self.get_checkpoint_dir(step, epoch)
311
+ ckpt_dir.mkdir(parents=True, exist_ok=True)
312
+
313
+ print(f"\nπŸ’Ύ Saving checkpoint: {self.sub_name}/{ckpt_dir.name}")
314
+
315
+ state_dict = self._safe_state_dict(model)
316
+ weights_path = ckpt_dir / "model.safetensors"
317
+ save_file(state_dict, weights_path)
318
+ print(f" βœ“ Model weights: {weights_path.name}")
319
+
320
+ training_state = {
321
+ 'epoch': epoch,
322
+ 'global_step': step,
323
+ 'optimizer_state_dict': optimizer.state_dict(),
324
+ 'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
325
+ 'val_loss': val_loss,
326
+ 'sub_name': self.sub_name
327
+ }
328
+ torch.save(training_state, ckpt_dir / "training_state.pt")
329
+ print(f" βœ“ Training state: training_state.pt")
330
+
331
+ config_dict = asdict(config)
332
+ config_dict['timestamp'] = datetime.now().isoformat()
333
+ config_dict['step'] = step
334
+ config_dict['epoch'] = epoch
335
+ config_dict['val_loss'] = val_loss
336
+ config_dict['fusion_diagnostics'] = fusion_diagnostics
337
+ config_dict['is_best'] = is_best
338
+
339
+ with open(ckpt_dir / "config.json", 'w') as f:
340
+ json.dump(config_dict, f, indent=2)
341
+ print(f" βœ“ Config: config.json (step={step}, epoch={epoch}, val_loss={val_loss:.4f})")
342
+
343
+ checkpoint_info = {
344
+ 'timestamp': datetime.now().isoformat(),
345
+ 'dirname': ckpt_dir.name,
346
+ 'step': step,
347
+ 'epoch': epoch,
348
+ 'val_loss': val_loss,
349
+ 'is_best': is_best,
350
+ 'fusion_diagnostics': fusion_diagnostics
351
+ }
352
+
353
+ self.checkpoint_history['checkpoints'].append(checkpoint_info)
354
+ self.checkpoint_history['latest'] = checkpoint_info
355
+
356
+ if is_best:
357
+ self.checkpoint_history['best'] = checkpoint_info
358
+
359
+ self._save_checkpoint_history()
360
+ print(f" βœ“ Updated checkpoint history")
361
+
362
+ return ckpt_dir
363
+
364
+ def upload_checkpoint(self, ckpt_dir: Path):
365
+ """Upload checkpoint to HuggingFace."""
366
+ if not self.hf_api or not self.hf_repo_id:
367
+ return
368
+
369
+ try:
370
+ print(f"\nπŸ€— Uploading to HuggingFace: {self.hf_repo_id}")
371
+ print(f" Path: {self.sub_name}/{ckpt_dir.name}")
372
+
373
+ self.hf_api.upload_folder(
374
+ repo_id=self.hf_repo_id,
375
+ folder_path=str(ckpt_dir),
376
+ path_in_repo=f"{self.sub_name}/{ckpt_dir.name}",
377
+ commit_message=f"Checkpoint: {self.sub_name}/{ckpt_dir.name}"
378
+ )
379
+ print(f" βœ“ Uploaded checkpoint files")
380
+
381
+ self.hf_api.upload_file(
382
+ repo_id=self.hf_repo_id,
383
+ path_or_fileobj=str(self.checkpoints_file),
384
+ path_in_repo=f"{self.sub_name}/checkpoints.json",
385
+ commit_message=f"Update checkpoint history"
386
+ )
387
+ print(f" βœ“ Updated checkpoints.json")
388
+
389
+ print(f"βœ… Upload complete: https://huggingface.co/{self.hf_repo_id}")
390
+
391
+ except Exception as e:
392
+ print(f"⚠️ Upload failed: {e}")
393
+ traceback.print_exc()
394
+
395
+ def find_latest_checkpoint(self) -> Optional[Dict]:
396
+ """Find the latest checkpoint for this sub_name."""
397
+ checkpoints = self.checkpoint_history.get('checkpoints', [])
398
+ if checkpoints:
399
+ return max(checkpoints, key=lambda x: x['step'])
400
+ return None
401
+
402
+ def load_checkpoint_for_resume(
403
+ self,
404
+ model: nn.Module,
405
+ optimizer: torch.optim.Optimizer,
406
+ scheduler
407
+ ) -> Tuple[int, int, float]:
408
+ """Load checkpoint to resume training."""
409
+ latest = self.find_latest_checkpoint()
410
+
411
+ if not latest:
412
+ print(f"ℹ️ No previous checkpoint found for sub_name='{self.sub_name}'")
413
+ return 0, 0, float('inf')
414
+
415
+ ckpt_dir = self.sub_checkpoint_dir / latest['dirname']
416
+
417
+ if not ckpt_dir.exists():
418
+ if self.hf_api and self.hf_repo_id:
419
+ print(f"πŸ“₯ Downloading checkpoint from HuggingFace...")
420
+ try:
421
+ weights_path = hf_hub_download(
422
+ repo_id=self.hf_repo_id,
423
+ filename=f"{self.sub_name}/{latest['dirname']}/model.safetensors",
424
+ local_dir=self.local_dir
425
+ )
426
+
427
+ state_path = hf_hub_download(
428
+ repo_id=self.hf_repo_id,
429
+ filename=f"{self.sub_name}/{latest['dirname']}/training_state.pt",
430
+ local_dir=self.local_dir
431
+ )
432
+ print(f" βœ“ Downloaded checkpoint files")
433
+ except Exception as e:
434
+ print(f" ⚠️ Download failed: {e}")
435
+ return 0, 0, float('inf')
436
+ else:
437
+ print(f" ⚠️ Checkpoint directory not found: {ckpt_dir}")
438
+ return 0, 0, float('inf')
439
+
440
+ print(f"\nπŸ”„ Resuming from checkpoint: {latest['dirname']}")
441
+ print(f" Step: {latest['step']}, Epoch: {latest['epoch']}, Val Loss: {latest['val_loss']:.4f}")
442
+
443
+ weights_path = ckpt_dir / "model.safetensors"
444
+ state_dict = load_file(str(weights_path))
445
+
446
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
447
+
448
+ expected_missing = [
449
+ k for k in missing_keys if (
450
+ 'siglip_experts.' in k or
451
+ 'clip_experts.' in k or
452
+ 'fusion.' in k
453
+ ) and '.fusion_controller.' in k
454
+ ]
455
+
456
+ unexpected_missing = [k for k in missing_keys if k not in expected_missing]
457
+
458
+ if unexpected_missing:
459
+ print(f" ⚠️ Unexpected missing keys: {len(unexpected_missing)}")
460
+ for k in unexpected_missing[:5]:
461
+ print(f" - {k}")
462
+
463
+ if unexpected_keys:
464
+ print(f" ⚠️ Unexpected keys: {len(unexpected_keys)}")
465
+
466
+ print(f" βœ“ Loaded model weights ({len(expected_missing)} shared fusion refs skipped)")
467
+
468
+ state_path = ckpt_dir / "training_state.pt"
469
+ training_state = torch.load(state_path)
470
+
471
+ optimizer.load_state_dict(training_state['optimizer_state_dict'])
472
+ print(f" βœ“ Loaded optimizer state")
473
+
474
+ if scheduler and training_state['scheduler_state_dict']:
475
+ scheduler.load_state_dict(training_state['scheduler_state_dict'])
476
+ print(f" βœ“ Loaded scheduler state")
477
+
478
+ return training_state['epoch'], training_state['global_step'], training_state['val_loss']
479
+
480
+
481
+ # ============================================================================
482
+ # ILLUSTRIOUS CLIP & SIGLIP
483
+ # ============================================================================
484
+
485
+ class IllustriousCLIPTextEncoder(nn.Module):
486
+ """Loads and wraps Illustrious CLIP text encoder."""
487
+
488
+ def __init__(
489
+ self,
490
+ safetensors_path: str,
491
+ tokenizer_name: str = "openai/clip-vit-large-patch14",
492
+ clip_skip: int = 2,
493
+ device: str = "cuda"
494
+ ):
495
+ super().__init__()
496
+
497
+ self.clip_skip = clip_skip
498
+ self.device = device
499
+
500
+ print(f"\n{'='*80}")
501
+ print("LOADING ILLUSTRIOUS CLIP TEXT ENCODER")
502
+ print(f"{'='*80}")
503
+
504
+ from transformers import CLIPTokenizer
505
+ self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_name)
506
+ print(f"βœ“ Tokenizer: {tokenizer_name}")
507
+ print(f"βœ“ Vocab size: {self.tokenizer.vocab_size}")
508
+
509
+ if not os.path.exists(safetensors_path):
510
+ print(f"\n⚠️ Illustrious CLIP not found: {safetensors_path}")
511
+ print("Falling back to standard CLIP")
512
+
513
+ from transformers import CLIPTextModel
514
+ self.model = CLIPTextModel.from_pretrained(tokenizer_name).to(device)
515
+ self.is_illustrious = False
516
+ else:
517
+ print(f"Loading from: {safetensors_path}")
518
+
519
+ state_dict = load_file(safetensors_path)
520
+ print(f"βœ“ Loaded {len(state_dict)} tensors")
521
+
522
+ from transformers import CLIPTextModel, CLIPTextConfig
523
+ config = CLIPTextConfig.from_pretrained(tokenizer_name)
524
+ self.model = CLIPTextModel(config).to(device)
525
+
526
+ model_state_dict = self.model.state_dict()
527
+ mapped_state = {}
528
+
529
+ for key in state_dict.keys():
530
+ if key in model_state_dict:
531
+ mapped_state[key] = state_dict[key]
532
+ else:
533
+ new_key = key.replace("text_model.", "")
534
+ if new_key in model_state_dict:
535
+ mapped_state[new_key] = state_dict[key]
536
+
537
+ print(f"βœ“ Mapped {len(mapped_state)}/{len(model_state_dict)} parameters")
538
+
539
+ missing, unexpected = self.model.load_state_dict(mapped_state, strict=False)
540
+ if missing:
541
+ print(f"⚠️ Missing: {len(missing)} keys")
542
+ if unexpected:
543
+ print(f"⚠️ Unexpected: {len(unexpected)} keys")
544
+
545
+ self.is_illustrious = True
546
+ print(f"βœ… Illustrious CLIP loaded!")
547
+
548
+ for param in self.model.parameters():
549
+ param.requires_grad = False
550
+ self.model.eval()
551
+
552
+ active_layers = 12 - clip_skip
553
+ print(f"βœ“ Using {active_layers} layers (skip last {clip_skip})")
554
+ print(f"{'='*80}\n")
555
+
556
+ def forward(
557
+ self,
558
+ input_ids: torch.Tensor,
559
+ attention_mask: torch.Tensor
560
+ ) -> Dict[str, torch.Tensor]:
561
+ """Extract features from text encoder layers."""
562
+ with torch.no_grad():
563
+ outputs = self.model(
564
+ input_ids=input_ids,
565
+ attention_mask=attention_mask,
566
+ output_hidden_states=True,
567
+ return_dict=True
568
+ )
569
+
570
+ hidden_states = outputs.hidden_states
571
+ num_layers = len(hidden_states) - self.clip_skip - 1
572
+
573
+ features = {}
574
+ for i in range(num_layers):
575
+ features[f'clip_layer_{i}'] = hidden_states[i + 1]
576
+
577
+ return features
578
+
579
+
580
+ class SigLIPFeatureExtractor(nn.Module):
581
+ """Extracts features from all SigLIP vision layers."""
582
+
583
+ def __init__(self, model_name: str, device: str = "cuda"):
584
+ super().__init__()
585
+
586
+ print(f"\n{'='*80}")
587
+ print("LOADING SIGLIP VISION ENCODER")
588
+ print(f"{'='*80}")
589
+ print(f"Model: {model_name}")
590
+
591
+ self.model = SiglipModel.from_pretrained(model_name).to(device)
592
+ self.processor = SiglipProcessor.from_pretrained(model_name)
593
+
594
+ for param in self.model.parameters():
595
+ param.requires_grad = False
596
+ self.model.eval()
597
+
598
+ self.layer_outputs = {}
599
+ self._register_hooks()
600
+
601
+ num_layers = len(self.model.vision_model.encoder.layers)
602
+ print(f"βœ“ {num_layers} vision layers")
603
+ print(f"βœ“ Frozen encoder")
604
+ print(f"{'='*80}\n")
605
+
606
+ def _register_hooks(self):
607
+ """Register forward hooks to capture layer outputs."""
608
+ vision_model = self.model.vision_model
609
+
610
+ for i, layer in enumerate(vision_model.encoder.layers):
611
+ def make_hook(layer_idx):
612
+ def hook(module, input, output):
613
+ self.layer_outputs[f'siglip_layer_{layer_idx}'] = output
614
+ return hook
615
+ layer.register_forward_hook(make_hook(i))
616
+
617
+ def forward(self, images: torch.Tensor) -> Dict[str, torch.Tensor]:
618
+ """Extract features from all vision layers using hooks."""
619
+ with torch.no_grad():
620
+ if images.device != next(self.model.parameters()).device:
621
+ images = images.to(next(self.model.parameters()).device)
622
+
623
+ self.layer_outputs = {}
624
+ _ = self.model.vision_model(pixel_values=images)
625
+
626
+ return dict(self.layer_outputs)
627
+
628
+
629
+ # ============================================================================
630
+ # GEOMETRIC REGULARIZATION
631
+ # ============================================================================
632
+
633
+ class GeometricRegularization(nn.Module):
634
+ """Geometric regularization for pentachoron opinion anchors."""
635
+
636
+ def __init__(self):
637
+ super().__init__()
638
+
639
+ def cayley_menger_loss(
640
+ self,
641
+ pentachora: torch.Tensor,
642
+ sample_size: int = 50
643
+ ) -> torch.Tensor:
644
+ """Cayley-Menger volume regularization."""
645
+ num_anchors = pentachora.shape[0]
646
+
647
+ if num_anchors > sample_size:
648
+ indices = torch.randperm(num_anchors, device=pentachora.device)[:sample_size]
649
+ pentachora = pentachora[indices]
650
+
651
+ losses = []
652
+ for i in range(pentachora.shape[0]):
653
+ vertices = pentachora[i]
654
+
655
+ diff = vertices.unsqueeze(0) - vertices.unsqueeze(1)
656
+ dist_sq = (diff ** 2).sum(dim=-1)
657
+
658
+ M = torch.zeros(6, 6, device=vertices.device, dtype=vertices.dtype)
659
+ M[0, 1:] = 1.0
660
+ M[1:, 0] = 1.0
661
+ M[1:, 1:] = dist_sq
662
+
663
+ det = torch.linalg.det(M)
664
+ volume_sq = (-det / 9216.0).clamp(min=0.0)
665
+ volume = volume_sq.sqrt()
666
+
667
+ volume_loss = F.relu(0.01 - volume)
668
+ losses.append(volume_loss)
669
+
670
+ return torch.stack(losses).mean()
671
+
672
+ def rose_loss(
673
+ self,
674
+ pentachora: torch.Tensor,
675
+ target_norm: float = 0.29514
676
+ ) -> torch.Tensor:
677
+ """Rose harmonic constraint."""
678
+ vertex_norms = torch.norm(pentachora, dim=-1)
679
+ target = torch.full_like(vertex_norms, target_norm)
680
+ return F.mse_loss(vertex_norms, target)
681
+
682
+ def forward(self, pentachora: torch.Tensor) -> Dict[str, torch.Tensor]:
683
+ """Compute all geometric losses."""
684
+ return {
685
+ 'cayley': self.cayley_menger_loss(pentachora),
686
+ 'rose': self.rose_loss(pentachora)
687
+ }
688
+
689
+
690
+ # ============================================================================
691
+ # TRAINER WITH TEXT MODALITY ROBUSTNESS
692
+ # ============================================================================
693
+
694
+ class DanbooruLiminalStaircaseTrainer:
695
+ """Trainer with bulletproof checkpointing + text modality robustness."""
696
+
697
+ def __init__(self, config: DanbooruTrainingConfig):
698
+ self.config = config
699
+ self._interrupt_received = False
700
+ self._save_on_interrupt = True
701
+
702
+ self.accelerator = Accelerator(
703
+ gradient_accumulation_steps=config.gradient_accumulation_steps,
704
+ mixed_precision='fp16' if config.device == 'cuda' else 'no'
705
+ )
706
+
707
+ print("\n" + "🎨 " * 40)
708
+ print("LIMINAL STAIRCASE TRAINER - BULLETPROOF + GEOMETRIC + TEXT ROBUSTNESS")
709
+ print("🎨 " * 40 + "\n")
710
+
711
+ # Checkpoint manager
712
+ self.checkpoint_manager = CheckpointManager(
713
+ local_dir=config.checkpoint_dir,
714
+ hf_repo_id=config.hf_repo_id,
715
+ sub_name=config.sub_name,
716
+ hf_private=config.hf_private
717
+ )
718
+
719
+ # TensorBoard
720
+ if self.accelerator.is_main_process:
721
+ log_dir = Path(config.log_dir) / f"{config.sub_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
722
+ self.writer = SummaryWriter(log_dir=log_dir)
723
+ print(f"πŸ“Š TensorBoard logging to: {log_dir}")
724
+ else:
725
+ self.writer = None
726
+
727
+ # Feature extractors
728
+ self.siglip_extractor = SigLIPFeatureExtractor(
729
+ config.siglip_model,
730
+ config.device
731
+ )
732
+
733
+ self.clip_extractor = IllustriousCLIPTextEncoder(
734
+ config.illustrious_clip_path,
735
+ config.clip_tokenizer,
736
+ config.clip_skip,
737
+ config.device
738
+ )
739
+
740
+ # Get dimensions
741
+ siglip_hidden_dim = self.siglip_extractor.model.vision_model.config.hidden_size
742
+ siglip_num_layers = len(self.siglip_extractor.model.vision_model.encoder.layers)
743
+
744
+ # Initialize model
745
+ print("\n" + "⚑ " * 40)
746
+ print("INITIALIZING LIMINAL STAIRCASE WITH GEOMETRIC PENTACHORA")
747
+ print("⚑ " * 40)
748
+
749
+ model_config = config.to_model_config(siglip_hidden_dim, siglip_num_layers)
750
+ self.model = LiminalStaircase(model_config).to(config.device)
751
+
752
+ # Geometric regularization
753
+ self.geometric_reg = GeometricRegularization()
754
+
755
+ # Optimizer
756
+ self.optimizer = torch.optim.AdamW(
757
+ self.model.parameters(),
758
+ lr=config.learning_rate,
759
+ weight_decay=config.weight_decay
760
+ )
761
+
762
+ # Create dataloaders
763
+ print("\n" + "🎨 " * 40)
764
+ self.train_loader, self.val_loader, self.tag_vocab = create_danbooru_dataloaders(
765
+ siglip_processor=self.siglip_extractor.processor,
766
+ clip_tokenizer=self.clip_extractor.tokenizer,
767
+ dataset_name=config.dataset_name,
768
+ image_size=config.image_size,
769
+ batch_size=config.batch_size,
770
+ num_workers=4
771
+ )
772
+
773
+ # Create scheduler
774
+ steps_per_epoch = len(self.train_loader)
775
+ total_steps = config.num_epochs * steps_per_epoch
776
+
777
+ print(f"\nπŸ“Š Training schedule:")
778
+ print(f" Steps per epoch: {steps_per_epoch:,}")
779
+ print(f" Total training steps: {total_steps:,}")
780
+
781
+ self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
782
+ self.optimizer,
783
+ T_max=total_steps
784
+ )
785
+
786
+ # PRE-COMPUTE VISION-ONLY SENTINEL (CACHED!)
787
+ print(f"\nπŸ”· Creating vision-only sentinel token...")
788
+ print(f" Token: '{config.vision_only_text}'")
789
+ with torch.no_grad():
790
+ sentinel_input = self.clip_extractor.tokenizer(
791
+ config.vision_only_text,
792
+ return_tensors="pt",
793
+ padding="max_length",
794
+ truncation=True,
795
+ max_length=config.max_tag_length
796
+ ).to(config.device)
797
+
798
+ # Extract CLIP features for sentinel
799
+ self.vision_only_clip_features = self.clip_extractor(
800
+ sentinel_input['input_ids'],
801
+ sentinel_input['attention_mask']
802
+ )
803
+
804
+ # Freeze these - they're our "no text available" signal
805
+ self.vision_only_clip_features = {
806
+ name: feat.detach().clone()
807
+ for name, feat in self.vision_only_clip_features.items()
808
+ }
809
+
810
+ print(f"βœ“ Vision-only sentinel cached")
811
+ example_shape = list(self.vision_only_clip_features.values())[0].shape
812
+ print(f" Shape example: {example_shape}")
813
+ print(f" Text dropout: {config.text_dropout_schedule} schedule")
814
+ print(f" Start: {config.text_dropout_start:.1%}, End: {config.text_dropout_end:.1%}")
815
+
816
+ # Prepare with accelerator
817
+ (
818
+ self.model,
819
+ self.optimizer,
820
+ self.train_loader,
821
+ self.val_loader,
822
+ self.scheduler
823
+ ) = self.accelerator.prepare(
824
+ self.model,
825
+ self.optimizer,
826
+ self.train_loader,
827
+ self.val_loader,
828
+ self.scheduler
829
+ )
830
+
831
+ self.global_step = 0
832
+ self.start_epoch = 0
833
+ self.best_val_loss = float('inf')
834
+ self.current_epoch = 0
835
+
836
+ # Text modality tracking
837
+ self.text_dropout_stats = {
838
+ 'clean': 0,
839
+ 'noisy': 0,
840
+ 'sentinel': 0
841
+ }
842
+
843
+ # Resume if requested
844
+ if config.resume and self.accelerator.is_main_process:
845
+ epoch, step, val_loss = self.checkpoint_manager.load_checkpoint_for_resume(
846
+ self.accelerator.unwrap_model(self.model),
847
+ self.optimizer,
848
+ self.scheduler
849
+ )
850
+ self.start_epoch = epoch
851
+ self.global_step = step
852
+ self.best_val_loss = val_loss
853
+
854
+ # Setup interrupt handler
855
+ self._setup_interrupt_handler()
856
+
857
+ print("\n" + "βœ… " * 40)
858
+ print("TRAINER READY")
859
+ print("βœ… " * 40)
860
+ print(f"Sub name: {config.sub_name}")
861
+ print(f"Fusion strategy: {config.fusion_strategy}")
862
+ print(f"Model params: {sum(p.numel() for p in self.model.parameters()):,}")
863
+ print(f"Text robustness: ENABLED")
864
+ print(f" Sentinel: '{config.vision_only_text}'")
865
+ print(f" Dropout schedule: {config.text_dropout_schedule}")
866
+ if self.global_step > 0:
867
+ print(f"Resuming from: step {self.global_step}, epoch {self.start_epoch}")
868
+ print(f"⚑ Interrupt handling: Ctrl+C saves checkpoint before exit")
869
+ print("βœ… " * 40 + "\n")
870
+
871
+ def _setup_interrupt_handler(self):
872
+ """Setup signal handler for graceful interrupt."""
873
+ def signal_handler(sig, frame):
874
+ if self._interrupt_received:
875
+ print("\n⚠️ Second interrupt received, forcing exit...")
876
+ sys.exit(1)
877
+
878
+ self._interrupt_received = True
879
+ print("\n" + "⚑ " * 40)
880
+ print("KEYBOARD INTERRUPT DETECTED")
881
+ print("⚑ " * 40)
882
+ print("Saving checkpoint before exit...")
883
+
884
+ if self._save_on_interrupt and self.accelerator.is_main_process:
885
+ try:
886
+ self._emergency_save_checkpoint()
887
+ print("βœ… Emergency checkpoint saved successfully")
888
+ except Exception as e:
889
+ print(f"⚠️ Emergency save failed: {e}")
890
+ traceback.print_exc()
891
+
892
+ print("\n" + "⚑ " * 40)
893
+ print("Exiting gracefully...")
894
+ print("⚑ " * 40 + "\n")
895
+ sys.exit(0)
896
+
897
+ signal.signal(signal.SIGINT, signal_handler)
898
+
899
+ def _emergency_save_checkpoint(self):
900
+ """Emergency checkpoint save on interrupt."""
901
+ print(f"\nπŸ’Ύ Emergency save at step {self.global_step}, epoch {self.current_epoch}")
902
+
903
+ fusion_diagnostics = self.get_fusion_diagnostics()
904
+
905
+ ckpt_dir = self.checkpoint_manager.save_checkpoint(
906
+ model=self.accelerator.unwrap_model(self.model),
907
+ optimizer=self.optimizer,
908
+ scheduler=self.scheduler,
909
+ epoch=self.current_epoch,
910
+ step=self.global_step,
911
+ val_loss=float('inf'),
912
+ config=self.config,
913
+ fusion_diagnostics=fusion_diagnostics,
914
+ is_best=False
915
+ )
916
+
917
+ if self.config.hf_repo_id:
918
+ print("Attempting HuggingFace upload...")
919
+ try:
920
+ self.checkpoint_manager.upload_checkpoint(ckpt_dir)
921
+ except Exception as e:
922
+ print(f"⚠️ Upload failed (checkpoint saved locally): {e}")
923
+
924
+ def get_text_dropout_prob(self) -> float:
925
+ """Get current text dropout probability with curriculum."""
926
+ if self.config.text_dropout_schedule == "constant":
927
+ return self.config.text_dropout_prob
928
+
929
+ # Calculate progress
930
+ steps_per_epoch = len(self.train_loader)
931
+ total_steps = self.config.num_epochs * steps_per_epoch
932
+ progress = self.global_step / max(total_steps, 1)
933
+
934
+ if self.config.text_dropout_schedule == "linear":
935
+ dropout = self.config.text_dropout_start + progress * (
936
+ self.config.text_dropout_end - self.config.text_dropout_start
937
+ )
938
+ elif self.config.text_dropout_schedule == "cosine":
939
+ dropout = self.config.text_dropout_start + 0.5 * (
940
+ self.config.text_dropout_end - self.config.text_dropout_start
941
+ ) * (1 - np.cos(np.pi * progress))
942
+ else:
943
+ dropout = self.config.text_dropout_prob
944
+
945
+ return dropout
946
+
947
+ def compute_loss(
948
+ self,
949
+ outputs: Dict,
950
+ target_tokens: torch.Tensor
951
+ ) -> Tuple[torch.Tensor, Dict[str, float]]:
952
+ """Compute ALL losses in trainer."""
953
+ try:
954
+ token_logits = outputs['token_logits']
955
+
956
+ B, seq_len, vocab_size = token_logits.shape
957
+ token_logits_flat = token_logits.view(-1, vocab_size)
958
+ target_tokens_flat = target_tokens.view(-1)
959
+
960
+ token_loss = F.cross_entropy(
961
+ token_logits_flat,
962
+ target_tokens_flat,
963
+ ignore_index=self.clip_extractor.tokenizer.pad_token_id
964
+ )
965
+
966
+ # Geometric regularization
967
+ pentachora = self.accelerator.unwrap_model(self.model).opinion_anchors
968
+ geo_losses = self.geometric_reg(pentachora)
969
+
970
+ # Beta losses
971
+ beta_loss = 0.0
972
+ if 'scale_feature_pairs' in outputs and self.model.training:
973
+ beta_losses = []
974
+ for scale, features in outputs['scale_feature_pairs'].items():
975
+ token_feat = features['token_features']
976
+ geo_feat = features['geometric_features']
977
+ beta = features['beta']
978
+
979
+ scale_beta_loss = beta * F.mse_loss(token_feat, geo_feat)
980
+ beta_losses.append(scale_beta_loss)
981
+
982
+ if beta_losses:
983
+ beta_loss = sum(beta_losses) / len(beta_losses)
984
+
985
+ total_loss = (
986
+ self.config.token_loss_weight * token_loss +
987
+ self.config.geometric_weight * (geo_losses['cayley'] + geo_losses['rose'] + beta_loss)
988
+ )
989
+
990
+ # Accuracy
991
+ preds = token_logits.argmax(dim=-1)
992
+ mask = target_tokens != self.clip_extractor.tokenizer.pad_token_id
993
+ mask_sum = mask.float().sum()
994
+
995
+ if mask_sum > 0:
996
+ acc = ((preds == target_tokens) & mask).float().sum() / mask_sum
997
+ else:
998
+ acc = torch.tensor(0.0, device=token_logits.device)
999
+
1000
+ metrics = {
1001
+ 'loss/total': total_loss.item(),
1002
+ 'loss/token': token_loss.item(),
1003
+ 'loss/cayley': geo_losses['cayley'].item(),
1004
+ 'loss/rose': geo_losses['rose'].item(),
1005
+ 'loss/beta': beta_loss.item() if isinstance(beta_loss, torch.Tensor) else beta_loss,
1006
+ 'acc/token': acc.item()
1007
+ }
1008
+
1009
+ return total_loss, metrics
1010
+
1011
+ except Exception as e:
1012
+ print(f"\n⚠️ Error in compute_loss: {e}")
1013
+ traceback.print_exc()
1014
+ raise
1015
+
1016
+ def get_fusion_diagnostics(self) -> Dict:
1017
+ """Get current fusion controller state with error handling."""
1018
+ try:
1019
+ model = self.accelerator.unwrap_model(self.model)
1020
+ return model.fusion_controller.get_diagnostics()
1021
+ except Exception as e:
1022
+ print(f"⚠️ Error getting fusion diagnostics: {e}")
1023
+ return {
1024
+ 'layer_weights': [],
1025
+ 'scale_weights': [],
1026
+ 'alpha_per_scale': [],
1027
+ 'beta_per_scale': [],
1028
+ 'scale_statistics': {}
1029
+ }
1030
+
1031
+ def train_step(self, batch: Dict) -> Dict[str, float]:
1032
+ """Single training step with TEXT MODALITY ROBUSTNESS."""
1033
+ try:
1034
+ self.model.train()
1035
+
1036
+ # Extract vision features (always present)
1037
+ with torch.no_grad():
1038
+ siglip_features = self.siglip_extractor(batch['siglip_images'])
1039
+
1040
+ # TEXT MODALITY ROBUSTNESS
1041
+ current_dropout = self.get_text_dropout_prob()
1042
+ use_text = torch.rand(1).item() > current_dropout
1043
+ text_status = "clean"
1044
+
1045
+ if use_text:
1046
+ # Extract text features
1047
+ with torch.no_grad():
1048
+ clip_features = self.clip_extractor(
1049
+ batch['clip_input_ids'],
1050
+ batch['clip_attention_mask']
1051
+ )
1052
+
1053
+ # Maybe add noise
1054
+ if torch.rand(1).item() < self.config.text_noise_prob:
1055
+ for layer_name, features in clip_features.items():
1056
+ noise = torch.randn_like(features) * self.config.text_noise_std
1057
+ clip_features[layer_name] = features + noise
1058
+ text_status = "noisy"
1059
+ self.text_dropout_stats['noisy'] += 1
1060
+ else:
1061
+ text_status = "clean"
1062
+ self.text_dropout_stats['clean'] += 1
1063
+ else:
1064
+ # VISION-ONLY MODE: Use semantic sentinel
1065
+ batch_size = batch['siglip_images'].shape[0]
1066
+ clip_features = {}
1067
+
1068
+ for layer_name, sentinel_feat in self.vision_only_clip_features.items():
1069
+ # Expand sentinel to batch: [1, seq, dim] -> [batch, seq, dim]
1070
+ clip_features[layer_name] = sentinel_feat.expand(
1071
+ batch_size, -1, -1
1072
+ ).contiguous()
1073
+
1074
+ text_status = "sentinel"
1075
+ self.text_dropout_stats['sentinel'] += 1
1076
+
1077
+ # Forward pass
1078
+ with self.accelerator.accumulate(self.model):
1079
+ outputs = self.model(siglip_features, clip_features)
1080
+ loss, metrics = self.compute_loss(outputs, batch['clip_input_ids'])
1081
+
1082
+ # Track text modality usage
1083
+ metrics['text_dropout_prob'] = current_dropout
1084
+ metrics['text_mode'] = {'clean': 0.0, 'noisy': 0.5, 'sentinel': 1.0}[text_status]
1085
+
1086
+ self.accelerator.backward(loss)
1087
+
1088
+ if self.accelerator.sync_gradients and self.config.gradient_clip > 0:
1089
+ self.accelerator.clip_grad_norm_(
1090
+ self.model.parameters(),
1091
+ self.config.gradient_clip
1092
+ )
1093
+
1094
+ self.optimizer.step()
1095
+ self.scheduler.step()
1096
+ self.optimizer.zero_grad()
1097
+
1098
+ return metrics
1099
+
1100
+ except Exception as e:
1101
+ print(f"\n⚠️ Error in train_step at step {self.global_step}: {e}")
1102
+ traceback.print_exc()
1103
+ return {
1104
+ 'loss/total': float('nan'),
1105
+ 'loss/token': float('nan'),
1106
+ 'loss/cayley': 0.0,
1107
+ 'loss/rose': 0.0,
1108
+ 'loss/beta': 0.0,
1109
+ 'acc/token': 0.0,
1110
+ 'text_dropout_prob': 0.0,
1111
+ 'text_mode': 0.0
1112
+ }
1113
+
1114
+ def log_metrics(self, metrics: Dict[str, float], prefix: str = "train"):
1115
+ """Log metrics to TensorBoard."""
1116
+ if self.writer is None:
1117
+ return
1118
+
1119
+ for key, value in metrics.items():
1120
+ self.writer.add_scalar(f"{prefix}/{key}", value, self.global_step)
1121
+
1122
+ current_lr = self.optimizer.param_groups[0]['lr']
1123
+ self.writer.add_scalar("train/learning_rate", current_lr, self.global_step)
1124
+
1125
+ # Log text modality stats
1126
+ if self.global_step % self.config.log_every == 0:
1127
+ total = sum(self.text_dropout_stats.values()) or 1
1128
+ for mode, count in self.text_dropout_stats.items():
1129
+ self.writer.add_scalar(f"text_modality/{mode}_pct", 100 * count / total, self.global_step)
1130
+
1131
+ if self.global_step % (self.config.log_every * 10) == 0:
1132
+ fusion_diag = self.get_fusion_diagnostics()
1133
+
1134
+ for i, w in enumerate(fusion_diag.get('layer_weights', [])):
1135
+ self.writer.add_scalar(f"fusion/layer_weight_{i}", w, self.global_step)
1136
+
1137
+ for i, w in enumerate(fusion_diag.get('scale_weights', [])):
1138
+ self.writer.add_scalar(f"fusion/scale_weight_{i}", w, self.global_step)
1139
+
1140
+ for i, a in enumerate(fusion_diag.get('alpha_per_scale', [])):
1141
+ self.writer.add_scalar(f"fusion/alpha_scale_{i}", a, self.global_step)
1142
+
1143
+ for i, b in enumerate(fusion_diag.get('beta_per_scale', [])):
1144
+ self.writer.add_scalar(f"fusion/beta_scale_{i}", b, self.global_step)
1145
+
1146
+ @torch.no_grad()
1147
+ def validate(self, max_batches: int = 100) -> Dict[str, float]:
1148
+ """Validation with both vision-only and vision+text modes."""
1149
+ try:
1150
+ self.model.eval()
1151
+
1152
+ # Track both modes separately
1153
+ stats_with_text = {'loss': 0.0, 'acc': 0.0, 'count': 0}
1154
+ stats_vision_only = {'loss': 0.0, 'acc': 0.0, 'count': 0}
1155
+
1156
+ num_batches = 0
1157
+
1158
+ for batch in tqdm(self.val_loader, desc="Validating", leave=False, total=max_batches):
1159
+ if num_batches >= max_batches:
1160
+ break
1161
+
1162
+ try:
1163
+ siglip_features = self.siglip_extractor(batch['siglip_images'])
1164
+ batch_size = batch['siglip_images'].shape[0]
1165
+
1166
+ # TEST 1: Vision + Text (for reference)
1167
+ clip_features_text = self.clip_extractor(
1168
+ batch['clip_input_ids'],
1169
+ batch['clip_attention_mask']
1170
+ )
1171
+
1172
+ outputs_text = self.model(siglip_features, clip_features_text)
1173
+ loss_text, metrics_text = self.compute_loss(outputs_text, batch['clip_input_ids'])
1174
+
1175
+ stats_with_text['loss'] += metrics_text['loss/total']
1176
+ stats_with_text['acc'] += metrics_text['acc/token']
1177
+ stats_with_text['count'] += 1
1178
+
1179
+ # TEST 2: Vision-only (REAL USE CASE!)
1180
+ clip_features_sentinel = {}
1181
+ for layer_name, sentinel_feat in self.vision_only_clip_features.items():
1182
+ clip_features_sentinel[layer_name] = sentinel_feat.expand(
1183
+ batch_size, -1, -1
1184
+ ).contiguous()
1185
+
1186
+ outputs_vision = self.model(siglip_features, clip_features_sentinel)
1187
+ loss_vision, metrics_vision = self.compute_loss(outputs_vision, batch['clip_input_ids'])
1188
+
1189
+ stats_vision_only['loss'] += metrics_vision['loss/total']
1190
+ stats_vision_only['acc'] += metrics_vision['acc/token']
1191
+ stats_vision_only['count'] += 1
1192
+
1193
+ num_batches += 1
1194
+
1195
+ except Exception as e:
1196
+ print(f"\n⚠️ Error in validation batch: {e}")
1197
+ continue
1198
+
1199
+ if stats_with_text['count'] == 0 or stats_vision_only['count'] == 0:
1200
+ return {'loss/val': float('inf'), 'acc/val': 0.0}
1201
+
1202
+ return {
1203
+ 'loss/val_with_text': stats_with_text['loss'] / stats_with_text['count'],
1204
+ 'acc/val_with_text': stats_with_text['acc'] / stats_with_text['count'],
1205
+ 'loss/val_vision_only': stats_vision_only['loss'] / stats_vision_only['count'],
1206
+ 'acc/val_vision_only': stats_vision_only['acc'] / stats_vision_only['count'],
1207
+ # Overall metric = vision-only (the real use case)
1208
+ 'loss/val': stats_vision_only['loss'] / stats_vision_only['count'],
1209
+ 'acc/val': stats_vision_only['acc'] / stats_vision_only['count'],
1210
+ }
1211
+
1212
+ except Exception as e:
1213
+ print(f"\n⚠️ Validation completely failed: {e}")
1214
+ traceback.print_exc()
1215
+ return {'loss/val': float('inf'), 'acc/val': 0.0}
1216
+
1217
+ def save_checkpoint_and_upload(self, epoch: int, val_loss: float = float('inf'), is_best: bool = False):
1218
+ """Save checkpoint first, then optionally upload."""
1219
+ if not self.accelerator.is_main_process:
1220
+ return
1221
+
1222
+ try:
1223
+ fusion_diagnostics = self.get_fusion_diagnostics()
1224
+
1225
+ # Add text modality stats to diagnostics
1226
+ total = sum(self.text_dropout_stats.values()) or 1
1227
+ fusion_diagnostics['text_modality_stats'] = {
1228
+ mode: f"{100 * count / total:.1f}%"
1229
+ for mode, count in self.text_dropout_stats.items()
1230
+ }
1231
+
1232
+ ckpt_dir = self.checkpoint_manager.save_checkpoint(
1233
+ model=self.accelerator.unwrap_model(self.model),
1234
+ optimizer=self.optimizer,
1235
+ scheduler=self.scheduler,
1236
+ epoch=epoch,
1237
+ step=self.global_step,
1238
+ val_loss=val_loss,
1239
+ config=self.config,
1240
+ fusion_diagnostics=fusion_diagnostics,
1241
+ is_best=is_best
1242
+ )
1243
+
1244
+ if self.config.hf_repo_id:
1245
+ self.checkpoint_manager.upload_checkpoint(ckpt_dir)
1246
+
1247
+ except Exception as e:
1248
+ print(f"\n⚠️ Checkpoint save/upload failed: {e}")
1249
+ traceback.print_exc()
1250
+
1251
+ # ============================================================================
1252
+ # MAIN TRAINING METHOD
1253
+ # ============================================================================
1254
+
1255
+ def train(self):
1256
+ """Full training loop with bulletproof checkpointing."""
1257
+ print("\n" + "πŸš€ " * 40)
1258
+ print("TRAINING START")
1259
+ print("πŸš€ " * 40 + "\n")
1260
+
1261
+ try:
1262
+ for epoch in range(self.start_epoch, self.config.num_epochs):
1263
+ self.current_epoch = epoch
1264
+
1265
+ if self._interrupt_received:
1266
+ break
1267
+
1268
+ print(f"\n{'🎨'*40}")
1269
+ print(f"EPOCH {epoch + 1}/{self.config.num_epochs}")
1270
+ print(f"{'🎨'*40}\n")
1271
+
1272
+ pbar = tqdm(
1273
+ self.train_loader,
1274
+ desc=f"Epoch {epoch + 1}",
1275
+ disable=not self.accelerator.is_main_process
1276
+ )
1277
+
1278
+ for batch in pbar:
1279
+ if self._interrupt_received:
1280
+ break
1281
+
1282
+ metrics = self.train_step(batch)
1283
+ self.global_step += 1
1284
+
1285
+ if self.global_step % self.config.log_every == 0:
1286
+ pbar.set_postfix(metrics)
1287
+ self.log_metrics(metrics, prefix="train")
1288
+
1289
+ # Save checkpoint
1290
+ if self.global_step % self.config.save_every == 0:
1291
+ self.save_checkpoint_and_upload(epoch, val_loss=float('inf'), is_best=False)
1292
+
1293
+ if self.accelerator.is_main_process:
1294
+ print("\nπŸ” Running validation...")
1295
+ val_metrics = self.validate(max_batches=50)
1296
+ self.log_metrics(val_metrics, prefix="val")
1297
+ print(f"βœ“ Val (with text) - Loss: {val_metrics.get('loss/val_with_text', 0):.4f}, Acc: {val_metrics.get('acc/val_with_text', 0):.4f}")
1298
+ print(f"βœ“ Val (vision-only) - Loss: {val_metrics.get('loss/val_vision_only', 0):.4f}, Acc: {val_metrics.get('acc/val_vision_only', 0):.4f}")
1299
+
1300
+ # HuggingFace upload
1301
+ if (self.config.hf_repo_id and
1302
+ self.global_step % self.config.hf_upload_every == 0):
1303
+ self.save_checkpoint_and_upload(epoch, val_loss=float('inf'), is_best=False)
1304
+
1305
+ if self.accelerator.is_main_process:
1306
+ print("\nπŸ” Running validation for upload...")
1307
+ val_metrics = self.validate(max_batches=50)
1308
+ print(f"βœ“ Val (with text) - Loss: {val_metrics.get('loss/val_with_text', 0):.4f}, Acc: {val_metrics.get('acc/val_with_text', 0):.4f}")
1309
+ print(f"βœ“ Val (vision-only) - Loss: {val_metrics.get('loss/val_vision_only', 0):.4f}, Acc: {val_metrics.get('acc/val_vision_only', 0):.4f}")
1310
+
1311
+ if self._interrupt_received:
1312
+ break
1313
+
1314
+ # End of epoch
1315
+ if self.accelerator.is_main_process:
1316
+ self.save_checkpoint_and_upload(epoch, val_loss=float('inf'), is_best=False)
1317
+
1318
+ print("\nπŸ” End of epoch validation...")
1319
+ val_metrics = self.validate(max_batches=100)
1320
+
1321
+ print(f"\nπŸ“Š Validation Results:")
1322
+ print(f" With Text:")
1323
+ print(f" Loss: {val_metrics.get('loss/val_with_text', 0):.4f}")
1324
+ print(f" Acc: {val_metrics.get('acc/val_with_text', 0):.4f}")
1325
+ print(f" Vision-Only (PRIMARY METRIC):")
1326
+ print(f" Loss: {val_metrics.get('loss/val_vision_only', 0):.4f}")
1327
+ print(f" Acc: {val_metrics.get('acc/val_vision_only', 0):.4f}")
1328
+
1329
+ self.log_metrics(val_metrics, prefix="val")
1330
+
1331
+ is_best = val_metrics['loss/val'] < self.best_val_loss
1332
+ if is_best:
1333
+ self.best_val_loss = val_metrics['loss/val']
1334
+ print(f"\nπŸŽ‰ New best (vision-only): {self.best_val_loss:.4f}")
1335
+ self.save_checkpoint_and_upload(epoch, val_metrics['loss/val'], is_best=True)
1336
+
1337
+ fusion_diag = self.get_fusion_diagnostics()
1338
+ print(f"\n⚑ Fusion Controller State:")
1339
+ print(f" Scale weights: {[f'{w:.3f}' for w in fusion_diag.get('scale_weights', [])]}")
1340
+ print(f" Alpha: {[f'{a:.3f}' for a in fusion_diag.get('alpha_per_scale', [])]}")
1341
+ print(f" Beta: {[f'{b:.3f}' for b in fusion_diag.get('beta_per_scale', [])]}")
1342
+
1343
+ # Print text modality stats
1344
+ total = sum(self.text_dropout_stats.values()) or 1
1345
+ print(f"\nπŸ“ Text Modality Distribution:")
1346
+ for mode, count in self.text_dropout_stats.items():
1347
+ print(f" {mode}: {100*count/total:.1f}%")
1348
+
1349
+ except KeyboardInterrupt:
1350
+ if not self._interrupt_received:
1351
+ self._interrupt_received = True
1352
+ if self._save_on_interrupt and self.accelerator.is_main_process:
1353
+ self._emergency_save_checkpoint()
1354
+ raise
1355
+
1356
+ if not self._interrupt_received:
1357
+ print("\n" + "βœ… " * 40)
1358
+ print("TRAINING COMPLETE")
1359
+ print("βœ… " * 40)
1360
+ print(f"Best val loss (vision-only): {self.best_val_loss:.4f}")
1361
+
1362
+ if self.accelerator.is_main_process:
1363
+ print(f"\nπŸ“Š TensorBoard logs: {self.config.log_dir}")
1364
+ if self.config.hf_repo_id:
1365
+ print(f"πŸ€— Model on HuggingFace: https://huggingface.co/{self.config.hf_repo_id}")
1366
+
1367
+ print("βœ… " * 40 + "\n")
1368
+
1369
+ if self.writer:
1370
+ self.writer.close()
1371
+
1372
+
1373
+ # ============================================================================
1374
+ # MAIN
1375
+ # ============================================================================
1376
+
1377
+ if __name__ == "__main__":
1378
+ config = DanbooruTrainingConfig(
1379
+ # Run identifier
1380
+ sub_name="danbooru-50k-v1-512",
1381
+
1382
+ # Model architecture
1383
+ num_opinion_anchors=225,
1384
+ pentachoron_dim=256,
1385
+ scales=[128, 256, 512, 1024],
1386
+ scale_hidden_dims={128: 128, 256: 512, 512: 1024, 1024: 2048},
1387
+
1388
+ # Fusion controller
1389
+ alpha_init=0.1,
1390
+ alpha_learnable=True,
1391
+ beta_init=0.5,
1392
+ beta_learnable=True,
1393
+ gamma_learnable=True,
1394
+ learn_layer_weights=True,
1395
+
1396
+ # Encoders
1397
+ clip_skip=0,
1398
+ siglip_layer_indices=[3, 6, 9, 12, 21, 23, 24, 25, 26],
1399
+
1400
+ # Optimizations
1401
+ use_gradient_checkpointing=False,
1402
+ share_scale_embeddings=False,
1403
+
1404
+ # Training
1405
+ batch_size=32,
1406
+ num_epochs=3,
1407
+ learning_rate=1e-4,
1408
+ save_every=500,
1409
+
1410
+ # TEXT MODALITY ROBUSTNESS (NEW!)
1411
+ text_dropout_prob=0.3,
1412
+ text_noise_std=0.1,
1413
+ text_noise_prob=0.5,
1414
+ vision_only_text="general: blank_image", # Semantic sentinel
1415
+ text_dropout_schedule="linear", # Curriculum: 10% β†’ 50%
1416
+ text_dropout_start=0.1,
1417
+ text_dropout_end=0.5,
1418
+
1419
+ # Resume
1420
+ resume=True,
1421
+
1422
+ # HuggingFace
1423
+ hf_repo_id="AbstractPhil/liminal-staircase-v2",
1424
+ hf_upload_every=1000,
1425
+ hf_private=False,
1426
+ )
1427
+
1428
+ print("\n" + "🎨 " * 40)
1429
+ print("LIMINAL STAIRCASE - BULLETPROOF + GEOMETRIC + TEXT ROBUSTNESS")
1430
+ print("🎨 " * 40)
1431
+ print(f"\nSub name: {config.sub_name}")
1432
+ print(f"Scales: {config.scales}")
1433
+ print(f"SigLIP layers: {config.siglip_layer_indices}")
1434
+ print(f"CLIP skip: {config.clip_skip}")
1435
+ print(f"Geometric init: hybrid pentachora")
1436
+ print(f"\nπŸ”· Text Modality Robustness:")
1437
+ print(f" Sentinel: '{config.vision_only_text}'")
1438
+ print(f" Dropout: {config.text_dropout_schedule} ({config.text_dropout_start:.0%} β†’ {config.text_dropout_end:.0%})")
1439
+ print(f" Noise: {config.text_noise_prob:.0%} of text batches @ std={config.text_noise_std}")
1440
+ if config.hf_repo_id:
1441
+ print(f"\nπŸ€— HuggingFace: {config.hf_repo_id}")
1442
+ print("\n" + "🎨 " * 40 + "\n")
1443
+
1444
+ trainer = DanbooruLiminalStaircaseTrainer(config)
1445
+ trainer.train()