|
|
""" |
|
|
Liminal Staircase Training - DANBOORU EDITION (BULLETPROOF + GEOMETRIC + TEXT DROPOUT) |
|
|
========================================================================================= |
|
|
|
|
|
Fully hardened trainer with: |
|
|
- Geometric pentachoron initialization via SimplexFactory |
|
|
- TEXT MODALITY ROBUSTNESS: dropout, noise, semantic sentinel |
|
|
- Saves checkpoints BEFORE validation |
|
|
- Handles all validation crashes gracefully |
|
|
- Proper scheduler with actual step counts |
|
|
- Clean model/loss separation |
|
|
- Keyboard interrupt saves checkpoint before exit |
|
|
- Fixed shared fusion controller checkpoint handling |
|
|
- PROPER checkpoint naming (no step in directory name) |
|
|
|
|
|
Author: AbstractPhil + Claude Sonnet 4.5 |
|
|
Date: 2025-11-17 (Text Robustness Update) |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import DataLoader |
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
from transformers import SiglipModel, SiglipProcessor, CLIPTokenizer |
|
|
from accelerate import Accelerator |
|
|
from tqdm.auto import tqdm |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Tuple, Optional |
|
|
from dataclasses import dataclass, asdict |
|
|
import numpy as np |
|
|
from safetensors.torch import load_file, save_file |
|
|
import os |
|
|
import json |
|
|
from datetime import datetime |
|
|
import shutil |
|
|
import traceback |
|
|
import signal |
|
|
import sys |
|
|
|
|
|
|
|
|
from huggingface_hub import HfApi, create_repo, hf_hub_download |
|
|
|
|
|
|
|
|
from geovocab2.train.model.core.liminal_staircase_collective_v2 import ( |
|
|
LiminalStaircase, |
|
|
LiminalStaircaseConfig, |
|
|
ScaleFusionConfig, |
|
|
OrganizedFusionController |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DanbooruTrainingConfig: |
|
|
"""Training configuration for Danbooru dataset with organized fusion.""" |
|
|
|
|
|
|
|
|
sub_name: str = "danbooru-v1" |
|
|
|
|
|
|
|
|
num_opinion_anchors: int = 225 |
|
|
pentachoron_dim: int = 512 |
|
|
scales: List[int] = None |
|
|
scale_hidden_dims: Dict[int, int] = None |
|
|
|
|
|
|
|
|
alpha_init: float = 0.1 |
|
|
alpha_learnable: bool = True |
|
|
alpha_per_scale: bool = True |
|
|
|
|
|
beta_init: float = 0.5 |
|
|
beta_learnable: bool = True |
|
|
beta_per_scale: bool = True |
|
|
|
|
|
gamma_learnable: bool = True |
|
|
|
|
|
learn_layer_weights: bool = True |
|
|
|
|
|
|
|
|
siglip_model: str = "google/siglip-so400m-patch14-384" |
|
|
clip_tokenizer: str = "openai/clip-vit-large-patch14" |
|
|
illustrious_clip_path: str = "./models/NAI-11-epsilon_clip_l.safetensors" |
|
|
clip_skip: int = 0 |
|
|
|
|
|
|
|
|
siglip_layer_indices: Optional[List[int]] = None |
|
|
clip_layer_indices: Optional[List[int]] = None |
|
|
|
|
|
|
|
|
use_gradient_checkpointing: bool = False |
|
|
share_scale_embeddings: bool = True |
|
|
|
|
|
|
|
|
dataset_name: str = "animetimm/danbooru-wdtagger-v4-w640-ws-50k" |
|
|
image_size: int = 384 |
|
|
max_tag_length: int = 77 |
|
|
|
|
|
|
|
|
batch_size: int = 32 |
|
|
num_epochs: int = 5 |
|
|
learning_rate: float = 1e-4 |
|
|
weight_decay: float = 1e-2 |
|
|
warmup_steps: int = 1000 |
|
|
gradient_clip: float = 1.0 |
|
|
gradient_accumulation_steps: int = 1 |
|
|
|
|
|
|
|
|
token_loss_weight: float = 1.0 |
|
|
geometric_weight: float = 0.1 |
|
|
fusion_strategy: str = "learned_weighted" |
|
|
|
|
|
|
|
|
text_dropout_prob: float = 0.3 |
|
|
text_noise_std: float = 0.1 |
|
|
text_noise_prob: float = 0.5 |
|
|
vision_only_text: str = "general: blank_image" |
|
|
|
|
|
|
|
|
text_dropout_schedule: str = "linear" |
|
|
text_dropout_start: float = 0.1 |
|
|
text_dropout_end: float = 0.5 |
|
|
|
|
|
|
|
|
checkpoint_dir: str = "./checkpoints/liminal_staircase_danbooru" |
|
|
save_every: int = 500 |
|
|
|
|
|
|
|
|
hf_repo_id: Optional[str] = None |
|
|
hf_upload_every: int = 5000 |
|
|
hf_private: bool = False |
|
|
|
|
|
|
|
|
resume: bool = False |
|
|
|
|
|
|
|
|
log_dir: str = "./logs/liminal_staircase_danbooru" |
|
|
log_every: int = 5 |
|
|
|
|
|
|
|
|
device: str = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
def __post_init__(self): |
|
|
if self.scales is None: |
|
|
self.scales = [128, 256, 512] |
|
|
|
|
|
if self.scale_hidden_dims is None: |
|
|
self.scale_hidden_dims = {s: s * 2 for s in self.scales} |
|
|
|
|
|
Path(self.checkpoint_dir).mkdir(parents=True, exist_ok=True) |
|
|
Path(self.log_dir).mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
def to_model_config(self, siglip_hidden_dim: int, siglip_num_layers: int) -> LiminalStaircaseConfig: |
|
|
"""Convert to LiminalStaircaseConfig with organized fusion.""" |
|
|
|
|
|
|
|
|
fusion_config = ScaleFusionConfig( |
|
|
scales=self.scales, |
|
|
scale_hidden_dims=self.scale_hidden_dims, |
|
|
alpha_init=self.alpha_init, |
|
|
alpha_learnable=self.alpha_learnable, |
|
|
alpha_per_scale=self.alpha_per_scale, |
|
|
beta_init=self.beta_init, |
|
|
beta_learnable=self.beta_learnable, |
|
|
beta_per_scale=self.beta_per_scale, |
|
|
gamma_learnable=self.gamma_learnable, |
|
|
learn_layer_weights=self.learn_layer_weights, |
|
|
learn_scale_weights=True, |
|
|
track_scale_losses=True |
|
|
) |
|
|
|
|
|
|
|
|
return LiminalStaircaseConfig( |
|
|
num_opinion_anchors=self.num_opinion_anchors, |
|
|
pentachoron_dim=self.pentachoron_dim, |
|
|
siglip_hidden_dim=siglip_hidden_dim, |
|
|
siglip_num_layers=siglip_num_layers, |
|
|
clip_hidden_dim=768, |
|
|
clip_num_layers=12, |
|
|
clip_skip=self.clip_skip, |
|
|
vocab_size=49408, |
|
|
max_seq_len=77, |
|
|
siglip_layer_indices=self.siglip_layer_indices, |
|
|
clip_layer_indices=self.clip_layer_indices, |
|
|
scale_fusion=fusion_config, |
|
|
use_gradient_checkpointing=self.use_gradient_checkpointing, |
|
|
share_scale_embeddings=self.share_scale_embeddings, |
|
|
geometric_init_method="hybrid", |
|
|
geometric_init_validate=True, |
|
|
geometric_init_seed=42 |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CheckpointManager: |
|
|
"""Manages checkpoints with run timestamp, simple step-based checkpoint names.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
local_dir: str, |
|
|
hf_repo_id: Optional[str] = None, |
|
|
sub_name: str = "default", |
|
|
hf_private: bool = False |
|
|
): |
|
|
self.local_dir = Path(local_dir) |
|
|
self.hf_repo_id = hf_repo_id |
|
|
self.base_sub_name = sub_name |
|
|
|
|
|
|
|
|
run_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
self.sub_name = f"{sub_name}-{run_timestamp}" |
|
|
|
|
|
self.hf_private = hf_private |
|
|
|
|
|
|
|
|
self.sub_checkpoint_dir = self.local_dir / self.sub_name |
|
|
self.sub_checkpoint_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
self.checkpoints_file = self.sub_checkpoint_dir / "checkpoints.json" |
|
|
|
|
|
if hf_repo_id: |
|
|
self.hf_api = HfApi() |
|
|
try: |
|
|
create_repo( |
|
|
repo_id=hf_repo_id, |
|
|
private=hf_private, |
|
|
exist_ok=True |
|
|
) |
|
|
print(f"π€ HuggingFace repo: {hf_repo_id}") |
|
|
except Exception as e: |
|
|
print(f"β οΈ Could not create HF repo: {e}") |
|
|
self.hf_api = None |
|
|
else: |
|
|
self.hf_api = None |
|
|
|
|
|
self.checkpoint_history = self._load_checkpoint_history() |
|
|
|
|
|
def _load_checkpoint_history(self) -> Dict: |
|
|
if self.checkpoints_file.exists(): |
|
|
with open(self.checkpoints_file, 'r') as f: |
|
|
return json.load(f) |
|
|
return { |
|
|
"sub_name": self.sub_name, |
|
|
"base_name": self.base_sub_name, |
|
|
"checkpoints": [], |
|
|
"latest": None, |
|
|
"best": None |
|
|
} |
|
|
|
|
|
def _save_checkpoint_history(self): |
|
|
with open(self.checkpoints_file, 'w') as f: |
|
|
json.dump(self.checkpoint_history, f, indent=2) |
|
|
|
|
|
def get_checkpoint_dir(self, step: int, epoch: int) -> Path: |
|
|
"""Generate checkpoint directory name: just step{N}.""" |
|
|
dirname = f"step{step}" |
|
|
return self.sub_checkpoint_dir / dirname |
|
|
|
|
|
def _safe_state_dict(self, model: nn.Module) -> Dict[str, torch.Tensor]: |
|
|
"""Get state dict with shared memory removed and fusion controller deduplicated.""" |
|
|
state_dict = model.state_dict() |
|
|
|
|
|
|
|
|
keys_to_remove = [ |
|
|
k for k in state_dict.keys() if any([ |
|
|
'fusion_controller.scale_losses' in k, |
|
|
'fusion_controller.scale_loss_counts' in k, |
|
|
'fusion_controller.scale_beta_losses' in k |
|
|
]) |
|
|
] |
|
|
|
|
|
for key in keys_to_remove: |
|
|
del state_dict[key] |
|
|
|
|
|
if keys_to_remove: |
|
|
print(f" βΉοΈ Removed {len(keys_to_remove)} shared tracking buffers") |
|
|
|
|
|
|
|
|
fusion_keys_to_remove = [ |
|
|
k for k in state_dict.keys() if ( |
|
|
'siglip_experts.' in k or |
|
|
'clip_experts.' in k or |
|
|
'fusion.' in k |
|
|
) and '.fusion_controller.' in k |
|
|
] |
|
|
|
|
|
for key in fusion_keys_to_remove: |
|
|
del state_dict[key] |
|
|
|
|
|
if fusion_keys_to_remove: |
|
|
print(f" βΉοΈ Removed {len(fusion_keys_to_remove)} duplicate fusion controller references") |
|
|
print(f" β Keeping only main 'fusion_controller.*' parameters") |
|
|
|
|
|
return state_dict |
|
|
|
|
|
def save_checkpoint( |
|
|
self, |
|
|
model: nn.Module, |
|
|
optimizer: torch.optim.Optimizer, |
|
|
scheduler, |
|
|
epoch: int, |
|
|
step: int, |
|
|
val_loss: float, |
|
|
config: DanbooruTrainingConfig, |
|
|
fusion_diagnostics: Dict, |
|
|
is_best: bool = False |
|
|
) -> Path: |
|
|
"""Save checkpoint with proper naming.""" |
|
|
ckpt_dir = self.get_checkpoint_dir(step, epoch) |
|
|
ckpt_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
print(f"\nπΎ Saving checkpoint: {self.sub_name}/{ckpt_dir.name}") |
|
|
print(f" Step: {step}, Epoch: {epoch}") |
|
|
|
|
|
state_dict = self._safe_state_dict(model) |
|
|
weights_path = ckpt_dir / "model.safetensors" |
|
|
save_file(state_dict, weights_path) |
|
|
print(f" β Model weights: model.safetensors") |
|
|
|
|
|
training_state = { |
|
|
'epoch': epoch, |
|
|
'global_step': step, |
|
|
'optimizer_state_dict': optimizer.state_dict(), |
|
|
'scheduler_state_dict': scheduler.state_dict() if scheduler else None, |
|
|
'val_loss': val_loss, |
|
|
'sub_name': self.sub_name, |
|
|
'base_name': self.base_sub_name |
|
|
} |
|
|
torch.save(training_state, ckpt_dir / "training_state.pt") |
|
|
print(f" β Training state: training_state.pt") |
|
|
|
|
|
config_dict = asdict(config) |
|
|
config_dict['timestamp'] = datetime.now().isoformat() |
|
|
config_dict['step'] = step |
|
|
config_dict['epoch'] = epoch |
|
|
config_dict['val_loss'] = val_loss |
|
|
config_dict['fusion_diagnostics'] = fusion_diagnostics |
|
|
config_dict['is_best'] = is_best |
|
|
|
|
|
with open(ckpt_dir / "config.json", 'w') as f: |
|
|
json.dump(config_dict, f, indent=2) |
|
|
print(f" β Config: config.json (step={step}, epoch={epoch}, val_loss={val_loss:.4f})") |
|
|
|
|
|
checkpoint_info = { |
|
|
'timestamp': datetime.now().isoformat(), |
|
|
'dirname': ckpt_dir.name, |
|
|
'step': step, |
|
|
'epoch': epoch, |
|
|
'val_loss': val_loss, |
|
|
'is_best': is_best, |
|
|
'fusion_diagnostics': fusion_diagnostics |
|
|
} |
|
|
|
|
|
self.checkpoint_history['checkpoints'].append(checkpoint_info) |
|
|
self.checkpoint_history['latest'] = checkpoint_info |
|
|
|
|
|
if is_best: |
|
|
self.checkpoint_history['best'] = checkpoint_info |
|
|
|
|
|
self._save_checkpoint_history() |
|
|
print(f" β Updated checkpoint history") |
|
|
|
|
|
return ckpt_dir |
|
|
|
|
|
def upload_checkpoint(self, ckpt_dir: Path): |
|
|
"""Upload checkpoint to HuggingFace.""" |
|
|
if not self.hf_api or not self.hf_repo_id: |
|
|
return |
|
|
|
|
|
try: |
|
|
print(f"\nπ€ Uploading to HuggingFace: {self.hf_repo_id}") |
|
|
print(f" Path: {self.sub_name}/{ckpt_dir.name}") |
|
|
|
|
|
self.hf_api.upload_folder( |
|
|
repo_id=self.hf_repo_id, |
|
|
folder_path=str(ckpt_dir), |
|
|
path_in_repo=f"{self.sub_name}/{ckpt_dir.name}", |
|
|
commit_message=f"Checkpoint: {self.sub_name}/{ckpt_dir.name}" |
|
|
) |
|
|
print(f" β Uploaded checkpoint files") |
|
|
|
|
|
self.hf_api.upload_file( |
|
|
repo_id=self.hf_repo_id, |
|
|
path_or_fileobj=str(self.checkpoints_file), |
|
|
path_in_repo=f"{self.sub_name}/checkpoints.json", |
|
|
commit_message=f"Update checkpoint history" |
|
|
) |
|
|
print(f" β Updated checkpoints.json") |
|
|
|
|
|
print(f"β
Upload complete: https://huggingface.co/{self.hf_repo_id}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β οΈ Upload failed: {e}") |
|
|
traceback.print_exc() |
|
|
|
|
|
def find_latest_checkpoint(self) -> Optional[Dict]: |
|
|
"""Find the latest checkpoint for this training run.""" |
|
|
checkpoints = self.checkpoint_history.get('checkpoints', []) |
|
|
if checkpoints: |
|
|
return max(checkpoints, key=lambda x: x['step']) |
|
|
return None |
|
|
|
|
|
def load_checkpoint_for_resume( |
|
|
self, |
|
|
model: nn.Module, |
|
|
optimizer: torch.optim.Optimizer, |
|
|
scheduler |
|
|
) -> Tuple[int, int, float]: |
|
|
"""Load checkpoint to resume training.""" |
|
|
latest = self.find_latest_checkpoint() |
|
|
|
|
|
if not latest: |
|
|
print(f"βΉοΈ No previous checkpoint found for training run '{self.sub_name}'") |
|
|
return 0, 0, float('inf') |
|
|
|
|
|
ckpt_dir = self.sub_checkpoint_dir / latest['dirname'] |
|
|
|
|
|
if not ckpt_dir.exists(): |
|
|
if self.hf_api and self.hf_repo_id: |
|
|
print(f"π₯ Downloading checkpoint from HuggingFace...") |
|
|
try: |
|
|
weights_path = hf_hub_download( |
|
|
repo_id=self.hf_repo_id, |
|
|
filename=f"{self.sub_name}/{latest['dirname']}/model.safetensors", |
|
|
local_dir=self.local_dir |
|
|
) |
|
|
|
|
|
state_path = hf_hub_download( |
|
|
repo_id=self.hf_repo_id, |
|
|
filename=f"{self.sub_name}/{latest['dirname']}/training_state.pt", |
|
|
local_dir=self.local_dir |
|
|
) |
|
|
print(f" β Downloaded checkpoint files") |
|
|
except Exception as e: |
|
|
print(f" β οΈ Download failed: {e}") |
|
|
return 0, 0, float('inf') |
|
|
else: |
|
|
print(f" β οΈ Checkpoint directory not found: {ckpt_dir}") |
|
|
return 0, 0, float('inf') |
|
|
|
|
|
print(f"\nπ Resuming from checkpoint: {self.sub_name}/{latest['dirname']}") |
|
|
print(f" Step: {latest['step']}, Epoch: {latest['epoch']}, Val Loss: {latest['val_loss']:.4f}") |
|
|
|
|
|
weights_path = ckpt_dir / "model.safetensors" |
|
|
state_dict = load_file(str(weights_path)) |
|
|
|
|
|
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) |
|
|
|
|
|
expected_missing = [ |
|
|
k for k in missing_keys if ( |
|
|
'siglip_experts.' in k or |
|
|
'clip_experts.' in k or |
|
|
'fusion.' in k |
|
|
) and '.fusion_controller.' in k |
|
|
] |
|
|
|
|
|
unexpected_missing = [k for k in missing_keys if k not in expected_missing] |
|
|
|
|
|
if unexpected_missing: |
|
|
print(f" β οΈ Unexpected missing keys: {len(unexpected_missing)}") |
|
|
for k in unexpected_missing[:5]: |
|
|
print(f" - {k}") |
|
|
|
|
|
if unexpected_keys: |
|
|
print(f" β οΈ Unexpected keys: {len(unexpected_keys)}") |
|
|
|
|
|
print(f" β Loaded model weights ({len(expected_missing)} shared fusion refs skipped)") |
|
|
|
|
|
state_path = ckpt_dir / "training_state.pt" |
|
|
training_state = torch.load(state_path) |
|
|
|
|
|
optimizer.load_state_dict(training_state['optimizer_state_dict']) |
|
|
print(f" β Loaded optimizer state") |
|
|
|
|
|
if scheduler and training_state['scheduler_state_dict']: |
|
|
scheduler.load_state_dict(training_state['scheduler_state_dict']) |
|
|
print(f" β Loaded scheduler state") |
|
|
|
|
|
return training_state['epoch'], training_state['global_step'], training_state['val_loss'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class IllustriousCLIPTextEncoder(nn.Module): |
|
|
"""Loads and wraps Illustrious CLIP text encoder.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
safetensors_path: str, |
|
|
tokenizer_name: str = "openai/clip-vit-large-patch14", |
|
|
clip_skip: int = 2, |
|
|
device: str = "cuda" |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.clip_skip = clip_skip |
|
|
self.device = device |
|
|
|
|
|
print(f"\n{'='*80}") |
|
|
print("LOADING ILLUSTRIOUS CLIP TEXT ENCODER") |
|
|
print(f"{'='*80}") |
|
|
|
|
|
from transformers import CLIPTokenizer |
|
|
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_name) |
|
|
print(f"β Tokenizer: {tokenizer_name}") |
|
|
print(f"β Vocab size: {self.tokenizer.vocab_size}") |
|
|
|
|
|
if not os.path.exists(safetensors_path): |
|
|
print(f"\nβ οΈ Illustrious CLIP not found: {safetensors_path}") |
|
|
print("Falling back to standard CLIP") |
|
|
|
|
|
from transformers import CLIPTextModel |
|
|
self.model = CLIPTextModel.from_pretrained(tokenizer_name).to(device) |
|
|
self.is_illustrious = False |
|
|
else: |
|
|
print(f"Loading from: {safetensors_path}") |
|
|
|
|
|
state_dict = load_file(safetensors_path) |
|
|
print(f"β Loaded {len(state_dict)} tensors") |
|
|
|
|
|
from transformers import CLIPTextModel, CLIPTextConfig |
|
|
config = CLIPTextConfig.from_pretrained(tokenizer_name) |
|
|
self.model = CLIPTextModel(config).to(device) |
|
|
|
|
|
model_state_dict = self.model.state_dict() |
|
|
mapped_state = {} |
|
|
|
|
|
for key in state_dict.keys(): |
|
|
if key in model_state_dict: |
|
|
mapped_state[key] = state_dict[key] |
|
|
else: |
|
|
new_key = key.replace("text_model.", "") |
|
|
if new_key in model_state_dict: |
|
|
mapped_state[new_key] = state_dict[key] |
|
|
|
|
|
print(f"β Mapped {len(mapped_state)}/{len(model_state_dict)} parameters") |
|
|
|
|
|
missing, unexpected = self.model.load_state_dict(mapped_state, strict=False) |
|
|
if missing: |
|
|
print(f"β οΈ Missing: {len(missing)} keys") |
|
|
if unexpected: |
|
|
print(f"β οΈ Unexpected: {len(unexpected)} keys") |
|
|
|
|
|
self.is_illustrious = True |
|
|
print(f"β
Illustrious CLIP loaded!") |
|
|
|
|
|
for param in self.model.parameters(): |
|
|
param.requires_grad = False |
|
|
self.model.eval() |
|
|
|
|
|
active_layers = 12 - clip_skip |
|
|
print(f"β Using {active_layers} layers (skip last {clip_skip})") |
|
|
print(f"{'='*80}\n") |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
attention_mask: torch.Tensor |
|
|
) -> Dict[str, torch.Tensor]: |
|
|
"""Extract features from text encoder layers.""" |
|
|
with torch.no_grad(): |
|
|
outputs = self.model( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
output_hidden_states=True, |
|
|
return_dict=True |
|
|
) |
|
|
|
|
|
hidden_states = outputs.hidden_states |
|
|
num_layers = len(hidden_states) - self.clip_skip - 1 |
|
|
|
|
|
features = {} |
|
|
for i in range(num_layers): |
|
|
features[f'clip_layer_{i}'] = hidden_states[i + 1] |
|
|
|
|
|
return features |
|
|
|
|
|
|
|
|
class SigLIPFeatureExtractor(nn.Module): |
|
|
"""Extracts features from all SigLIP vision layers.""" |
|
|
|
|
|
def __init__(self, model_name: str, device: str = "cuda"): |
|
|
super().__init__() |
|
|
|
|
|
print(f"\n{'='*80}") |
|
|
print("LOADING SIGLIP VISION ENCODER") |
|
|
print(f"{'='*80}") |
|
|
print(f"Model: {model_name}") |
|
|
|
|
|
self.model = SiglipModel.from_pretrained(model_name).to(device) |
|
|
self.processor = SiglipProcessor.from_pretrained(model_name) |
|
|
|
|
|
for param in self.model.parameters(): |
|
|
param.requires_grad = False |
|
|
self.model.eval() |
|
|
|
|
|
self.layer_outputs = {} |
|
|
self._register_hooks() |
|
|
|
|
|
num_layers = len(self.model.vision_model.encoder.layers) |
|
|
print(f"β {num_layers} vision layers") |
|
|
print(f"β Frozen encoder") |
|
|
print(f"{'='*80}\n") |
|
|
|
|
|
def _register_hooks(self): |
|
|
"""Register forward hooks to capture layer outputs.""" |
|
|
vision_model = self.model.vision_model |
|
|
|
|
|
for i, layer in enumerate(vision_model.encoder.layers): |
|
|
def make_hook(layer_idx): |
|
|
def hook(module, input, output): |
|
|
self.layer_outputs[f'siglip_layer_{layer_idx}'] = output |
|
|
return hook |
|
|
layer.register_forward_hook(make_hook(i)) |
|
|
|
|
|
def forward(self, images: torch.Tensor) -> Dict[str, torch.Tensor]: |
|
|
"""Extract features from all vision layers using hooks.""" |
|
|
with torch.no_grad(): |
|
|
if images.device != next(self.model.parameters()).device: |
|
|
images = images.to(next(self.model.parameters()).device) |
|
|
|
|
|
self.layer_outputs = {} |
|
|
_ = self.model.vision_model(pixel_values=images) |
|
|
|
|
|
return dict(self.layer_outputs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GeometricRegularization(nn.Module): |
|
|
"""Geometric regularization for pentachoron opinion anchors.""" |
|
|
|
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
|
|
|
def cayley_menger_loss( |
|
|
self, |
|
|
pentachora: torch.Tensor, |
|
|
sample_size: int = 50 |
|
|
) -> torch.Tensor: |
|
|
"""Cayley-Menger volume regularization.""" |
|
|
num_anchors = pentachora.shape[0] |
|
|
|
|
|
if num_anchors > sample_size: |
|
|
indices = torch.randperm(num_anchors, device=pentachora.device)[:sample_size] |
|
|
pentachora = pentachora[indices] |
|
|
|
|
|
losses = [] |
|
|
for i in range(pentachora.shape[0]): |
|
|
vertices = pentachora[i] |
|
|
|
|
|
diff = vertices.unsqueeze(0) - vertices.unsqueeze(1) |
|
|
dist_sq = (diff ** 2).sum(dim=-1) |
|
|
|
|
|
M = torch.zeros(6, 6, device=vertices.device, dtype=vertices.dtype) |
|
|
M[0, 1:] = 1.0 |
|
|
M[1:, 0] = 1.0 |
|
|
M[1:, 1:] = dist_sq |
|
|
|
|
|
det = torch.linalg.det(M) |
|
|
volume_sq = (-det / 9216.0).clamp(min=0.0) |
|
|
volume = volume_sq.sqrt() |
|
|
|
|
|
volume_loss = F.relu(0.01 - volume) |
|
|
losses.append(volume_loss) |
|
|
|
|
|
return torch.stack(losses).mean() |
|
|
|
|
|
def rose_loss( |
|
|
self, |
|
|
pentachora: torch.Tensor, |
|
|
target_norm: float = 0.29514 |
|
|
) -> torch.Tensor: |
|
|
"""Rose harmonic constraint.""" |
|
|
vertex_norms = torch.norm(pentachora, dim=-1) |
|
|
target = torch.full_like(vertex_norms, target_norm) |
|
|
return F.mse_loss(vertex_norms, target) |
|
|
|
|
|
def forward(self, pentachora: torch.Tensor) -> Dict[str, torch.Tensor]: |
|
|
"""Compute all geometric losses.""" |
|
|
return { |
|
|
'cayley': self.cayley_menger_loss(pentachora), |
|
|
'rose': self.rose_loss(pentachora) |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DanbooruLiminalStaircaseTrainer: |
|
|
"""Trainer with bulletproof checkpointing + text modality robustness.""" |
|
|
|
|
|
def __init__(self, config: DanbooruTrainingConfig): |
|
|
self.config = config |
|
|
self._interrupt_received = False |
|
|
self._save_on_interrupt = True |
|
|
|
|
|
self.accelerator = Accelerator( |
|
|
gradient_accumulation_steps=config.gradient_accumulation_steps, |
|
|
mixed_precision='fp16' if config.device == 'cuda' else 'no' |
|
|
) |
|
|
|
|
|
print("\n" + "π¨ " * 40) |
|
|
print("LIMINAL STAIRCASE TRAINER - BULLETPROOF + GEOMETRIC + TEXT ROBUSTNESS") |
|
|
print("π¨ " * 40 + "\n") |
|
|
|
|
|
|
|
|
self.checkpoint_manager = CheckpointManager( |
|
|
local_dir=config.checkpoint_dir, |
|
|
hf_repo_id=config.hf_repo_id, |
|
|
sub_name=config.sub_name, |
|
|
hf_private=config.hf_private |
|
|
) |
|
|
|
|
|
|
|
|
if self.accelerator.is_main_process: |
|
|
log_dir = Path(config.log_dir) / f"{config.sub_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}" |
|
|
self.writer = SummaryWriter(log_dir=log_dir) |
|
|
print(f"π TensorBoard logging to: {log_dir}") |
|
|
else: |
|
|
self.writer = None |
|
|
|
|
|
|
|
|
self.siglip_extractor = SigLIPFeatureExtractor( |
|
|
config.siglip_model, |
|
|
config.device |
|
|
) |
|
|
|
|
|
self.clip_extractor = IllustriousCLIPTextEncoder( |
|
|
config.illustrious_clip_path, |
|
|
config.clip_tokenizer, |
|
|
config.clip_skip, |
|
|
config.device |
|
|
) |
|
|
|
|
|
|
|
|
siglip_hidden_dim = self.siglip_extractor.model.vision_model.config.hidden_size |
|
|
siglip_num_layers = len(self.siglip_extractor.model.vision_model.encoder.layers) |
|
|
|
|
|
|
|
|
print("\n" + "β‘ " * 40) |
|
|
print("INITIALIZING LIMINAL STAIRCASE WITH GEOMETRIC PENTACHORA") |
|
|
print("β‘ " * 40) |
|
|
|
|
|
model_config = config.to_model_config(siglip_hidden_dim, siglip_num_layers) |
|
|
self.model = LiminalStaircase(model_config).to(config.device) |
|
|
|
|
|
|
|
|
self.geometric_reg = GeometricRegularization() |
|
|
|
|
|
|
|
|
self.optimizer = torch.optim.AdamW( |
|
|
self.model.parameters(), |
|
|
lr=config.learning_rate, |
|
|
weight_decay=config.weight_decay |
|
|
) |
|
|
|
|
|
|
|
|
print("\n" + "π¨ " * 40) |
|
|
self.train_loader, self.val_loader, self.tag_vocab = create_danbooru_dataloaders( |
|
|
siglip_processor=self.siglip_extractor.processor, |
|
|
clip_tokenizer=self.clip_extractor.tokenizer, |
|
|
dataset_name=config.dataset_name, |
|
|
image_size=config.image_size, |
|
|
batch_size=config.batch_size, |
|
|
num_workers=4 |
|
|
) |
|
|
|
|
|
|
|
|
steps_per_epoch = len(self.train_loader) |
|
|
total_steps = config.num_epochs * steps_per_epoch |
|
|
|
|
|
print(f"\nπ Training schedule:") |
|
|
print(f" Steps per epoch: {steps_per_epoch:,}") |
|
|
print(f" Total training steps: {total_steps:,}") |
|
|
|
|
|
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( |
|
|
self.optimizer, |
|
|
T_max=total_steps |
|
|
) |
|
|
|
|
|
|
|
|
print(f"\nπ· Creating vision-only sentinel token...") |
|
|
print(f" Token: '{config.vision_only_text}'") |
|
|
with torch.no_grad(): |
|
|
sentinel_input = self.clip_extractor.tokenizer( |
|
|
config.vision_only_text, |
|
|
return_tensors="pt", |
|
|
padding="max_length", |
|
|
truncation=True, |
|
|
max_length=config.max_tag_length |
|
|
).to(config.device) |
|
|
|
|
|
|
|
|
self.vision_only_clip_features = self.clip_extractor( |
|
|
sentinel_input['input_ids'], |
|
|
sentinel_input['attention_mask'] |
|
|
) |
|
|
|
|
|
|
|
|
self.vision_only_clip_features = { |
|
|
name: feat.detach().clone() |
|
|
for name, feat in self.vision_only_clip_features.items() |
|
|
} |
|
|
|
|
|
print(f"β Vision-only sentinel cached") |
|
|
example_shape = list(self.vision_only_clip_features.values())[0].shape |
|
|
print(f" Shape example: {example_shape}") |
|
|
print(f" Text dropout: {config.text_dropout_schedule} schedule") |
|
|
print(f" Start: {config.text_dropout_start:.1%}, End: {config.text_dropout_end:.1%}") |
|
|
|
|
|
|
|
|
( |
|
|
self.model, |
|
|
self.optimizer, |
|
|
self.train_loader, |
|
|
self.val_loader, |
|
|
self.scheduler |
|
|
) = self.accelerator.prepare( |
|
|
self.model, |
|
|
self.optimizer, |
|
|
self.train_loader, |
|
|
self.val_loader, |
|
|
self.scheduler |
|
|
) |
|
|
|
|
|
self.global_step = 0 |
|
|
self.start_epoch = 0 |
|
|
self.best_val_loss = float('inf') |
|
|
self.current_epoch = 0 |
|
|
|
|
|
|
|
|
self.text_dropout_stats = { |
|
|
'clean': 0, |
|
|
'noisy': 0, |
|
|
'sentinel': 0 |
|
|
} |
|
|
|
|
|
|
|
|
if config.resume and self.accelerator.is_main_process: |
|
|
epoch, step, val_loss = self.checkpoint_manager.load_checkpoint_for_resume( |
|
|
self.accelerator.unwrap_model(self.model), |
|
|
self.optimizer, |
|
|
self.scheduler |
|
|
) |
|
|
self.start_epoch = epoch |
|
|
self.global_step = step |
|
|
self.best_val_loss = val_loss |
|
|
|
|
|
|
|
|
self._setup_interrupt_handler() |
|
|
|
|
|
print("\n" + "β
" * 40) |
|
|
print("TRAINER READY") |
|
|
print("β
" * 40) |
|
|
print(f"Sub name: {config.sub_name}") |
|
|
print(f"Fusion strategy: {config.fusion_strategy}") |
|
|
print(f"Model params: {sum(p.numel() for p in self.model.parameters()):,}") |
|
|
print(f"Text robustness: ENABLED") |
|
|
print(f" Sentinel: '{config.vision_only_text}'") |
|
|
print(f" Dropout schedule: {config.text_dropout_schedule}") |
|
|
if self.global_step > 0: |
|
|
print(f"Resuming from: step {self.global_step}, epoch {self.start_epoch}") |
|
|
print(f"β‘ Interrupt handling: Ctrl+C saves checkpoint before exit") |
|
|
print("β
" * 40 + "\n") |
|
|
|
|
|
def _setup_interrupt_handler(self): |
|
|
"""Setup signal handler for graceful interrupt.""" |
|
|
def signal_handler(sig, frame): |
|
|
if self._interrupt_received: |
|
|
print("\nβ οΈ Second interrupt received, forcing exit...") |
|
|
sys.exit(1) |
|
|
|
|
|
self._interrupt_received = True |
|
|
print("\n" + "β‘ " * 40) |
|
|
print("KEYBOARD INTERRUPT DETECTED") |
|
|
print("β‘ " * 40) |
|
|
print("Saving checkpoint before exit...") |
|
|
|
|
|
if self._save_on_interrupt and self.accelerator.is_main_process: |
|
|
try: |
|
|
self._emergency_save_checkpoint() |
|
|
print("β
Emergency checkpoint saved successfully") |
|
|
except Exception as e: |
|
|
print(f"β οΈ Emergency save failed: {e}") |
|
|
traceback.print_exc() |
|
|
|
|
|
print("\n" + "β‘ " * 40) |
|
|
print("Exiting gracefully...") |
|
|
print("β‘ " * 40 + "\n") |
|
|
sys.exit(0) |
|
|
|
|
|
signal.signal(signal.SIGINT, signal_handler) |
|
|
|
|
|
def _emergency_save_checkpoint(self): |
|
|
"""Emergency checkpoint save on interrupt.""" |
|
|
print(f"\nπΎ Emergency save at step {self.global_step}, epoch {self.current_epoch}") |
|
|
|
|
|
fusion_diagnostics = self.get_fusion_diagnostics() |
|
|
|
|
|
ckpt_dir = self.checkpoint_manager.save_checkpoint( |
|
|
model=self.accelerator.unwrap_model(self.model), |
|
|
optimizer=self.optimizer, |
|
|
scheduler=self.scheduler, |
|
|
epoch=self.current_epoch, |
|
|
step=self.global_step, |
|
|
val_loss=float('inf'), |
|
|
config=self.config, |
|
|
fusion_diagnostics=fusion_diagnostics, |
|
|
is_best=False |
|
|
) |
|
|
|
|
|
if self.config.hf_repo_id: |
|
|
print("Attempting HuggingFace upload...") |
|
|
try: |
|
|
self.checkpoint_manager.upload_checkpoint(ckpt_dir) |
|
|
except Exception as e: |
|
|
print(f"β οΈ Upload failed (checkpoint saved locally): {e}") |
|
|
|
|
|
def get_text_dropout_prob(self) -> float: |
|
|
"""Get current text dropout probability with curriculum.""" |
|
|
if self.config.text_dropout_schedule == "constant": |
|
|
return self.config.text_dropout_prob |
|
|
|
|
|
|
|
|
steps_per_epoch = len(self.train_loader) |
|
|
total_steps = self.config.num_epochs * steps_per_epoch |
|
|
progress = self.global_step / max(total_steps, 1) |
|
|
|
|
|
if self.config.text_dropout_schedule == "linear": |
|
|
dropout = self.config.text_dropout_start + progress * ( |
|
|
self.config.text_dropout_end - self.config.text_dropout_start |
|
|
) |
|
|
elif self.config.text_dropout_schedule == "cosine": |
|
|
dropout = self.config.text_dropout_start + 0.5 * ( |
|
|
self.config.text_dropout_end - self.config.text_dropout_start |
|
|
) * (1 - np.cos(np.pi * progress)) |
|
|
else: |
|
|
dropout = self.config.text_dropout_prob |
|
|
|
|
|
return dropout |
|
|
|
|
|
def compute_loss( |
|
|
self, |
|
|
outputs: Dict, |
|
|
target_tokens: torch.Tensor |
|
|
) -> Tuple[torch.Tensor, Dict[str, float]]: |
|
|
"""Compute ALL losses in trainer.""" |
|
|
try: |
|
|
token_logits = outputs['token_logits'] |
|
|
|
|
|
B, seq_len, vocab_size = token_logits.shape |
|
|
token_logits_flat = token_logits.view(-1, vocab_size) |
|
|
target_tokens_flat = target_tokens.view(-1) |
|
|
|
|
|
token_loss = F.cross_entropy( |
|
|
token_logits_flat, |
|
|
target_tokens_flat, |
|
|
ignore_index=self.clip_extractor.tokenizer.pad_token_id |
|
|
) |
|
|
|
|
|
|
|
|
pentachora = self.accelerator.unwrap_model(self.model).opinion_anchors |
|
|
geo_losses = self.geometric_reg(pentachora) |
|
|
|
|
|
|
|
|
beta_loss = 0.0 |
|
|
if 'scale_feature_pairs' in outputs and self.model.training: |
|
|
beta_losses = [] |
|
|
for scale, features in outputs['scale_feature_pairs'].items(): |
|
|
token_feat = features['token_features'] |
|
|
geo_feat = features['geometric_features'] |
|
|
beta = features['beta'] |
|
|
|
|
|
scale_beta_loss = beta * F.mse_loss(token_feat, geo_feat) |
|
|
beta_losses.append(scale_beta_loss) |
|
|
|
|
|
if beta_losses: |
|
|
beta_loss = sum(beta_losses) / len(beta_losses) |
|
|
|
|
|
total_loss = ( |
|
|
self.config.token_loss_weight * token_loss + |
|
|
self.config.geometric_weight * (geo_losses['cayley'] + geo_losses['rose'] + beta_loss) |
|
|
) |
|
|
|
|
|
|
|
|
preds = token_logits.argmax(dim=-1) |
|
|
mask = target_tokens != self.clip_extractor.tokenizer.pad_token_id |
|
|
mask_sum = mask.float().sum() |
|
|
|
|
|
if mask_sum > 0: |
|
|
acc = ((preds == target_tokens) & mask).float().sum() / mask_sum |
|
|
else: |
|
|
acc = torch.tensor(0.0, device=token_logits.device) |
|
|
|
|
|
metrics = { |
|
|
'loss/total': total_loss.item(), |
|
|
'loss/token': token_loss.item(), |
|
|
'loss/cayley': geo_losses['cayley'].item(), |
|
|
'loss/rose': geo_losses['rose'].item(), |
|
|
'loss/beta': beta_loss.item() if isinstance(beta_loss, torch.Tensor) else beta_loss, |
|
|
'acc/token': acc.item() |
|
|
} |
|
|
|
|
|
return total_loss, metrics |
|
|
|
|
|
except Exception as e: |
|
|
print(f"\nβ οΈ Error in compute_loss: {e}") |
|
|
traceback.print_exc() |
|
|
raise |
|
|
|
|
|
def get_fusion_diagnostics(self) -> Dict: |
|
|
"""Get current fusion controller state with error handling.""" |
|
|
try: |
|
|
model = self.accelerator.unwrap_model(self.model) |
|
|
return model.fusion_controller.get_diagnostics() |
|
|
except Exception as e: |
|
|
print(f"β οΈ Error getting fusion diagnostics: {e}") |
|
|
return { |
|
|
'layer_weights': [], |
|
|
'scale_weights': [], |
|
|
'alpha_per_scale': [], |
|
|
'beta_per_scale': [], |
|
|
'scale_statistics': {} |
|
|
} |
|
|
|
|
|
def train_step(self, batch: Dict) -> Dict[str, float]: |
|
|
"""Single training step with TEXT MODALITY ROBUSTNESS.""" |
|
|
try: |
|
|
self.model.train() |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
siglip_features = self.siglip_extractor(batch['siglip_images']) |
|
|
|
|
|
|
|
|
current_dropout = self.get_text_dropout_prob() |
|
|
use_text = torch.rand(1).item() > current_dropout |
|
|
text_status = "clean" |
|
|
|
|
|
if use_text: |
|
|
|
|
|
with torch.no_grad(): |
|
|
clip_features = self.clip_extractor( |
|
|
batch['clip_input_ids'], |
|
|
batch['clip_attention_mask'] |
|
|
) |
|
|
|
|
|
|
|
|
if torch.rand(1).item() < self.config.text_noise_prob: |
|
|
for layer_name, features in clip_features.items(): |
|
|
noise = torch.randn_like(features) * self.config.text_noise_std |
|
|
clip_features[layer_name] = features + noise |
|
|
text_status = "noisy" |
|
|
self.text_dropout_stats['noisy'] += 1 |
|
|
else: |
|
|
text_status = "clean" |
|
|
self.text_dropout_stats['clean'] += 1 |
|
|
else: |
|
|
|
|
|
batch_size = batch['siglip_images'].shape[0] |
|
|
clip_features = {} |
|
|
|
|
|
for layer_name, sentinel_feat in self.vision_only_clip_features.items(): |
|
|
|
|
|
clip_features[layer_name] = sentinel_feat.expand( |
|
|
batch_size, -1, -1 |
|
|
).contiguous() |
|
|
|
|
|
text_status = "sentinel" |
|
|
self.text_dropout_stats['sentinel'] += 1 |
|
|
|
|
|
|
|
|
with self.accelerator.accumulate(self.model): |
|
|
outputs = self.model(siglip_features, clip_features) |
|
|
loss, metrics = self.compute_loss(outputs, batch['clip_input_ids']) |
|
|
|
|
|
|
|
|
metrics['text_dropout_prob'] = current_dropout |
|
|
metrics['text_mode'] = {'clean': 0.0, 'noisy': 0.5, 'sentinel': 1.0}[text_status] |
|
|
|
|
|
self.accelerator.backward(loss) |
|
|
|
|
|
if self.accelerator.sync_gradients and self.config.gradient_clip > 0: |
|
|
self.accelerator.clip_grad_norm_( |
|
|
self.model.parameters(), |
|
|
self.config.gradient_clip |
|
|
) |
|
|
|
|
|
self.optimizer.step() |
|
|
self.scheduler.step() |
|
|
self.optimizer.zero_grad() |
|
|
|
|
|
return metrics |
|
|
|
|
|
except Exception as e: |
|
|
print(f"\nβ οΈ Error in train_step at step {self.global_step}: {e}") |
|
|
traceback.print_exc() |
|
|
return { |
|
|
'loss/total': float('nan'), |
|
|
'loss/token': float('nan'), |
|
|
'loss/cayley': 0.0, |
|
|
'loss/rose': 0.0, |
|
|
'loss/beta': 0.0, |
|
|
'acc/token': 0.0, |
|
|
'text_dropout_prob': 0.0, |
|
|
'text_mode': 0.0 |
|
|
} |
|
|
|
|
|
def log_metrics(self, metrics: Dict[str, float], prefix: str = "train"): |
|
|
"""Log metrics to TensorBoard.""" |
|
|
if self.writer is None: |
|
|
return |
|
|
|
|
|
for key, value in metrics.items(): |
|
|
|
|
|
if prefix == "val" and key.startswith(('loss/', 'acc/')): |
|
|
|
|
|
clean_key = key.replace('loss/', '').replace('acc/', '') |
|
|
self.writer.add_scalar(f"val/{clean_key}", value, self.global_step) |
|
|
else: |
|
|
self.writer.add_scalar(f"{prefix}/{key}", value, self.global_step) |
|
|
|
|
|
|
|
|
if prefix == "train": |
|
|
current_lr = self.optimizer.param_groups[0]['lr'] |
|
|
self.writer.add_scalar("train/learning_rate", current_lr, self.global_step) |
|
|
|
|
|
|
|
|
self.writer.flush() |
|
|
|
|
|
|
|
|
if prefix == "train" and self.global_step % self.config.log_every == 0: |
|
|
total = sum(self.text_dropout_stats.values()) or 1 |
|
|
for mode, count in self.text_dropout_stats.items(): |
|
|
self.writer.add_scalar(f"text_modality/{mode}_pct", 100 * count / total, self.global_step) |
|
|
|
|
|
|
|
|
if prefix == "train" and self.global_step % (self.config.log_every * 10) == 0: |
|
|
fusion_diag = self.get_fusion_diagnostics() |
|
|
|
|
|
for i, w in enumerate(fusion_diag.get('layer_weights', [])): |
|
|
self.writer.add_scalar(f"fusion/layer_weight_{i}", w, self.global_step) |
|
|
|
|
|
for i, w in enumerate(fusion_diag.get('scale_weights', [])): |
|
|
self.writer.add_scalar(f"fusion/scale_weight_{i}", w, self.global_step) |
|
|
|
|
|
for i, a in enumerate(fusion_diag.get('alpha_per_scale', [])): |
|
|
self.writer.add_scalar(f"fusion/alpha_scale_{i}", a, self.global_step) |
|
|
|
|
|
for i, b in enumerate(fusion_diag.get('beta_per_scale', [])): |
|
|
self.writer.add_scalar(f"fusion/beta_scale_{i}", b, self.global_step) |
|
|
|
|
|
self.writer.flush() |
|
|
|
|
|
@torch.no_grad() |
|
|
def validate(self, max_batches: int = 100) -> Dict[str, float]: |
|
|
"""Validation with both vision-only and vision+text modes.""" |
|
|
try: |
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
stats_with_text = {'loss': 0.0, 'acc': 0.0, 'count': 0} |
|
|
stats_vision_only = {'loss': 0.0, 'acc': 0.0, 'count': 0} |
|
|
|
|
|
num_batches = 0 |
|
|
|
|
|
for batch in tqdm(self.val_loader, desc="Validating", leave=False, total=max_batches): |
|
|
if num_batches >= max_batches: |
|
|
break |
|
|
|
|
|
try: |
|
|
siglip_features = self.siglip_extractor(batch['siglip_images']) |
|
|
batch_size = batch['siglip_images'].shape[0] |
|
|
|
|
|
|
|
|
clip_features_text = self.clip_extractor( |
|
|
batch['clip_input_ids'], |
|
|
batch['clip_attention_mask'] |
|
|
) |
|
|
|
|
|
outputs_text = self.model(siglip_features, clip_features_text) |
|
|
loss_text, metrics_text = self.compute_loss(outputs_text, batch['clip_input_ids']) |
|
|
|
|
|
stats_with_text['loss'] += metrics_text['loss/total'] |
|
|
stats_with_text['acc'] += metrics_text['acc/token'] |
|
|
stats_with_text['count'] += 1 |
|
|
|
|
|
|
|
|
clip_features_sentinel = {} |
|
|
for layer_name, sentinel_feat in self.vision_only_clip_features.items(): |
|
|
clip_features_sentinel[layer_name] = sentinel_feat.expand( |
|
|
batch_size, -1, -1 |
|
|
).contiguous() |
|
|
|
|
|
outputs_vision = self.model(siglip_features, clip_features_sentinel) |
|
|
loss_vision, metrics_vision = self.compute_loss(outputs_vision, batch['clip_input_ids']) |
|
|
|
|
|
stats_vision_only['loss'] += metrics_vision['loss/total'] |
|
|
stats_vision_only['acc'] += metrics_vision['acc/token'] |
|
|
stats_vision_only['count'] += 1 |
|
|
|
|
|
num_batches += 1 |
|
|
|
|
|
except Exception as e: |
|
|
print(f"\nβ οΈ Error in validation batch: {e}") |
|
|
continue |
|
|
|
|
|
if stats_with_text['count'] == 0 or stats_vision_only['count'] == 0: |
|
|
return { |
|
|
'val_with_text_loss': float('inf'), |
|
|
'val_with_text_acc': 0.0, |
|
|
'val_vision_only_loss': float('inf'), |
|
|
'val_vision_only_acc': 0.0, |
|
|
'loss/val': float('inf'), |
|
|
'acc/val': 0.0 |
|
|
} |
|
|
|
|
|
return { |
|
|
'val_with_text_loss': stats_with_text['loss'] / stats_with_text['count'], |
|
|
'val_with_text_acc': stats_with_text['acc'] / stats_with_text['count'], |
|
|
'val_vision_only_loss': stats_vision_only['loss'] / stats_vision_only['count'], |
|
|
'val_vision_only_acc': stats_vision_only['acc'] / stats_vision_only['count'], |
|
|
'loss/val': stats_vision_only['loss'] / stats_vision_only['count'], |
|
|
'acc/val': stats_vision_only['acc'] / stats_vision_only['count'], |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
print(f"\nβ οΈ Validation completely failed: {e}") |
|
|
traceback.print_exc() |
|
|
return { |
|
|
'val_with_text_loss': float('inf'), |
|
|
'val_with_text_acc': 0.0, |
|
|
'val_vision_only_loss': float('inf'), |
|
|
'val_vision_only_acc': 0.0, |
|
|
'loss/val': float('inf'), |
|
|
'acc/val': 0.0 |
|
|
} |
|
|
|
|
|
def save_checkpoint_and_upload(self, epoch: int, val_loss: float = float('inf'), is_best: bool = False): |
|
|
"""Save checkpoint first, then optionally upload.""" |
|
|
if not self.accelerator.is_main_process: |
|
|
return |
|
|
|
|
|
try: |
|
|
fusion_diagnostics = self.get_fusion_diagnostics() |
|
|
|
|
|
|
|
|
total = sum(self.text_dropout_stats.values()) or 1 |
|
|
fusion_diagnostics['text_modality_stats'] = { |
|
|
mode: f"{100 * count / total:.1f}%" |
|
|
for mode, count in self.text_dropout_stats.items() |
|
|
} |
|
|
|
|
|
ckpt_dir = self.checkpoint_manager.save_checkpoint( |
|
|
model=self.accelerator.unwrap_model(self.model), |
|
|
optimizer=self.optimizer, |
|
|
scheduler=self.scheduler, |
|
|
epoch=epoch, |
|
|
step=self.global_step, |
|
|
val_loss=val_loss, |
|
|
config=self.config, |
|
|
fusion_diagnostics=fusion_diagnostics, |
|
|
is_best=is_best |
|
|
) |
|
|
|
|
|
if self.config.hf_repo_id: |
|
|
self.checkpoint_manager.upload_checkpoint(ckpt_dir) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"\nβ οΈ Checkpoint save/upload failed: {e}") |
|
|
traceback.print_exc() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train(self): |
|
|
"""Full training loop with bulletproof checkpointing.""" |
|
|
print("\n" + "π " * 40) |
|
|
print("TRAINING START") |
|
|
print("π " * 40 + "\n") |
|
|
|
|
|
try: |
|
|
for epoch in range(self.start_epoch, self.config.num_epochs): |
|
|
self.current_epoch = epoch |
|
|
|
|
|
if self._interrupt_received: |
|
|
break |
|
|
|
|
|
print(f"\n{'π¨'*40}") |
|
|
print(f"EPOCH {epoch + 1}/{self.config.num_epochs}") |
|
|
print(f"{'π¨'*40}\n") |
|
|
|
|
|
pbar = tqdm( |
|
|
self.train_loader, |
|
|
desc=f"Epoch {epoch + 1}", |
|
|
disable=not self.accelerator.is_main_process |
|
|
) |
|
|
|
|
|
for batch in pbar: |
|
|
if self._interrupt_received: |
|
|
break |
|
|
|
|
|
metrics = self.train_step(batch) |
|
|
self.global_step += 1 |
|
|
|
|
|
if self.global_step % self.config.log_every == 0: |
|
|
pbar.set_postfix(metrics) |
|
|
self.log_metrics(metrics, prefix="train") |
|
|
|
|
|
|
|
|
if self.global_step % self.config.save_every == 0: |
|
|
self.save_checkpoint_and_upload(epoch, val_loss=float('inf'), is_best=False) |
|
|
|
|
|
if self.accelerator.is_main_process: |
|
|
print("\nπ Running validation...") |
|
|
val_metrics = self.validate(max_batches=50) |
|
|
self.log_metrics(val_metrics, prefix="val") |
|
|
print(f"β Val (with text) - Loss: {val_metrics['val_with_text_loss']:.4f}, Acc: {val_metrics['val_with_text_acc']:.4f}") |
|
|
print(f"β Val (vision-only) - Loss: {val_metrics['val_vision_only_loss']:.4f}, Acc: {val_metrics['val_vision_only_acc']:.4f}") |
|
|
|
|
|
|
|
|
if (self.config.hf_repo_id and |
|
|
self.global_step % self.config.hf_upload_every == 0): |
|
|
self.save_checkpoint_and_upload(epoch, val_loss=float('inf'), is_best=False) |
|
|
|
|
|
if self.accelerator.is_main_process: |
|
|
print("\nπ Running validation for upload...") |
|
|
val_metrics = self.validate(max_batches=50) |
|
|
print(f"β Val (with text) - Loss: {val_metrics['val_with_text_loss']:.4f}, Acc: {val_metrics['val_with_text_acc']:.4f}") |
|
|
print(f"β Val (vision-only) - Loss: {val_metrics['val_vision_only_loss']:.4f}, Acc: {val_metrics['val_vision_only_acc']:.4f}") |
|
|
|
|
|
if self._interrupt_received: |
|
|
break |
|
|
|
|
|
|
|
|
if self.accelerator.is_main_process: |
|
|
self.save_checkpoint_and_upload(epoch, val_loss=float('inf'), is_best=False) |
|
|
|
|
|
print("\nπ End of epoch validation...") |
|
|
val_metrics = self.validate(max_batches=100) |
|
|
|
|
|
print(f"\nπ Validation Results:") |
|
|
print(f" With Text:") |
|
|
print(f" Loss: {val_metrics['val_with_text_loss']:.4f}") |
|
|
print(f" Acc: {val_metrics['val_with_text_acc']:.4f}") |
|
|
print(f" Vision-Only (PRIMARY METRIC):") |
|
|
print(f" Loss: {val_metrics['val_vision_only_loss']:.4f}") |
|
|
print(f" Acc: {val_metrics['val_vision_only_acc']:.4f}") |
|
|
|
|
|
self.log_metrics(val_metrics, prefix="val") |
|
|
|
|
|
is_best = val_metrics['loss/val'] < self.best_val_loss |
|
|
if is_best: |
|
|
self.best_val_loss = val_metrics['loss/val'] |
|
|
print(f"\nπ New best (vision-only): {self.best_val_loss:.4f}") |
|
|
self.save_checkpoint_and_upload(epoch, val_metrics['loss/val'], is_best=True) |
|
|
|
|
|
fusion_diag = self.get_fusion_diagnostics() |
|
|
print(f"\nβ‘ Fusion Controller State:") |
|
|
print(f" Scale weights: {[f'{w:.3f}' for w in fusion_diag.get('scale_weights', [])]}") |
|
|
print(f" Alpha: {[f'{a:.3f}' for a in fusion_diag.get('alpha_per_scale', [])]}") |
|
|
print(f" Beta: {[f'{b:.3f}' for b in fusion_diag.get('beta_per_scale', [])]}") |
|
|
|
|
|
|
|
|
total = sum(self.text_dropout_stats.values()) or 1 |
|
|
print(f"\nπ Text Modality Distribution:") |
|
|
for mode, count in self.text_dropout_stats.items(): |
|
|
print(f" {mode}: {100*count/total:.1f}%") |
|
|
|
|
|
except KeyboardInterrupt: |
|
|
if not self._interrupt_received: |
|
|
self._interrupt_received = True |
|
|
if self._save_on_interrupt and self.accelerator.is_main_process: |
|
|
self._emergency_save_checkpoint() |
|
|
raise |
|
|
|
|
|
if not self._interrupt_received: |
|
|
print("\n" + "β
" * 40) |
|
|
print("TRAINING COMPLETE") |
|
|
print("β
" * 40) |
|
|
print(f"Best val loss (vision-only): {self.best_val_loss:.4f}") |
|
|
|
|
|
if self.accelerator.is_main_process: |
|
|
print(f"\nπ TensorBoard logs: {self.config.log_dir}") |
|
|
if self.config.hf_repo_id: |
|
|
print(f"π€ Model on HuggingFace: https://huggingface.co/{self.config.hf_repo_id}") |
|
|
|
|
|
print("β
" * 40 + "\n") |
|
|
|
|
|
if self.writer: |
|
|
self.writer.close() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
config = DanbooruTrainingConfig( |
|
|
|
|
|
sub_name="danbooru-50k-v1-512-2", |
|
|
|
|
|
|
|
|
num_opinion_anchors=225, |
|
|
pentachoron_dim=512, |
|
|
scales=[128, 256, 512, 1024], |
|
|
scale_hidden_dims={128: 256, 256: 512, 512: 1024, 1024: 2048}, |
|
|
|
|
|
|
|
|
alpha_init=0.125, |
|
|
alpha_learnable=True, |
|
|
beta_init=0.5, |
|
|
beta_learnable=True, |
|
|
gamma_learnable=True, |
|
|
learn_layer_weights=True, |
|
|
|
|
|
|
|
|
clip_skip=1, |
|
|
siglip_layer_indices=[1, 2, 3, 4, 5, 6, 9, 12, 18, 21, 23, 24, 25, 26], |
|
|
|
|
|
|
|
|
use_gradient_checkpointing=False, |
|
|
share_scale_embeddings=False, |
|
|
|
|
|
|
|
|
batch_size=24, |
|
|
num_epochs=20, |
|
|
learning_rate=1e-4, |
|
|
save_every=500, |
|
|
|
|
|
|
|
|
text_dropout_prob=0.3, |
|
|
text_noise_std=0.1, |
|
|
text_noise_prob=0.5, |
|
|
vision_only_text="general: blank_image", |
|
|
text_dropout_schedule="linear", |
|
|
text_dropout_start=0.1, |
|
|
text_dropout_end=0.5, |
|
|
|
|
|
|
|
|
resume=False, |
|
|
|
|
|
|
|
|
hf_repo_id="AbstractPhil/liminal-staircase-v2", |
|
|
hf_upload_every=1000, |
|
|
hf_private=False, |
|
|
) |
|
|
|
|
|
print("\n" + "π¨ " * 40) |
|
|
print("LIMINAL STAIRCASE - BULLETPROOF + GEOMETRIC + TEXT ROBUSTNESS") |
|
|
print("π¨ " * 40) |
|
|
print(f"\nSub name: {config.sub_name}") |
|
|
print(f"Scales: {config.scales}") |
|
|
print(f"SigLIP layers: {config.siglip_layer_indices}") |
|
|
print(f"CLIP skip: {config.clip_skip}") |
|
|
print(f"Geometric init: hybrid pentachora") |
|
|
print(f"\nπ· Text Modality Robustness:") |
|
|
print(f" Sentinel: '{config.vision_only_text}'") |
|
|
print(f" Dropout: {config.text_dropout_schedule} ({config.text_dropout_start:.0%} β {config.text_dropout_end:.0%})") |
|
|
print(f" Noise: {config.text_noise_prob:.0%} of text batches @ std={config.text_noise_std}") |
|
|
if config.hf_repo_id: |
|
|
print(f"\nπ€ HuggingFace: {config.hf_repo_id}") |
|
|
print("\n" + "π¨ " * 40 + "\n") |
|
|
|
|
|
trainer = DanbooruLiminalStaircaseTrainer(config) |
|
|
trainer.train() |