""" DavidCollective SD1.5 - Complete System with Pattern Supervision ================================================================ Integrates symbolic synthesis + proper pattern-supervised losses. Key features: - Symbolic caption synthesis - All 9 SD1.5 blocks - Full pattern supervision (1000 classes, not just 100) - Pattern diversity regularization - Three accuracy metrics (timestep, pattern, full) - Minimal disk usage - TensorBoard logging Author: AbstractPhil + Claude Sonnet 4.5 Run it in colab after installing the necessary repo. !pip install git+https://github.com/AbstractEyes/lattice_vocabulary.git """ import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm from pathlib import Path from typing import Dict, List, Optional, Tuple import time import json import numpy as np from datetime import datetime # Diffusers from diffusers import StableDiffusionPipeline # David imports from geovocab2.train.model.core.david_diffusion import ( DavidCollective, DavidCollectiveConfig, SD15_BLOCKS ) # Symbolic synthesis from geovocab2.data.prompt.symbolic_tree import SynthesisSystem # HuggingFace try: from huggingface_hub import HfApi, create_repo, upload_folder from safetensors.torch import save_file HF_AVAILABLE = True except ImportError: print("โš ๏ธ HuggingFace libraries not available. Install with:") print(" pip install huggingface_hub safetensors") HF_AVAILABLE = False # ============================================================================ # PROMPT LOGGER - Saves ALL prompts to JSONL # ============================================================================ class PromptLogger: """ Logs ALL prompts with metadata to JSONL. Flushes after every batch to prevent data loss. """ def __init__(self, output_path: str = "./prompts_all_epochs.jsonl"): self.output_path = Path(output_path) self.output_path.parent.mkdir(parents=True, exist_ok=True) # Create/truncate file with open(self.output_path, 'w') as f: f.write("") self.batch_count = 0 print(f"โœ“ PromptLogger initialized: {self.output_path}") def log_batch( self, prompts: List[str], timesteps: torch.Tensor, epoch: int, batch_idx: int, global_step: int ): """ Log a batch of prompts with metadata. Flushes immediately to prevent data loss. """ with open(self.output_path, 'a') as f: for i, (prompt, t) in enumerate(zip(prompts, timesteps)): entry = { 'timestamp': datetime.now().isoformat(), 'epoch': epoch, 'batch': batch_idx, 'global_step': global_step, 'sample_idx': i, 'timestep': int(t.item()), 'timestep_bin': int(t.item()) // 10, 'prompt': prompt } f.write(json.dumps(entry) + '\n') f.flush() # CRITICAL: Force write to disk self.batch_count += 1 if self.batch_count % 100 == 0: print(f" ๐Ÿ“ Logged {self.batch_count} batches ({self.batch_count * len(prompts):,} prompts)") def get_stats(self) -> dict: """Get statistics about logged prompts.""" if not self.output_path.exists(): return {'total': 0} with open(self.output_path, 'r') as f: lines = f.readlines() return { 'total': len(lines), 'size_mb': self.output_path.stat().st_size / 1024**2 } # ============================================================================ # PATTERN-SUPERVISED LOSS # ============================================================================ class PatternSupervisedLoss(nn.Module): """ Pattern-supervised loss with full 1000-class supervision. Supervises all 1000 classes (100 timesteps ร— 10 patterns). """ def __init__( self, num_timestep_bins: int = 100, num_patterns_per_timestep: int = 10, feature_similarity_weight: float = 0.5, rose_weight: float = 0.3, ce_weight: float = 0.2, pattern_diversity_weight: float = 0.05, use_soft_assignment: bool = True, temperature: float = 0.1 ): super().__init__() self.num_bins = num_timestep_bins self.num_patterns = num_patterns_per_timestep self.num_classes = num_timestep_bins * num_patterns_per_timestep self.feature_sim_weight = feature_similarity_weight self.rose_weight = rose_weight self.ce_weight = ce_weight self.pattern_diversity_weight = pattern_diversity_weight self.use_soft_assignment = use_soft_assignment self.temperature = temperature def assign_patterns( self, features: torch.Tensor, timestep_class: torch.Tensor, crystal_centroids: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ Assign samples to nearest pattern within their timestep bin. FIXED: Uses COSINE SIMILARITY (not Euclidean distance) to match original trainer. Args: features: [B, D] timestep_class: [B] - timestep bins [0, num_bins) crystal_centroids: [num_bins, num_patterns, D] Returns: pattern_ids: [B] - pattern indices [0, num_patterns) full_class_ids: [B] - full class [0, num_classes) """ B = features.shape[0] # Get centroids for each sample's timestep batch_centroids = crystal_centroids[timestep_class] # [B, num_patterns, D] # Compute similarities (CRITICAL: Use cosine, not Euclidean!) features_expanded = features.unsqueeze(1) # [B, 1, D] similarities = F.cosine_similarity( features_expanded, batch_centroids, dim=2 ) # [B, num_patterns] # Assign to nearest (highest similarity) pattern_ids = similarities.argmax(dim=1) full_class_ids = timestep_class * self.num_patterns + pattern_ids return pattern_ids, full_class_ids def compute_soft_assignment( self, features: torch.Tensor, timestep_class: torch.Tensor, crystal_centroids: torch.Tensor ) -> torch.Tensor: """ Compute soft pattern assignment with temperature smoothing. MATCHES ORIGINAL: Lines 120-156 Args: features: [B, D] timestep_class: [B] - timestep bins crystal_centroids: [num_bins, num_patterns, D] Returns: soft_targets: [B, num_classes] - soft target distribution """ B, D = features.shape device = features.device # Get centroids for each sample's timestep bin batch_centroids = crystal_centroids[timestep_class] # [B, num_patterns, D] features_expanded = features.unsqueeze(1) # [B, 1, D] # Compute cosine similarities similarities = F.cosine_similarity( features_expanded, batch_centroids, dim=2 ) # [B, num_patterns] # Soft assignment with temperature pattern_probs = F.softmax(similarities / self.temperature, dim=1) # Create full soft targets [B, num_classes] soft_targets = torch.zeros(B, self.num_classes, device=device) for i in range(B): t = timestep_class[i] start_idx = t * self.num_patterns end_idx = start_idx + self.num_patterns soft_targets[i, start_idx:end_idx] = pattern_probs[i] return soft_targets def compute_pattern_diversity_loss( self, logits: torch.Tensor, timestep_class: torch.Tensor ) -> torch.Tensor: """ Encourage diverse pattern usage (prevent mode collapse). MATCHES ORIGINAL: Lines 157-182 """ B = logits.shape[0] # For each sample, get pattern probs within its timestep pattern_probs_list = [] for i in range(B): t = timestep_class[i] start_idx = t * self.num_patterns end_idx = start_idx + self.num_patterns probs = F.softmax(logits[i, start_idx:end_idx], dim=0) pattern_probs_list.append(probs) pattern_probs = torch.stack(pattern_probs_list) # [B, num_patterns] # Entropy (higher = more diverse) entropy = -(pattern_probs * torch.log(pattern_probs + 1e-8)).sum(dim=1).mean() # Minimize negative entropy (maximize diversity) return -entropy def forward( self, student_features: torch.Tensor, teacher_features: torch.Tensor, student_logits: torch.Tensor, crystal_centroids: torch.Tensor, timesteps: torch.Tensor ) -> Tuple[torch.Tensor, Dict]: """ Compute full loss with pattern supervision. Returns: total_loss: Combined weighted loss metrics: Dict of individual metrics """ # Timestep classification (0-999 -> 0-99 bins) timestep_class = (timesteps // 10).clamp(0, self.num_bins - 1) # Pattern assignment (use STUDENT features, not teacher!) pattern_ids, full_class_ids = self.assign_patterns( student_features, # โœ“ FIXED: Use student features [B, D] timestep_class, crystal_centroids ) # Get target centroids for assigned patterns target_centroids = torch.stack([ crystal_centroids[timestep_class[j], pattern_ids[j]] for j in range(len(timestep_class)) ]) # 1. Feature similarity loss (student vs target centroids) feature_sim_loss = 1.0 - F.cosine_similarity( student_features, target_centroids, # โœ“ FIXED: Compare to centroids, not teacher dim=-1 ).mean() # 2. Rose loss (MATCHES ORIGINAL: Same as feature_sim_loss!) # Original trainer line 609: rose_loss = feature_sim_loss rose_loss = feature_sim_loss # โœ“ Simple copy, not contrastive learning! # 3. Cross-entropy with soft assignment (MATCHES ORIGINAL) # Original trainer lines 612-617 if self.use_soft_assignment: soft_targets = self.compute_soft_assignment( student_features, timestep_class, crystal_centroids ) log_probs = F.log_softmax(student_logits, dim=1) ce_loss = -(soft_targets * log_probs).sum(dim=1).mean() else: ce_loss = F.cross_entropy(student_logits, full_class_ids) # 4. Pattern diversity (MATCHES ORIGINAL: lines 622-623) diversity_loss = self.compute_pattern_diversity_loss( student_logits, timestep_class ) # Total loss total_loss = ( self.feature_sim_weight * feature_sim_loss + self.rose_weight * rose_loss + self.ce_weight * ce_loss + self.pattern_diversity_weight * diversity_loss ) # Accuracy metrics timestep_pred = student_logits.argmax(dim=-1) // self.num_patterns pattern_pred = student_logits.argmax(dim=-1) % self.num_patterns full_pred = student_logits.argmax(dim=-1) timestep_acc = (timestep_pred == timestep_class).float().mean() pattern_acc = (pattern_pred == pattern_ids).float().mean() full_acc = (full_pred == full_class_ids).float().mean() metrics = { 'feature_sim': feature_sim_loss.item(), 'rose': rose_loss.item(), 'ce': ce_loss.item(), 'pattern_diversity': diversity_loss.item(), 'timestep_acc': timestep_acc.item(), 'pattern_acc': pattern_acc.item(), 'full_acc': full_acc.item() } return total_loss, metrics # ============================================================================ # CONFIG # ============================================================================ FULL_CONFIG = DavidCollectiveConfig( # Timestep discretization num_timestep_bins=100, num_feature_patterns_per_timestep=10, # CORRECT parameter name # Active blocks (all 9) active_blocks=['down_0', 'down_1', 'down_2', 'down_3', 'mid', 'up_0', 'up_1', 'up_2', 'up_3'], # David architecture david_sharing_mode='fully_shared', david_fusion_mode='deep_efficiency', use_belly=True, belly_expand=1.5, # Loss weights feature_similarity_weight=0.5, rose_weight=0.3, cayley_weight=0.0, ce_weight=0.2, # Geometric constraints rose_margin=1.0, rose_temperature=0.07, cayley_volume_floor=1e-4, # Progressive training progressive_training=True, warmup_epochs_per_block=2, # No caching cache_dir=None, max_cache_size_gb=0.0 ) # ============================================================================ # SD1.5 EXTRACTOR - FIXED (NO POOLING) # ============================================================================ class StreamingSD15Extractor: """ Extract features from SD1.5 UNet. CRITICAL: Returns spatial features [B, C, H, W] David's Companions will handle pooling internally. """ def __init__( self, model_id: str = "runwayml/stable-diffusion-v1-5", device: str = "cuda", active_blocks: List[str] = None ): self.device = device self.active_blocks = active_blocks or FULL_CONFIG.active_blocks print(f"Loading SD1.5 from {model_id}...") self.pipe = StableDiffusionPipeline.from_pretrained( model_id, torch_dtype=torch.float16, safety_checker=None, requires_safety_checker=False ).to(device) self.unet = self.pipe.unet self.vae = self.pipe.vae self.text_encoder = self.pipe.text_encoder self.tokenizer = self.pipe.tokenizer self.scheduler = self.pipe.scheduler self.features = {} self.hooks = [] self.block_mapping = { 'down_0': ('down_blocks', 0), 'down_1': ('down_blocks', 1), 'down_2': ('down_blocks', 2), 'down_3': ('down_blocks', 3), 'mid': ('mid_block', None), 'up_0': ('up_blocks', 0), 'up_1': ('up_blocks', 1), 'up_2': ('up_blocks', 2), 'up_3': ('up_blocks', 3), } print(f"โœ“ SD1.5 loaded on {device}") def _register_hooks(self): def make_hook(name): def hook(module, input, output): # CRITICAL: Store WITH spatial dimensions self.features[name] = output.detach().float() return hook self._remove_hooks() for block_name in self.active_blocks: block_type, idx = self.block_mapping[block_name] if block_type == 'down_blocks': block = self.unet.down_blocks[idx] if hasattr(block, 'resnets') and len(block.resnets) > 0: hook = block.resnets[-1].register_forward_hook(make_hook(block_name)) self.hooks.append(hook) elif block_type == 'mid_block': hook = self.unet.mid_block.register_forward_hook(make_hook(block_name)) self.hooks.append(hook) elif block_type == 'up_blocks': block = self.unet.up_blocks[idx] if hasattr(block, 'resnets') and len(block.resnets) > 0: hook = block.resnets[-1].register_forward_hook(make_hook(block_name)) self.hooks.append(hook) def _remove_hooks(self): for hook in self.hooks: hook.remove() self.hooks = [] @torch.no_grad() def extract_batch( self, prompts: List[str], timesteps: torch.Tensor ) -> Dict[str, torch.Tensor]: """ Extract features from SD1.5 UNet. CRITICAL: Returns spatial features [B, C, H, W] NO POOLING - Companions handle it internally """ self._register_hooks() text_inputs = self.tokenizer( prompts, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt" ).to(self.device) text_embeddings = self.text_encoder(text_inputs.input_ids)[0] B = len(prompts) latents = torch.randn(B, 4, 64, 64, device=self.device, dtype=torch.float16) for i, t in enumerate(timesteps): noise = torch.randn_like(latents[i:i+1]) latents[i:i+1] = self.scheduler.add_noise( latents[i:i+1], noise, t.unsqueeze(0) ) self.features = {} _ = self.unet( latents, timesteps.to(self.device), encoder_hidden_states=text_embeddings ).sample self._remove_hooks() # Return features WITH spatial dimensions [B, C, H, W] # NO POOLING HERE - Companions will handle it return self.features.copy() def __del__(self): self._remove_hooks() # ============================================================================ # SYMBOLIC PROMPT DATASET # ============================================================================ class SymbolicPromptDataset(Dataset): """Generate prompts on-the-fly using synthesis system.""" def __init__( self, num_samples: int = 10000, complexity_distribution: Optional[Dict[int, float]] = None, bias_weights_path: Optional[str] = None, seed: Optional[int] = None, log_synthesis_stats: bool = False ): self.num_samples = num_samples self.log_synthesis_stats = log_synthesis_stats if complexity_distribution is None: complexity_distribution = { 1: 0.05, 2: 0.15, 3: 0.40, 4: 0.25, 5: 0.15 } self.complexity_dist = complexity_distribution # Initialize synthesis system (no seed parameter) self.synth = SynthesisSystem() # Apply bias weights if provided if bias_weights_path and Path(bias_weights_path).exists(): with open(bias_weights_path, 'r') as f: bias_weights = json.load(f) # Apply bias weights to synthesis system # (assuming it has some method to set them) if hasattr(self.synth, 'bias_weights'): self.synth.bias_weights = bias_weights # Pre-generate prompts self.rng = np.random.RandomState(seed) self.prompts = [] self.metadata = [] print(f"Generating {num_samples:,} prompts...") for i in range(num_samples): # Sample complexity complexities = list(complexity_distribution.keys()) probs = list(complexity_distribution.values()) complexity = self.rng.choice(complexities, p=probs) # Generate prompt try: result = self.synth.synthesize(complexity=complexity) # โœ… Correct method name # Extract text and path_info from result dict if isinstance(result, dict): prompt = result.get('text', 'a photo') path_info = result.get('selected_paths', []) else: # Fallback if unexpected format prompt = str(result) path_info = {} self.prompts.append(prompt) if log_synthesis_stats: self.metadata.append({ 'complexity': complexity, 'path_info': path_info, 'sample_id': i }) except Exception as e: # Fallback prompt if generation fails print(f" โš ๏ธ Warning: Failed to generate prompt {i}: {e}") self.prompts.append("a photo") if log_synthesis_stats: self.metadata.append({ 'complexity': complexity, 'path_info': {}, 'sample_id': i, 'error': str(e) }) if (i + 1) % 1000 == 0: print(f" Generated {i+1:,}/{num_samples:,} prompts...") print(f"โœ“ Generated {len(self.prompts):,} prompts") if log_synthesis_stats: self._log_statistics() def _log_statistics(self): from collections import Counter complexity_counts = Counter(m['complexity'] for m in self.metadata) print("\nSynthesis Statistics:") print(" Complexity distribution:") for complexity in sorted(complexity_counts.keys()): count = complexity_counts[complexity] pct = 100 * count / len(self.metadata) print(f" Complexity {complexity}: {count:,} ({pct:.1f}%)") print("\n Example prompts:") for i in [0, len(self.prompts)//4, len(self.prompts)//2, 3*len(self.prompts)//4]: complexity = self.metadata[i]['complexity'] print(f" [C={complexity}] {self.prompts[i][:80]}...") def __len__(self): return self.num_samples def __getitem__(self, idx): prompt = self.prompts[idx] timestep = self.rng.randint(0, 1000) metadata = self.metadata[idx] if self.log_synthesis_stats else {} return { 'prompt': prompt, 'timestep': torch.tensor(timestep), 'metadata': metadata } def collate_symbolic_batch(batch): prompts = [item['prompt'] for item in batch] timesteps = torch.stack([item['timestep'] for item in batch]) metadata = [item['metadata'] for item in batch] return prompts, timesteps, metadata # ============================================================================ # HUGGINGFACE UTILITIES # ============================================================================ def convert_to_safetensors(checkpoint_path: str) -> str: """Convert .pt checkpoint to .safetensors format.""" if not HF_AVAILABLE: print("โš ๏ธ Safetensors not available, skipping conversion") return None checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False) # Handle different checkpoint formats if isinstance(checkpoint, dict): if 'model_state_dict' in checkpoint: state_dict = checkpoint['model_state_dict'] elif 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] else: # Assume the dict IS the state_dict state_dict = checkpoint else: raise ValueError(f"Unexpected checkpoint format: {type(checkpoint)}") output_path = checkpoint_path.replace('.pt', '.safetensors') print(f" Converting to safetensors: {output_path}") # Ensure all tensors are contiguous state_dict_safe = { k: v.contiguous() for k, v in state_dict.items() } # Save as safetensors save_file(state_dict_safe, output_path) size_mb = Path(output_path).stat().st_size / 1024**2 print(f" โœ“ Saved safetensors ({size_mb:.2f} MB)") return output_path def create_readme( final_epoch: int, history: dict, config: DavidCollectiveConfig, total_prompts: int ) -> str: """Generate comprehensive README for HuggingFace model card.""" readme = f"""# David Collective - SD1.5 Geometric Distillation (Continued) ## Model Description **David Collective** is a revolutionary geometric deep learning system that distills Stable Diffusion 1.5's knowledge into an ultra-efficient pentachoron-based architecture. This model was continued from epoch 20 to epoch {final_epoch}, achieving remarkable performance with full pattern supervision. ### Architecture Highlights - **Geometric Foundation**: Uses 5D pentachora (5-vertex simplices) instead of traditional attention - **Multi-Scale Learning**: Extracts features from all 9 SD1.5 UNet blocks - **Crystal Navigation**: 1000-class supervision (100 timesteps ร— 10 geometric patterns) - **Parameter Efficiency**: Ultra-compact architecture with shared geometric structures - **Full Supervision**: Every sample supervised by both timestep and geometric pattern ### Training Details **Continuation Training:** - Starting epoch: 20 - Final epoch: {final_epoch} - Total prompts trained: {total_prompts:,} - **All prompts included**: `prompts_all_epochs.jsonl` contains every prompt with metadata - Dataset: Symbolic caption synthesis (complexity 1-5) - Batch size: 32 - Learning rate: 1e-4 with cosine annealing - Optimizer: AdamW (weight_decay=0.01) **Final Metrics (Epoch {final_epoch}):** - Total Loss: {history.get('total_loss', [0])[-1] if history.get('total_loss') else 0:.4f} - Timestep Accuracy: {history.get('timestep_accuracy', [0])[-1] if history.get('timestep_accuracy') else 0:.2%} - Pattern Accuracy: {history.get('pattern_accuracy', [0])[-1] if history.get('pattern_accuracy') else 0:.2%} - Full Accuracy: {history.get('full_accuracy', [0])[-1] if history.get('full_accuracy') else 0:.2%} - Pattern Diversity: {history.get('pattern_diversity', [0])[-1] if history.get('pattern_diversity') else 0:.3f} ### Active Blocks David learns from all 9 SD1.5 UNet blocks: - `down_0`, `down_1`, `down_2`, `down_3`: Coarse semantic features - `mid`: Bottleneck representations - `up_0`, `up_1`, `up_2`, `up_3`: Fine reconstruction details ### Loss Components 1. **Feature Similarity** ({config.feature_similarity_weight}): Cosine similarity with teacher 2. **Rose Loss** ({config.rose_weight}): Geometric alignment with crystal centroids 3. **Cross-Entropy** ({config.ce_weight}): 1000-class classification 4. **Pattern Diversity** (0.05): Encourages balanced pattern usage ## Usage ### Loading the Model ```python import torch from geovocab2.train.model.core.david_diffusion import DavidCollective, DavidCollectiveConfig from safetensors.torch import load_file # Load configuration config = DavidCollectiveConfig( num_timestep_bins=100, num_feature_patterns_per_timestep=10, active_blocks={config.active_blocks}, david_sharing_mode='fully_shared', david_fusion_mode='deep_efficiency', use_belly=True, belly_expand=1.5 ) # Create model model = DavidCollective(config) # Load weights from safetensors state_dict = load_file("model.safetensors") model.load_state_dict(state_dict) model.eval() # Inference with torch.no_grad(): outputs = model(teacher_features, timesteps) ``` ### Training Data This model includes `prompts_all_epochs.jsonl` - every single prompt used during training with full metadata: ```json {{"timestamp": "2025-10-27T01:30:00", "epoch": 21, "batch": 0, "global_step": 6250, "sample_idx": 0, "timestep": 453, "timestep_bin": 45, "prompt": "a woman wearing red dress, against mountain landscape"}} ``` **Total prompts:** {total_prompts:,} You can use this to: - Analyze training data distribution - Reproduce training - Study prompt complexity vs model performance - Generate similar synthetic datasets ## Technical Details ### Crystal System - **Architecture**: Pentachoron-based geometric deep learning - **Centroids**: 100 timestep bins ร— 10 patterns = 1000 anchors - **Navigation**: Samples assigned to nearest pattern within timestep bin - **Diversity**: Regularization prevents mode collapse ### Progressive Training - Started with early blocks (down_0, down_1) - Progressively activated all 9 blocks - Each block warmed up for {config.warmup_epochs_per_block} epochs ### Pattern Supervision Unlike traditional timestep-only supervision, David learns: 1. **When** (timestep bin 0-99) 2. **How** (geometric pattern 0-9 within that bin) 3. **Combined** (full 1000-class space) This provides 10x finer-grained supervision of the diffusion process. ## Training History Trained continuously from epoch 20 to epoch {final_epoch}. See metrics: - Timestep accuracy improved from ~{history.get('timestep_accuracy', [0])[0] if history.get('timestep_accuracy') else 0:.1%} to {history.get('timestep_accuracy', [0])[-1] if history.get('timestep_accuracy') else 0:.2%} - Pattern accuracy maintained at {history.get('pattern_accuracy', [0])[-1] if history.get('pattern_accuracy') else 0:.2%} - Loss decreased from {history.get('total_loss', [0])[0] if history.get('total_loss') else 0:.4f} to {history.get('total_loss', [0])[-1] if history.get('total_loss') else 0:.4f} ## Citation ```bibtex @misc{{david-collective-sd15, title={{David Collective: Geometric Deep Learning for Diffusion Distillation}}, author={{AbstractPhil}}, year={{2025}}, publisher={{HuggingFace}}, howpublished={{\\url{{https://huggingface.co/AbstractPhil/david-collective-sd15-geometric-distillation}}}} }} ``` ## License MIT License - See repository for details. ## Acknowledgments Built on the geometric deep learning research by AbstractPhil, using: - Stable Diffusion 1.5 (teacher model) - Pentachoron-based geometric algebra - Crystalline consciousness architectures - Symbolic caption synthesis For more information, visit the [geovocab2 repository](https://github.com/AbstractEyes/lattice_vocabulary). """ return readme def create_model_card(config: DavidCollectiveConfig) -> dict: """Create model card metadata for HuggingFace.""" return { 'language': ['en'], 'license': 'mit', 'tags': [ 'geometric-deep-learning', 'diffusion-distillation', 'stable-diffusion', 'pentachoron', 'crystal-navigation', 'pattern-supervision', 'ultra-efficient', 'sd15-distillation' ], 'datasets': ['synthetic-captions'], 'metrics': ['accuracy', 'loss'], 'library_name': 'pytorch', 'pipeline_tag': 'image-classification', } def upload_to_huggingface( model_path: str, repo_name: str = "AbstractPhil/david-collective-sd15-geometric-distillation", final_epoch: int = 50, history: dict = None, config: DavidCollectiveConfig = None, total_prompts: int = 0, private: bool = False ): """Upload model to HuggingFace Hub with README and model card.""" if not HF_AVAILABLE: print("\nโš ๏ธ HuggingFace libraries not available") print("Install with: pip install huggingface_hub safetensors") return None print(f"\n{'='*80}") print("UPLOADING TO HUGGINGFACE") print(f"{'='*80}\n") # Convert to safetensors print("[1/5] Converting to safetensors...") safetensors_path = convert_to_safetensors(model_path) if not safetensors_path: print("โŒ Conversion failed, aborting upload") return None # Create temporary upload directory print("\n[2/5] Preparing upload directory...") upload_dir = Path("./hf_upload_temp") upload_dir.mkdir(exist_ok=True) # Copy safetensors import shutil shutil.copy(safetensors_path, upload_dir / "model.safetensors") print(f" โœ“ Copied model.safetensors") # Copy prompts file (CRITICAL!) prompt_file = Path("./prompts_all_epochs.jsonl") if prompt_file.exists(): shutil.copy(prompt_file, upload_dir / "prompts_all_epochs.jsonl") print(f" โœ“ Copied prompts_all_epochs.jsonl ({prompt_file.stat().st_size / 1024**2:.2f} MB)") else: print(f" โš ๏ธ Warning: prompts_all_epochs.jsonl not found, skipping") # Generate README print("\n[3/5] Generating README...") readme_content = create_readme(final_epoch, history, config, total_prompts) (upload_dir / "README.md").write_text(readme_content) print(f" โœ“ Created README.md") # Generate model card print("\n[4/5] Creating model card...") model_card = create_model_card(config) (upload_dir / "model_card.json").write_text(json.dumps(model_card, indent=2)) print(f" โœ“ Created model_card.json") # Save config config_dict = {k: v for k, v in config.__dict__.items() if not k.startswith('_')} (upload_dir / "config.json").write_text(json.dumps(config_dict, indent=2)) print(f" โœ“ Created config.json") # Upload print(f"\n[5/5] Uploading to {repo_name}...") try: api = HfApi() # Create repo if doesn't exist try: create_repo(repo_name, private=private, exist_ok=True) print(f" โœ“ Repository ready") except Exception as e: print(f" โš ๏ธ Repo might already exist: {e}") # Upload folder api.upload_folder( folder_path=str(upload_dir), repo_id=repo_name, repo_type="model", commit_message=f"Upload continuation training (epoch {final_epoch})" ) print(f"\nโœ… UPLOAD COMPLETE!") print(f"\n๐Ÿ”— View your model: https://huggingface.co/{repo_name}") print(f"\n๐Ÿ“ฆ Uploaded files:") print(f" - model.safetensors") print(f" - prompts_all_epochs.jsonl ({total_prompts:,} prompts)") print(f" - README.md (with metrics)") print(f" - config.json") print(f" - model_card.json") # Cleanup shutil.rmtree(upload_dir) print(f"\n๐Ÿงน Cleaned up temporary files") return f"https://huggingface.co/{repo_name}" except Exception as e: print(f"\nโŒ Upload failed: {e}") print(f"Files are still in: {upload_dir}") print(f"You can upload manually or fix the error and retry") return None # ============================================================================ # CONTINUATION TRAINING # ============================================================================ def continue_training( collective: DavidCollective, extractor: StreamingSD15Extractor, dataloader: DataLoader, start_epoch: int, num_epochs: int, device: str = "cuda", log_dir: str = "./runs/david_continued", prompt_log_path: str = "./prompts_all_epochs.jsonl", checkpoint_interval: int = 5, auto_upload: bool = True, hf_repo_name: str = "AbstractPhil/david-collective-sd15-geometric-distillation" ): """ Continue training from checkpoint with full logging and HuggingFace upload. Args: collective: DavidCollective model extractor: SD1.5 feature extractor dataloader: Training data start_epoch: Epoch to start from (loaded from checkpoint) num_epochs: Additional epochs to train device: Device to train on log_dir: TensorBoard log directory prompt_log_path: Where to save all prompts checkpoint_interval: Save checkpoint every N epochs auto_upload: Automatically upload to HuggingFace after training hf_repo_name: HuggingFace repository name """ print("\n" + "="*80) print("DAVID COLLECTIVE - CONTINUATION TRAINING") print("="*80) # Load checkpoint print(f"\n[1/7] Loading checkpoint...") checkpoint_path = Path("david_collective_continued_final.pt") if not checkpoint_path.exists(): raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") # Capture initial weights for verification print(f" Capturing initial weights...") initial_sample_key = list(collective.state_dict().keys())[0] initial_sample_weight = collective.state_dict()[initial_sample_key].clone() initial_mean = initial_sample_weight.mean().item() checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) # Handle different checkpoint formats if isinstance(checkpoint, dict): # Check what keys are present print(f" Checkpoint keys: {list(checkpoint.keys())}") if 'model_state_dict' in checkpoint: state_dict = checkpoint['model_state_dict'] actual_epoch = checkpoint.get('epoch', start_epoch) elif 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] actual_epoch = checkpoint.get('epoch', start_epoch) else: # Assume the dict IS the state_dict state_dict = checkpoint actual_epoch = start_epoch # Load with strict=False to see any issues print(f" Loading state dict ({len(state_dict)} parameters)...") missing_keys, unexpected_keys = collective.load_state_dict(state_dict, strict=False) if missing_keys: print(f" โš ๏ธ Missing keys ({len(missing_keys)}): {missing_keys[:3]}...") if unexpected_keys: print(f" โš ๏ธ Unexpected keys ({len(unexpected_keys)}): {unexpected_keys[:3]}...") # Verify weights actually changed final_sample_weight = collective.state_dict()[initial_sample_key] final_mean = final_sample_weight.mean().item() if torch.equal(initial_sample_weight, final_sample_weight): raise RuntimeError( f"โŒ CRITICAL: Weights did NOT change after loading!\n" f" Sample param: {initial_sample_key}\n" f" This means the checkpoint is not being loaded properly." ) print(f" โœ“ Weights verified changed (sample mean: {initial_mean:.6f} -> {final_mean:.6f})") print(f" โœ“ Loaded from epoch {actual_epoch}") else: # Not a dict - shouldn't happen but handle it raise ValueError(f"Unexpected checkpoint format: {type(checkpoint)}") # Model info total_params = sum(p.numel() for p in collective.parameters()) print(f"\n Model Status:") print(f" Parameters: {total_params:,}") print(f" Active blocks: {len(collective.config.active_blocks)}") print(f" Companions: {list(collective.companions.keys())}") # Prompt logger print(f"\n[2/7] Initializing prompt logger...") prompt_logger = PromptLogger(prompt_log_path) print(f" โœ“ Saving to: {prompt_log_path}") # Loss function print(f"\n[3/7] Setting up loss function...") criterion = PatternSupervisedLoss( num_timestep_bins=collective.config.num_timestep_bins, num_patterns_per_timestep=collective.config.num_feature_patterns_per_timestep, # CORRECT attribute name feature_similarity_weight=0.5, rose_weight=0.3, ce_weight=0.2, pattern_diversity_weight=0.05 ).to(device) print(f" โœ“ Pattern-supervised loss ready") # Optimizer print(f"\n[4/7] Creating optimizer...") optimizer = torch.optim.AdamW( collective.parameters(), lr=1e-4, weight_decay=0.01 ) # Scheduler scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=num_epochs * len(dataloader), eta_min=1e-6 ) print(f" โœ“ AdamW + Cosine annealing") # TensorBoard print(f"\n[5/7] Setting up TensorBoard...") writer = SummaryWriter(log_dir) print(f" โœ“ Logging to: {log_dir}") # Training print(f"\n[6/7] Starting training...") print(f" Epochs: {start_epoch + 1} โ†’ {start_epoch + num_epochs}") print(f" Batches per epoch: {len(dataloader)}") print() collective.train() # Initialize history (match original lines 519-528) history = { 'total_loss': [], 'feature_sim': [], 'rose': [], 'ce': [], 'pattern_diversity': [], 'timestep_accuracy': [], 'pattern_accuracy': [], 'full_accuracy': [] } global_step = start_epoch * len(dataloader) for epoch in range(start_epoch, start_epoch + num_epochs): collective.update_epoch(epoch) # Match original line 534 epoch_metrics = { 'total_loss': 0.0, 'feature_sim': 0.0, 'rose': 0.0, 'ce': 0.0, 'pattern_diversity': 0.0, 'timestep_accuracy': 0.0, 'pattern_accuracy': 0.0, 'full_accuracy': 0.0, 'num_batches': 0 } pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{start_epoch+num_epochs}") for batch_idx, (prompts, timesteps, metadata) in enumerate(pbar): batch_start = time.time() # Extract features (WITH spatial dimensions!) teacher_features = extractor.extract_batch(prompts, timesteps) teacher_features = { k: v.to(device) for k, v in teacher_features.items() } timesteps = timesteps.to(device) # Verify shapes on first batch if batch_idx == 0 and epoch == start_epoch: print(f"\n๐Ÿ” Feature shapes (Epoch {epoch+1}, Batch 1):") for k, v in teacher_features.items(): print(f" {k}: {v.shape}") if v.dim() != 4: raise ValueError( f"Expected 4D features [B,C,H,W], got {v.dim()}D for {k}!" ) print() # Forward outputs = collective(teacher_features, timesteps) # Compute loss total_loss = torch.tensor(0.0, device=device) block_metrics = {k: [] for k in ['feature_sim', 'rose', 'ce', 'pattern_diversity', 'timestep_acc', 'pattern_acc', 'full_acc']} for block_name in collective.companions.keys(): if block_name not in outputs or block_name not in teacher_features: continue # Use EXACT same structure as original trainer (lines 589-592) companion = collective.companions[block_name] block_output = outputs[block_name] # Get features and timestep class FROM DavidCollective output student_features = block_output['scale_features'][0] # First scale student_logits = block_output['combined_logits'] timestep_class = block_output['timestep_class'] # โœ“ FROM OUTPUT, not recomputed! # Get crystal centroids (lines 585-587 from original) crystal_anchors = companion.crystal_anchors # [bins, patterns, 5, max_scale] scale = companion.david_config.scales[0] # Use first scale crystal_centroids = crystal_anchors[..., :scale].mean(dim=2) # [bins, patterns, scale] # INLINE LOSS COMPUTATION (like original, NOT using forward()) # Assign patterns _, full_class_ids = criterion.assign_patterns( student_features, timestep_class, crystal_centroids ) # Feature similarity loss pattern_ids = full_class_ids % criterion.num_patterns target_centroids = torch.stack([ crystal_centroids[timestep_class[j], pattern_ids[j]] for j in range(len(timestep_class)) ]) cos_sim = F.cosine_similarity(student_features, target_centroids, dim=-1) feature_sim_loss = (1 - cos_sim).mean() # Rose loss (same as feature sim) rose_loss = feature_sim_loss # Cross-entropy with pattern supervision if criterion.use_soft_assignment: soft_targets = criterion.compute_soft_assignment( student_features, timestep_class, crystal_centroids ) log_probs = F.log_softmax(student_logits, dim=1) ce_loss = -(soft_targets * log_probs).sum(dim=1).mean() else: ce_loss = F.cross_entropy(student_logits, full_class_ids) # Pattern diversity diversity_loss = criterion.compute_pattern_diversity_loss( student_logits, timestep_class ) # Combined loss for this block block_loss = ( criterion.feature_sim_weight * feature_sim_loss + criterion.rose_weight * rose_loss + criterion.ce_weight * ce_loss + criterion.pattern_diversity_weight * diversity_loss ) total_loss = total_loss + block_loss # Accuracies (lines 626-642 from original) pred_class = student_logits.argmax(dim=1) pred_timestep = pred_class // criterion.num_patterns pred_pattern = pred_class % criterion.num_patterns true_pattern = full_class_ids % criterion.num_patterns timestep_acc = (pred_timestep == timestep_class).float().mean() correct_timestep_mask = (pred_timestep == timestep_class) if correct_timestep_mask.sum() > 0: pattern_acc = ( pred_pattern[correct_timestep_mask] == true_pattern[correct_timestep_mask] ).float().mean() else: pattern_acc = torch.tensor(0.0, device=device) full_acc = (pred_class == full_class_ids).float().mean() # Collect metrics block_metrics['feature_sim'].append(feature_sim_loss.item()) block_metrics['rose'].append(rose_loss.item()) block_metrics['ce'].append(ce_loss.item()) block_metrics['pattern_diversity'].append(diversity_loss.item()) block_metrics['timestep_acc'].append(timestep_acc.item()) block_metrics['pattern_acc'].append(pattern_acc.item()) block_metrics['full_acc'].append(full_acc.item()) # Average across blocks (from original trainer lines 664-668) num_processed_blocks = len([k for k in outputs.keys() if k in collective.companions]) if num_processed_blocks > 0: total_loss = total_loss / num_processed_blocks # Backward optimizer.zero_grad() total_loss.backward() torch.nn.utils.clip_grad_norm_(collective.parameters(), 1.0) optimizer.step() scheduler.step() # Log prompts (CRITICAL!) prompt_logger.log_batch(prompts, timesteps, epoch + 1, batch_idx, global_step) # Aggregate metrics (match original lines 677-686) epoch_metrics['total_loss'] += total_loss.item() epoch_metrics['feature_sim'] += np.mean(block_metrics['feature_sim']) epoch_metrics['rose'] += np.mean(block_metrics['rose']) epoch_metrics['ce'] += np.mean(block_metrics['ce']) epoch_metrics['pattern_diversity'] += np.mean(block_metrics['pattern_diversity']) epoch_metrics['timestep_accuracy'] += np.mean(block_metrics['timestep_acc']) epoch_metrics['pattern_accuracy'] += np.mean(block_metrics['pattern_acc']) epoch_metrics['full_accuracy'] += np.mean(block_metrics['full_acc']) epoch_metrics['num_batches'] += 1 # TensorBoard logging (match original lines 688-696) writer.add_scalar('Train/Total_Loss', total_loss.item(), global_step) writer.add_scalar('Train/Feature_Similarity', np.mean(block_metrics['feature_sim']), global_step) writer.add_scalar('Train/Rose_Loss', np.mean(block_metrics['rose']), global_step) writer.add_scalar('Train/CE_Loss', np.mean(block_metrics['ce']), global_step) writer.add_scalar('Train/Pattern_Diversity', np.mean(block_metrics['pattern_diversity']), global_step) writer.add_scalar('Train/Timestep_Accuracy', np.mean(block_metrics['timestep_acc']), global_step) writer.add_scalar('Train/Pattern_Accuracy', np.mean(block_metrics['pattern_acc']), global_step) writer.add_scalar('Train/Full_Accuracy', np.mean(block_metrics['full_acc']), global_step) # Update progress bar pbar.set_postfix({ 'loss': f"{total_loss.item():.4f}", 't_acc': f"{np.mean(block_metrics['timestep_acc']):.1%}" if block_metrics['timestep_acc'] else "N/A", 'p_acc': f"{np.mean(block_metrics['pattern_acc']):.1%}" if block_metrics['pattern_acc'] else "N/A", }) global_step += 1 # Cleanup del teacher_features, outputs, total_loss torch.cuda.empty_cache() # Epoch summary (match original lines 709-725) for key in epoch_metrics: if key != 'num_batches': avg = epoch_metrics[key] / epoch_metrics['num_batches'] history[key].append(avg) writer.add_scalar(f'Epoch/{key}', avg, epoch) print(f"\nEpoch {epoch+1} Summary:") print(f" Loss: {history['total_loss'][-1]:.4f}") print(f" Timestep Acc: {history['timestep_accuracy'][-1]:.2%}") print(f" Pattern Acc: {history['pattern_accuracy'][-1]:.2%}") print(f" Full Acc: {history['full_accuracy'][-1]:.2%}") print(f" Pattern Diversity: {history['pattern_diversity'][-1]:.3f}") # Save checkpoint if (epoch + 1) % checkpoint_interval == 0: checkpoint_path = f"checkpoint_continued_epoch_{epoch+1:03d}.pt" torch.save({ 'epoch': epoch + 1, 'model_state_dict': collective.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'history': history, 'config': collective.config.__dict__ }, checkpoint_path) print(f" โœ“ Saved: {checkpoint_path}") # Also save as safetensors if HF_AVAILABLE: convert_to_safetensors(checkpoint_path) # Final checkpoint final_path = "david_collective_continued_final.pt" torch.save({ 'epoch': start_epoch + num_epochs, 'model_state_dict': collective.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'history': history, 'config': collective.config.__dict__ }, final_path) print(f"\nโœ… Final checkpoint: {final_path}") # Get prompt stats prompt_stats = prompt_logger.get_stats() print(f"โœ… Prompts logged: {prompt_stats['total']:,} ({prompt_stats['size_mb']:.2f} MB)") writer.close() # HuggingFace upload if auto_upload: print(f"\n[7/7] Uploading to HuggingFace...") upload_to_huggingface( model_path=final_path, repo_name=hf_repo_name, final_epoch=start_epoch + num_epochs, history=history, config=collective.config, total_prompts=prompt_stats['total'], private=False ) else: print(f"\n[7/7] Skipping HuggingFace upload (auto_upload=False)") return collective, history # ============================================================================ # MAIN # ============================================================================ def main(): print("\n" + "="*80) print("DAVID COLLECTIVE - COMPLETE CONTINUATION SYSTEM") print("Checkpoint loading + Prompt logging + HuggingFace upload") print("="*80) device = "cuda" if torch.cuda.is_available() else "cpu" print(f"\nDevice: {device}") if device == "cpu": print("โš ๏ธ WARNING: Requires GPU!") return # Load SD1.5 print(f"\n[1/4] Loading SD1.5...") extractor = StreamingSD15Extractor( model_id="runwayml/stable-diffusion-v1-5", device=device, active_blocks=FULL_CONFIG.active_blocks ) # Create dataset print(f"\n[2/4] Creating symbolic dataset...") dataset = SymbolicPromptDataset( num_samples=100000, complexity_distribution={ 1: 0.05, 2: 0.15, 3: 0.40, 4: 0.25, 5: 0.15 }, seed=42, log_synthesis_stats=True ) dataloader = DataLoader( dataset, batch_size=256, shuffle=True, num_workers=0, pin_memory=True, collate_fn=collate_symbolic_batch ) print(f" โœ“ Dataset: {len(dataset):,} samples") # Initialize collective print(f"\n[3/4] Initializing DavidCollective...") collective = DavidCollective(FULL_CONFIG).to(device) print(f" โœ“ Ready for continuation training") # Continue training print(f"\n[4/4] Starting continuation training...") collective, history = continue_training( collective=collective, extractor=extractor, dataloader=dataloader, start_epoch=100, # Adjust based on your checkpoint num_epochs=5, device=device, log_dir="./runs/david_continued", prompt_log_path="./prompts_all_epochs.jsonl", checkpoint_interval=1, auto_upload=True, # Set to False to skip HuggingFace upload hf_repo_name="AbstractPhil/david-collective-sd15-geometric-distillation" ) print("\n" + "="*80) print("TRAINING COMPLETE!") print("="*80) print(f"\n๐Ÿ“ Files:") print(f" Model: david_collective_continued_final.pt") print(f" Prompts: ./prompts_all_epochs.jsonl") print(f" Logs: ./runs/david_continued") print(f"\n๐Ÿ“Š Final Metrics:") print(f" Loss: {history.get('total_loss', [0])[-1] if history.get('total_loss') else 0:.4f}") print(f" Timestep Acc: {history.get('timestep_accuracy', [0])[-1] if history.get('timestep_accuracy') else 0:.2%}") print(f" Pattern Acc: {history.get('pattern_accuracy', [0])[-1] if history.get('pattern_accuracy') else 0:.2%}") print(f" Full Acc: {history.get('full_accuracy', [0])[-1] if history.get('full_accuracy') else 0:.2%}") return collective, history, extractor if __name__ == "__main__": collective, history, extractor = main()