updated with new and more reliable trainer to allow continuing and janky prompt recording, will improve over time
f53ef80
verified
| """ | |
| DavidCollective SD1.5 - Complete System with Pattern Supervision | |
| ================================================================ | |
| Integrates symbolic synthesis + proper pattern-supervised losses. | |
| Key features: | |
| - Symbolic caption synthesis | |
| - All 9 SD1.5 blocks | |
| - Full pattern supervision (1000 classes, not just 100) | |
| - Pattern diversity regularization | |
| - Three accuracy metrics (timestep, pattern, full) | |
| - Minimal disk usage | |
| - TensorBoard logging | |
| Author: AbstractPhil + Claude Sonnet 4.5 | |
| Run it in colab after installing the necessary repo. | |
| !pip install git+https://github.com/AbstractEyes/lattice_vocabulary.git | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.data import Dataset, DataLoader | |
| from torch.utils.tensorboard import SummaryWriter | |
| from tqdm import tqdm | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Tuple | |
| import time | |
| import json | |
| import numpy as np | |
| from datetime import datetime | |
| # Diffusers | |
| from diffusers import StableDiffusionPipeline | |
| # David imports | |
| from geovocab2.train.model.core.david_diffusion import ( | |
| DavidCollective, | |
| DavidCollectiveConfig, | |
| SD15_BLOCKS | |
| ) | |
| # Symbolic synthesis | |
| from geovocab2.data.prompt.symbolic_tree import SynthesisSystem | |
| # HuggingFace | |
| try: | |
| from huggingface_hub import HfApi, create_repo, upload_folder | |
| from safetensors.torch import save_file | |
| HF_AVAILABLE = True | |
| except ImportError: | |
| print("โ ๏ธ HuggingFace libraries not available. Install with:") | |
| print(" pip install huggingface_hub safetensors") | |
| HF_AVAILABLE = False | |
| # ============================================================================ | |
| # PROMPT LOGGER - Saves ALL prompts to JSONL | |
| # ============================================================================ | |
| class PromptLogger: | |
| """ | |
| Logs ALL prompts with metadata to JSONL. | |
| Flushes after every batch to prevent data loss. | |
| """ | |
| def __init__(self, output_path: str = "./prompts_all_epochs.jsonl"): | |
| self.output_path = Path(output_path) | |
| self.output_path.parent.mkdir(parents=True, exist_ok=True) | |
| # Create/truncate file | |
| with open(self.output_path, 'w') as f: | |
| f.write("") | |
| self.batch_count = 0 | |
| print(f"โ PromptLogger initialized: {self.output_path}") | |
| def log_batch( | |
| self, | |
| prompts: List[str], | |
| timesteps: torch.Tensor, | |
| epoch: int, | |
| batch_idx: int, | |
| global_step: int | |
| ): | |
| """ | |
| Log a batch of prompts with metadata. | |
| Flushes immediately to prevent data loss. | |
| """ | |
| with open(self.output_path, 'a') as f: | |
| for i, (prompt, t) in enumerate(zip(prompts, timesteps)): | |
| entry = { | |
| 'timestamp': datetime.now().isoformat(), | |
| 'epoch': epoch, | |
| 'batch': batch_idx, | |
| 'global_step': global_step, | |
| 'sample_idx': i, | |
| 'timestep': int(t.item()), | |
| 'timestep_bin': int(t.item()) // 10, | |
| 'prompt': prompt | |
| } | |
| f.write(json.dumps(entry) + '\n') | |
| f.flush() # CRITICAL: Force write to disk | |
| self.batch_count += 1 | |
| if self.batch_count % 100 == 0: | |
| print(f" ๐ Logged {self.batch_count} batches ({self.batch_count * len(prompts):,} prompts)") | |
| def get_stats(self) -> dict: | |
| """Get statistics about logged prompts.""" | |
| if not self.output_path.exists(): | |
| return {'total': 0} | |
| with open(self.output_path, 'r') as f: | |
| lines = f.readlines() | |
| return { | |
| 'total': len(lines), | |
| 'size_mb': self.output_path.stat().st_size / 1024**2 | |
| } | |
| # ============================================================================ | |
| # PATTERN-SUPERVISED LOSS | |
| # ============================================================================ | |
| class PatternSupervisedLoss(nn.Module): | |
| """ | |
| Pattern-supervised loss with full 1000-class supervision. | |
| Supervises all 1000 classes (100 timesteps ร 10 patterns). | |
| """ | |
| def __init__( | |
| self, | |
| num_timestep_bins: int = 100, | |
| num_patterns_per_timestep: int = 10, | |
| feature_similarity_weight: float = 0.5, | |
| rose_weight: float = 0.3, | |
| ce_weight: float = 0.2, | |
| pattern_diversity_weight: float = 0.05, | |
| use_soft_assignment: bool = True, | |
| temperature: float = 0.1 | |
| ): | |
| super().__init__() | |
| self.num_bins = num_timestep_bins | |
| self.num_patterns = num_patterns_per_timestep | |
| self.num_classes = num_timestep_bins * num_patterns_per_timestep | |
| self.feature_sim_weight = feature_similarity_weight | |
| self.rose_weight = rose_weight | |
| self.ce_weight = ce_weight | |
| self.pattern_diversity_weight = pattern_diversity_weight | |
| self.use_soft_assignment = use_soft_assignment | |
| self.temperature = temperature | |
| def assign_patterns( | |
| self, | |
| features: torch.Tensor, | |
| timestep_class: torch.Tensor, | |
| crystal_centroids: torch.Tensor | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Assign samples to nearest pattern within their timestep bin. | |
| FIXED: Uses COSINE SIMILARITY (not Euclidean distance) to match original trainer. | |
| Args: | |
| features: [B, D] | |
| timestep_class: [B] - timestep bins [0, num_bins) | |
| crystal_centroids: [num_bins, num_patterns, D] | |
| Returns: | |
| pattern_ids: [B] - pattern indices [0, num_patterns) | |
| full_class_ids: [B] - full class [0, num_classes) | |
| """ | |
| B = features.shape[0] | |
| # Get centroids for each sample's timestep | |
| batch_centroids = crystal_centroids[timestep_class] # [B, num_patterns, D] | |
| # Compute similarities (CRITICAL: Use cosine, not Euclidean!) | |
| features_expanded = features.unsqueeze(1) # [B, 1, D] | |
| similarities = F.cosine_similarity( | |
| features_expanded, | |
| batch_centroids, | |
| dim=2 | |
| ) # [B, num_patterns] | |
| # Assign to nearest (highest similarity) | |
| pattern_ids = similarities.argmax(dim=1) | |
| full_class_ids = timestep_class * self.num_patterns + pattern_ids | |
| return pattern_ids, full_class_ids | |
| def compute_soft_assignment( | |
| self, | |
| features: torch.Tensor, | |
| timestep_class: torch.Tensor, | |
| crystal_centroids: torch.Tensor | |
| ) -> torch.Tensor: | |
| """ | |
| Compute soft pattern assignment with temperature smoothing. | |
| MATCHES ORIGINAL: Lines 120-156 | |
| Args: | |
| features: [B, D] | |
| timestep_class: [B] - timestep bins | |
| crystal_centroids: [num_bins, num_patterns, D] | |
| Returns: | |
| soft_targets: [B, num_classes] - soft target distribution | |
| """ | |
| B, D = features.shape | |
| device = features.device | |
| # Get centroids for each sample's timestep bin | |
| batch_centroids = crystal_centroids[timestep_class] # [B, num_patterns, D] | |
| features_expanded = features.unsqueeze(1) # [B, 1, D] | |
| # Compute cosine similarities | |
| similarities = F.cosine_similarity( | |
| features_expanded, | |
| batch_centroids, | |
| dim=2 | |
| ) # [B, num_patterns] | |
| # Soft assignment with temperature | |
| pattern_probs = F.softmax(similarities / self.temperature, dim=1) | |
| # Create full soft targets [B, num_classes] | |
| soft_targets = torch.zeros(B, self.num_classes, device=device) | |
| for i in range(B): | |
| t = timestep_class[i] | |
| start_idx = t * self.num_patterns | |
| end_idx = start_idx + self.num_patterns | |
| soft_targets[i, start_idx:end_idx] = pattern_probs[i] | |
| return soft_targets | |
| def compute_pattern_diversity_loss( | |
| self, | |
| logits: torch.Tensor, | |
| timestep_class: torch.Tensor | |
| ) -> torch.Tensor: | |
| """ | |
| Encourage diverse pattern usage (prevent mode collapse). | |
| MATCHES ORIGINAL: Lines 157-182 | |
| """ | |
| B = logits.shape[0] | |
| # For each sample, get pattern probs within its timestep | |
| pattern_probs_list = [] | |
| for i in range(B): | |
| t = timestep_class[i] | |
| start_idx = t * self.num_patterns | |
| end_idx = start_idx + self.num_patterns | |
| probs = F.softmax(logits[i, start_idx:end_idx], dim=0) | |
| pattern_probs_list.append(probs) | |
| pattern_probs = torch.stack(pattern_probs_list) # [B, num_patterns] | |
| # Entropy (higher = more diverse) | |
| entropy = -(pattern_probs * torch.log(pattern_probs + 1e-8)).sum(dim=1).mean() | |
| # Minimize negative entropy (maximize diversity) | |
| return -entropy | |
| def forward( | |
| self, | |
| student_features: torch.Tensor, | |
| teacher_features: torch.Tensor, | |
| student_logits: torch.Tensor, | |
| crystal_centroids: torch.Tensor, | |
| timesteps: torch.Tensor | |
| ) -> Tuple[torch.Tensor, Dict]: | |
| """ | |
| Compute full loss with pattern supervision. | |
| Returns: | |
| total_loss: Combined weighted loss | |
| metrics: Dict of individual metrics | |
| """ | |
| # Timestep classification (0-999 -> 0-99 bins) | |
| timestep_class = (timesteps // 10).clamp(0, self.num_bins - 1) | |
| # Pattern assignment (use STUDENT features, not teacher!) | |
| pattern_ids, full_class_ids = self.assign_patterns( | |
| student_features, # โ FIXED: Use student features [B, D] | |
| timestep_class, | |
| crystal_centroids | |
| ) | |
| # Get target centroids for assigned patterns | |
| target_centroids = torch.stack([ | |
| crystal_centroids[timestep_class[j], pattern_ids[j]] | |
| for j in range(len(timestep_class)) | |
| ]) | |
| # 1. Feature similarity loss (student vs target centroids) | |
| feature_sim_loss = 1.0 - F.cosine_similarity( | |
| student_features, | |
| target_centroids, # โ FIXED: Compare to centroids, not teacher | |
| dim=-1 | |
| ).mean() | |
| # 2. Rose loss (MATCHES ORIGINAL: Same as feature_sim_loss!) | |
| # Original trainer line 609: rose_loss = feature_sim_loss | |
| rose_loss = feature_sim_loss # โ Simple copy, not contrastive learning! | |
| # 3. Cross-entropy with soft assignment (MATCHES ORIGINAL) | |
| # Original trainer lines 612-617 | |
| if self.use_soft_assignment: | |
| soft_targets = self.compute_soft_assignment( | |
| student_features, timestep_class, crystal_centroids | |
| ) | |
| log_probs = F.log_softmax(student_logits, dim=1) | |
| ce_loss = -(soft_targets * log_probs).sum(dim=1).mean() | |
| else: | |
| ce_loss = F.cross_entropy(student_logits, full_class_ids) | |
| # 4. Pattern diversity (MATCHES ORIGINAL: lines 622-623) | |
| diversity_loss = self.compute_pattern_diversity_loss( | |
| student_logits, timestep_class | |
| ) | |
| # Total loss | |
| total_loss = ( | |
| self.feature_sim_weight * feature_sim_loss + | |
| self.rose_weight * rose_loss + | |
| self.ce_weight * ce_loss + | |
| self.pattern_diversity_weight * diversity_loss | |
| ) | |
| # Accuracy metrics | |
| timestep_pred = student_logits.argmax(dim=-1) // self.num_patterns | |
| pattern_pred = student_logits.argmax(dim=-1) % self.num_patterns | |
| full_pred = student_logits.argmax(dim=-1) | |
| timestep_acc = (timestep_pred == timestep_class).float().mean() | |
| pattern_acc = (pattern_pred == pattern_ids).float().mean() | |
| full_acc = (full_pred == full_class_ids).float().mean() | |
| metrics = { | |
| 'feature_sim': feature_sim_loss.item(), | |
| 'rose': rose_loss.item(), | |
| 'ce': ce_loss.item(), | |
| 'pattern_diversity': diversity_loss.item(), | |
| 'timestep_acc': timestep_acc.item(), | |
| 'pattern_acc': pattern_acc.item(), | |
| 'full_acc': full_acc.item() | |
| } | |
| return total_loss, metrics | |
| # ============================================================================ | |
| # CONFIG | |
| # ============================================================================ | |
| FULL_CONFIG = DavidCollectiveConfig( | |
| # Timestep discretization | |
| num_timestep_bins=100, | |
| num_feature_patterns_per_timestep=10, # CORRECT parameter name | |
| # Active blocks (all 9) | |
| active_blocks=['down_0', 'down_1', 'down_2', 'down_3', 'mid', 'up_0', 'up_1', 'up_2', 'up_3'], | |
| # David architecture | |
| david_sharing_mode='fully_shared', | |
| david_fusion_mode='deep_efficiency', | |
| use_belly=True, | |
| belly_expand=1.5, | |
| # Loss weights | |
| feature_similarity_weight=0.5, | |
| rose_weight=0.3, | |
| cayley_weight=0.0, | |
| ce_weight=0.2, | |
| # Geometric constraints | |
| rose_margin=1.0, | |
| rose_temperature=0.07, | |
| cayley_volume_floor=1e-4, | |
| # Progressive training | |
| progressive_training=True, | |
| warmup_epochs_per_block=2, | |
| # No caching | |
| cache_dir=None, | |
| max_cache_size_gb=0.0 | |
| ) | |
| # ============================================================================ | |
| # SD1.5 EXTRACTOR - FIXED (NO POOLING) | |
| # ============================================================================ | |
| class StreamingSD15Extractor: | |
| """ | |
| Extract features from SD1.5 UNet. | |
| CRITICAL: Returns spatial features [B, C, H, W] | |
| David's Companions will handle pooling internally. | |
| """ | |
| def __init__( | |
| self, | |
| model_id: str = "runwayml/stable-diffusion-v1-5", | |
| device: str = "cuda", | |
| active_blocks: List[str] = None | |
| ): | |
| self.device = device | |
| self.active_blocks = active_blocks or FULL_CONFIG.active_blocks | |
| print(f"Loading SD1.5 from {model_id}...") | |
| self.pipe = StableDiffusionPipeline.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float16, | |
| safety_checker=None, | |
| requires_safety_checker=False | |
| ).to(device) | |
| self.unet = self.pipe.unet | |
| self.vae = self.pipe.vae | |
| self.text_encoder = self.pipe.text_encoder | |
| self.tokenizer = self.pipe.tokenizer | |
| self.scheduler = self.pipe.scheduler | |
| self.features = {} | |
| self.hooks = [] | |
| self.block_mapping = { | |
| 'down_0': ('down_blocks', 0), | |
| 'down_1': ('down_blocks', 1), | |
| 'down_2': ('down_blocks', 2), | |
| 'down_3': ('down_blocks', 3), | |
| 'mid': ('mid_block', None), | |
| 'up_0': ('up_blocks', 0), | |
| 'up_1': ('up_blocks', 1), | |
| 'up_2': ('up_blocks', 2), | |
| 'up_3': ('up_blocks', 3), | |
| } | |
| print(f"โ SD1.5 loaded on {device}") | |
| def _register_hooks(self): | |
| def make_hook(name): | |
| def hook(module, input, output): | |
| # CRITICAL: Store WITH spatial dimensions | |
| self.features[name] = output.detach().float() | |
| return hook | |
| self._remove_hooks() | |
| for block_name in self.active_blocks: | |
| block_type, idx = self.block_mapping[block_name] | |
| if block_type == 'down_blocks': | |
| block = self.unet.down_blocks[idx] | |
| if hasattr(block, 'resnets') and len(block.resnets) > 0: | |
| hook = block.resnets[-1].register_forward_hook(make_hook(block_name)) | |
| self.hooks.append(hook) | |
| elif block_type == 'mid_block': | |
| hook = self.unet.mid_block.register_forward_hook(make_hook(block_name)) | |
| self.hooks.append(hook) | |
| elif block_type == 'up_blocks': | |
| block = self.unet.up_blocks[idx] | |
| if hasattr(block, 'resnets') and len(block.resnets) > 0: | |
| hook = block.resnets[-1].register_forward_hook(make_hook(block_name)) | |
| self.hooks.append(hook) | |
| def _remove_hooks(self): | |
| for hook in self.hooks: | |
| hook.remove() | |
| self.hooks = [] | |
| def extract_batch( | |
| self, | |
| prompts: List[str], | |
| timesteps: torch.Tensor | |
| ) -> Dict[str, torch.Tensor]: | |
| """ | |
| Extract features from SD1.5 UNet. | |
| CRITICAL: Returns spatial features [B, C, H, W] | |
| NO POOLING - Companions handle it internally | |
| """ | |
| self._register_hooks() | |
| text_inputs = self.tokenizer( | |
| prompts, | |
| padding="max_length", | |
| max_length=self.tokenizer.model_max_length, | |
| truncation=True, | |
| return_tensors="pt" | |
| ).to(self.device) | |
| text_embeddings = self.text_encoder(text_inputs.input_ids)[0] | |
| B = len(prompts) | |
| latents = torch.randn(B, 4, 64, 64, device=self.device, dtype=torch.float16) | |
| for i, t in enumerate(timesteps): | |
| noise = torch.randn_like(latents[i:i+1]) | |
| latents[i:i+1] = self.scheduler.add_noise( | |
| latents[i:i+1], | |
| noise, | |
| t.unsqueeze(0) | |
| ) | |
| self.features = {} | |
| _ = self.unet( | |
| latents, | |
| timesteps.to(self.device), | |
| encoder_hidden_states=text_embeddings | |
| ).sample | |
| self._remove_hooks() | |
| # Return features WITH spatial dimensions [B, C, H, W] | |
| # NO POOLING HERE - Companions will handle it | |
| return self.features.copy() | |
| def __del__(self): | |
| self._remove_hooks() | |
| # ============================================================================ | |
| # SYMBOLIC PROMPT DATASET | |
| # ============================================================================ | |
| class SymbolicPromptDataset(Dataset): | |
| """Generate prompts on-the-fly using synthesis system.""" | |
| def __init__( | |
| self, | |
| num_samples: int = 10000, | |
| complexity_distribution: Optional[Dict[int, float]] = None, | |
| bias_weights_path: Optional[str] = None, | |
| seed: Optional[int] = None, | |
| log_synthesis_stats: bool = False | |
| ): | |
| self.num_samples = num_samples | |
| self.log_synthesis_stats = log_synthesis_stats | |
| if complexity_distribution is None: | |
| complexity_distribution = { | |
| 1: 0.05, 2: 0.15, 3: 0.40, 4: 0.25, 5: 0.15 | |
| } | |
| self.complexity_dist = complexity_distribution | |
| # Initialize synthesis system (no seed parameter) | |
| self.synth = SynthesisSystem() | |
| # Apply bias weights if provided | |
| if bias_weights_path and Path(bias_weights_path).exists(): | |
| with open(bias_weights_path, 'r') as f: | |
| bias_weights = json.load(f) | |
| # Apply bias weights to synthesis system | |
| # (assuming it has some method to set them) | |
| if hasattr(self.synth, 'bias_weights'): | |
| self.synth.bias_weights = bias_weights | |
| # Pre-generate prompts | |
| self.rng = np.random.RandomState(seed) | |
| self.prompts = [] | |
| self.metadata = [] | |
| print(f"Generating {num_samples:,} prompts...") | |
| for i in range(num_samples): | |
| # Sample complexity | |
| complexities = list(complexity_distribution.keys()) | |
| probs = list(complexity_distribution.values()) | |
| complexity = self.rng.choice(complexities, p=probs) | |
| # Generate prompt | |
| try: | |
| result = self.synth.synthesize(complexity=complexity) # โ Correct method name | |
| # Extract text and path_info from result dict | |
| if isinstance(result, dict): | |
| prompt = result.get('text', 'a photo') | |
| path_info = result.get('selected_paths', []) | |
| else: | |
| # Fallback if unexpected format | |
| prompt = str(result) | |
| path_info = {} | |
| self.prompts.append(prompt) | |
| if log_synthesis_stats: | |
| self.metadata.append({ | |
| 'complexity': complexity, | |
| 'path_info': path_info, | |
| 'sample_id': i | |
| }) | |
| except Exception as e: | |
| # Fallback prompt if generation fails | |
| print(f" โ ๏ธ Warning: Failed to generate prompt {i}: {e}") | |
| self.prompts.append("a photo") | |
| if log_synthesis_stats: | |
| self.metadata.append({ | |
| 'complexity': complexity, | |
| 'path_info': {}, | |
| 'sample_id': i, | |
| 'error': str(e) | |
| }) | |
| if (i + 1) % 1000 == 0: | |
| print(f" Generated {i+1:,}/{num_samples:,} prompts...") | |
| print(f"โ Generated {len(self.prompts):,} prompts") | |
| if log_synthesis_stats: | |
| self._log_statistics() | |
| def _log_statistics(self): | |
| from collections import Counter | |
| complexity_counts = Counter(m['complexity'] for m in self.metadata) | |
| print("\nSynthesis Statistics:") | |
| print(" Complexity distribution:") | |
| for complexity in sorted(complexity_counts.keys()): | |
| count = complexity_counts[complexity] | |
| pct = 100 * count / len(self.metadata) | |
| print(f" Complexity {complexity}: {count:,} ({pct:.1f}%)") | |
| print("\n Example prompts:") | |
| for i in [0, len(self.prompts)//4, len(self.prompts)//2, 3*len(self.prompts)//4]: | |
| complexity = self.metadata[i]['complexity'] | |
| print(f" [C={complexity}] {self.prompts[i][:80]}...") | |
| def __len__(self): | |
| return self.num_samples | |
| def __getitem__(self, idx): | |
| prompt = self.prompts[idx] | |
| timestep = self.rng.randint(0, 1000) | |
| metadata = self.metadata[idx] if self.log_synthesis_stats else {} | |
| return { | |
| 'prompt': prompt, | |
| 'timestep': torch.tensor(timestep), | |
| 'metadata': metadata | |
| } | |
| def collate_symbolic_batch(batch): | |
| prompts = [item['prompt'] for item in batch] | |
| timesteps = torch.stack([item['timestep'] for item in batch]) | |
| metadata = [item['metadata'] for item in batch] | |
| return prompts, timesteps, metadata | |
| # ============================================================================ | |
| # HUGGINGFACE UTILITIES | |
| # ============================================================================ | |
| def convert_to_safetensors(checkpoint_path: str) -> str: | |
| """Convert .pt checkpoint to .safetensors format.""" | |
| if not HF_AVAILABLE: | |
| print("โ ๏ธ Safetensors not available, skipping conversion") | |
| return None | |
| checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False) | |
| # Handle different checkpoint formats | |
| if isinstance(checkpoint, dict): | |
| if 'model_state_dict' in checkpoint: | |
| state_dict = checkpoint['model_state_dict'] | |
| elif 'state_dict' in checkpoint: | |
| state_dict = checkpoint['state_dict'] | |
| else: | |
| # Assume the dict IS the state_dict | |
| state_dict = checkpoint | |
| else: | |
| raise ValueError(f"Unexpected checkpoint format: {type(checkpoint)}") | |
| output_path = checkpoint_path.replace('.pt', '.safetensors') | |
| print(f" Converting to safetensors: {output_path}") | |
| # Ensure all tensors are contiguous | |
| state_dict_safe = { | |
| k: v.contiguous() for k, v in state_dict.items() | |
| } | |
| # Save as safetensors | |
| save_file(state_dict_safe, output_path) | |
| size_mb = Path(output_path).stat().st_size / 1024**2 | |
| print(f" โ Saved safetensors ({size_mb:.2f} MB)") | |
| return output_path | |
| def create_readme( | |
| final_epoch: int, | |
| history: dict, | |
| config: DavidCollectiveConfig, | |
| total_prompts: int | |
| ) -> str: | |
| """Generate comprehensive README for HuggingFace model card.""" | |
| readme = f"""# David Collective - SD1.5 Geometric Distillation (Continued) | |
| ## Model Description | |
| **David Collective** is a revolutionary geometric deep learning system that distills Stable Diffusion 1.5's knowledge into an ultra-efficient pentachoron-based architecture. This model was continued from epoch 20 to epoch {final_epoch}, achieving remarkable performance with full pattern supervision. | |
| ### Architecture Highlights | |
| - **Geometric Foundation**: Uses 5D pentachora (5-vertex simplices) instead of traditional attention | |
| - **Multi-Scale Learning**: Extracts features from all 9 SD1.5 UNet blocks | |
| - **Crystal Navigation**: 1000-class supervision (100 timesteps ร 10 geometric patterns) | |
| - **Parameter Efficiency**: Ultra-compact architecture with shared geometric structures | |
| - **Full Supervision**: Every sample supervised by both timestep and geometric pattern | |
| ### Training Details | |
| **Continuation Training:** | |
| - Starting epoch: 20 | |
| - Final epoch: {final_epoch} | |
| - Total prompts trained: {total_prompts:,} | |
| - **All prompts included**: `prompts_all_epochs.jsonl` contains every prompt with metadata | |
| - Dataset: Symbolic caption synthesis (complexity 1-5) | |
| - Batch size: 32 | |
| - Learning rate: 1e-4 with cosine annealing | |
| - Optimizer: AdamW (weight_decay=0.01) | |
| **Final Metrics (Epoch {final_epoch}):** | |
| - Total Loss: {history.get('total_loss', [0])[-1] if history.get('total_loss') else 0:.4f} | |
| - Timestep Accuracy: {history.get('timestep_accuracy', [0])[-1] if history.get('timestep_accuracy') else 0:.2%} | |
| - Pattern Accuracy: {history.get('pattern_accuracy', [0])[-1] if history.get('pattern_accuracy') else 0:.2%} | |
| - Full Accuracy: {history.get('full_accuracy', [0])[-1] if history.get('full_accuracy') else 0:.2%} | |
| - Pattern Diversity: {history.get('pattern_diversity', [0])[-1] if history.get('pattern_diversity') else 0:.3f} | |
| ### Active Blocks | |
| David learns from all 9 SD1.5 UNet blocks: | |
| - `down_0`, `down_1`, `down_2`, `down_3`: Coarse semantic features | |
| - `mid`: Bottleneck representations | |
| - `up_0`, `up_1`, `up_2`, `up_3`: Fine reconstruction details | |
| ### Loss Components | |
| 1. **Feature Similarity** ({config.feature_similarity_weight}): Cosine similarity with teacher | |
| 2. **Rose Loss** ({config.rose_weight}): Geometric alignment with crystal centroids | |
| 3. **Cross-Entropy** ({config.ce_weight}): 1000-class classification | |
| 4. **Pattern Diversity** (0.05): Encourages balanced pattern usage | |
| ## Usage | |
| ### Loading the Model | |
| ```python | |
| import torch | |
| from geovocab2.train.model.core.david_diffusion import DavidCollective, DavidCollectiveConfig | |
| from safetensors.torch import load_file | |
| # Load configuration | |
| config = DavidCollectiveConfig( | |
| num_timestep_bins=100, | |
| num_feature_patterns_per_timestep=10, | |
| active_blocks={config.active_blocks}, | |
| david_sharing_mode='fully_shared', | |
| david_fusion_mode='deep_efficiency', | |
| use_belly=True, | |
| belly_expand=1.5 | |
| ) | |
| # Create model | |
| model = DavidCollective(config) | |
| # Load weights from safetensors | |
| state_dict = load_file("model.safetensors") | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| # Inference | |
| with torch.no_grad(): | |
| outputs = model(teacher_features, timesteps) | |
| ``` | |
| ### Training Data | |
| This model includes `prompts_all_epochs.jsonl` - every single prompt used during training with full metadata: | |
| ```json | |
| {{"timestamp": "2025-10-27T01:30:00", "epoch": 21, "batch": 0, "global_step": 6250, "sample_idx": 0, "timestep": 453, "timestep_bin": 45, "prompt": "a woman wearing red dress, against mountain landscape"}} | |
| ``` | |
| **Total prompts:** {total_prompts:,} | |
| You can use this to: | |
| - Analyze training data distribution | |
| - Reproduce training | |
| - Study prompt complexity vs model performance | |
| - Generate similar synthetic datasets | |
| ## Technical Details | |
| ### Crystal System | |
| - **Architecture**: Pentachoron-based geometric deep learning | |
| - **Centroids**: 100 timestep bins ร 10 patterns = 1000 anchors | |
| - **Navigation**: Samples assigned to nearest pattern within timestep bin | |
| - **Diversity**: Regularization prevents mode collapse | |
| ### Progressive Training | |
| - Started with early blocks (down_0, down_1) | |
| - Progressively activated all 9 blocks | |
| - Each block warmed up for {config.warmup_epochs_per_block} epochs | |
| ### Pattern Supervision | |
| Unlike traditional timestep-only supervision, David learns: | |
| 1. **When** (timestep bin 0-99) | |
| 2. **How** (geometric pattern 0-9 within that bin) | |
| 3. **Combined** (full 1000-class space) | |
| This provides 10x finer-grained supervision of the diffusion process. | |
| ## Training History | |
| Trained continuously from epoch 20 to epoch {final_epoch}. See metrics: | |
| - Timestep accuracy improved from ~{history.get('timestep_accuracy', [0])[0] if history.get('timestep_accuracy') else 0:.1%} to {history.get('timestep_accuracy', [0])[-1] if history.get('timestep_accuracy') else 0:.2%} | |
| - Pattern accuracy maintained at {history.get('pattern_accuracy', [0])[-1] if history.get('pattern_accuracy') else 0:.2%} | |
| - Loss decreased from {history.get('total_loss', [0])[0] if history.get('total_loss') else 0:.4f} to {history.get('total_loss', [0])[-1] if history.get('total_loss') else 0:.4f} | |
| ## Citation | |
| ```bibtex | |
| @misc{{david-collective-sd15, | |
| title={{David Collective: Geometric Deep Learning for Diffusion Distillation}}, | |
| author={{AbstractPhil}}, | |
| year={{2025}}, | |
| publisher={{HuggingFace}}, | |
| howpublished={{\\url{{https://huggingface.co/AbstractPhil/david-collective-sd15-geometric-distillation}}}} | |
| }} | |
| ``` | |
| ## License | |
| MIT License - See repository for details. | |
| ## Acknowledgments | |
| Built on the geometric deep learning research by AbstractPhil, using: | |
| - Stable Diffusion 1.5 (teacher model) | |
| - Pentachoron-based geometric algebra | |
| - Crystalline consciousness architectures | |
| - Symbolic caption synthesis | |
| For more information, visit the [geovocab2 repository](https://github.com/AbstractEyes/lattice_vocabulary). | |
| """ | |
| return readme | |
| def create_model_card(config: DavidCollectiveConfig) -> dict: | |
| """Create model card metadata for HuggingFace.""" | |
| return { | |
| 'language': ['en'], | |
| 'license': 'mit', | |
| 'tags': [ | |
| 'geometric-deep-learning', | |
| 'diffusion-distillation', | |
| 'stable-diffusion', | |
| 'pentachoron', | |
| 'crystal-navigation', | |
| 'pattern-supervision', | |
| 'ultra-efficient', | |
| 'sd15-distillation' | |
| ], | |
| 'datasets': ['synthetic-captions'], | |
| 'metrics': ['accuracy', 'loss'], | |
| 'library_name': 'pytorch', | |
| 'pipeline_tag': 'image-classification', | |
| } | |
| def upload_to_huggingface( | |
| model_path: str, | |
| repo_name: str = "AbstractPhil/david-collective-sd15-geometric-distillation", | |
| final_epoch: int = 50, | |
| history: dict = None, | |
| config: DavidCollectiveConfig = None, | |
| total_prompts: int = 0, | |
| private: bool = False | |
| ): | |
| """Upload model to HuggingFace Hub with README and model card.""" | |
| if not HF_AVAILABLE: | |
| print("\nโ ๏ธ HuggingFace libraries not available") | |
| print("Install with: pip install huggingface_hub safetensors") | |
| return None | |
| print(f"\n{'='*80}") | |
| print("UPLOADING TO HUGGINGFACE") | |
| print(f"{'='*80}\n") | |
| # Convert to safetensors | |
| print("[1/5] Converting to safetensors...") | |
| safetensors_path = convert_to_safetensors(model_path) | |
| if not safetensors_path: | |
| print("โ Conversion failed, aborting upload") | |
| return None | |
| # Create temporary upload directory | |
| print("\n[2/5] Preparing upload directory...") | |
| upload_dir = Path("./hf_upload_temp") | |
| upload_dir.mkdir(exist_ok=True) | |
| # Copy safetensors | |
| import shutil | |
| shutil.copy(safetensors_path, upload_dir / "model.safetensors") | |
| print(f" โ Copied model.safetensors") | |
| # Copy prompts file (CRITICAL!) | |
| prompt_file = Path("./prompts_all_epochs.jsonl") | |
| if prompt_file.exists(): | |
| shutil.copy(prompt_file, upload_dir / "prompts_all_epochs.jsonl") | |
| print(f" โ Copied prompts_all_epochs.jsonl ({prompt_file.stat().st_size / 1024**2:.2f} MB)") | |
| else: | |
| print(f" โ ๏ธ Warning: prompts_all_epochs.jsonl not found, skipping") | |
| # Generate README | |
| print("\n[3/5] Generating README...") | |
| readme_content = create_readme(final_epoch, history, config, total_prompts) | |
| (upload_dir / "README.md").write_text(readme_content) | |
| print(f" โ Created README.md") | |
| # Generate model card | |
| print("\n[4/5] Creating model card...") | |
| model_card = create_model_card(config) | |
| (upload_dir / "model_card.json").write_text(json.dumps(model_card, indent=2)) | |
| print(f" โ Created model_card.json") | |
| # Save config | |
| config_dict = {k: v for k, v in config.__dict__.items() if not k.startswith('_')} | |
| (upload_dir / "config.json").write_text(json.dumps(config_dict, indent=2)) | |
| print(f" โ Created config.json") | |
| # Upload | |
| print(f"\n[5/5] Uploading to {repo_name}...") | |
| try: | |
| api = HfApi() | |
| # Create repo if doesn't exist | |
| try: | |
| create_repo(repo_name, private=private, exist_ok=True) | |
| print(f" โ Repository ready") | |
| except Exception as e: | |
| print(f" โ ๏ธ Repo might already exist: {e}") | |
| # Upload folder | |
| api.upload_folder( | |
| folder_path=str(upload_dir), | |
| repo_id=repo_name, | |
| repo_type="model", | |
| commit_message=f"Upload continuation training (epoch {final_epoch})" | |
| ) | |
| print(f"\nโ UPLOAD COMPLETE!") | |
| print(f"\n๐ View your model: https://huggingface.co/{repo_name}") | |
| print(f"\n๐ฆ Uploaded files:") | |
| print(f" - model.safetensors") | |
| print(f" - prompts_all_epochs.jsonl ({total_prompts:,} prompts)") | |
| print(f" - README.md (with metrics)") | |
| print(f" - config.json") | |
| print(f" - model_card.json") | |
| # Cleanup | |
| shutil.rmtree(upload_dir) | |
| print(f"\n๐งน Cleaned up temporary files") | |
| return f"https://huggingface.co/{repo_name}" | |
| except Exception as e: | |
| print(f"\nโ Upload failed: {e}") | |
| print(f"Files are still in: {upload_dir}") | |
| print(f"You can upload manually or fix the error and retry") | |
| return None | |
| # ============================================================================ | |
| # CONTINUATION TRAINING | |
| # ============================================================================ | |
| def continue_training( | |
| collective: DavidCollective, | |
| extractor: StreamingSD15Extractor, | |
| dataloader: DataLoader, | |
| start_epoch: int, | |
| num_epochs: int, | |
| device: str = "cuda", | |
| log_dir: str = "./runs/david_continued", | |
| prompt_log_path: str = "./prompts_all_epochs.jsonl", | |
| checkpoint_interval: int = 5, | |
| auto_upload: bool = True, | |
| hf_repo_name: str = "AbstractPhil/david-collective-sd15-geometric-distillation" | |
| ): | |
| """ | |
| Continue training from checkpoint with full logging and HuggingFace upload. | |
| Args: | |
| collective: DavidCollective model | |
| extractor: SD1.5 feature extractor | |
| dataloader: Training data | |
| start_epoch: Epoch to start from (loaded from checkpoint) | |
| num_epochs: Additional epochs to train | |
| device: Device to train on | |
| log_dir: TensorBoard log directory | |
| prompt_log_path: Where to save all prompts | |
| checkpoint_interval: Save checkpoint every N epochs | |
| auto_upload: Automatically upload to HuggingFace after training | |
| hf_repo_name: HuggingFace repository name | |
| """ | |
| print("\n" + "="*80) | |
| print("DAVID COLLECTIVE - CONTINUATION TRAINING") | |
| print("="*80) | |
| # Load checkpoint | |
| print(f"\n[1/7] Loading checkpoint...") | |
| checkpoint_path = Path("david_collective_continued_final.pt") | |
| if not checkpoint_path.exists(): | |
| raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") | |
| # Capture initial weights for verification | |
| print(f" Capturing initial weights...") | |
| initial_sample_key = list(collective.state_dict().keys())[0] | |
| initial_sample_weight = collective.state_dict()[initial_sample_key].clone() | |
| initial_mean = initial_sample_weight.mean().item() | |
| checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) | |
| # Handle different checkpoint formats | |
| if isinstance(checkpoint, dict): | |
| # Check what keys are present | |
| print(f" Checkpoint keys: {list(checkpoint.keys())}") | |
| if 'model_state_dict' in checkpoint: | |
| state_dict = checkpoint['model_state_dict'] | |
| actual_epoch = checkpoint.get('epoch', start_epoch) | |
| elif 'state_dict' in checkpoint: | |
| state_dict = checkpoint['state_dict'] | |
| actual_epoch = checkpoint.get('epoch', start_epoch) | |
| else: | |
| # Assume the dict IS the state_dict | |
| state_dict = checkpoint | |
| actual_epoch = start_epoch | |
| # Load with strict=False to see any issues | |
| print(f" Loading state dict ({len(state_dict)} parameters)...") | |
| missing_keys, unexpected_keys = collective.load_state_dict(state_dict, strict=False) | |
| if missing_keys: | |
| print(f" โ ๏ธ Missing keys ({len(missing_keys)}): {missing_keys[:3]}...") | |
| if unexpected_keys: | |
| print(f" โ ๏ธ Unexpected keys ({len(unexpected_keys)}): {unexpected_keys[:3]}...") | |
| # Verify weights actually changed | |
| final_sample_weight = collective.state_dict()[initial_sample_key] | |
| final_mean = final_sample_weight.mean().item() | |
| if torch.equal(initial_sample_weight, final_sample_weight): | |
| raise RuntimeError( | |
| f"โ CRITICAL: Weights did NOT change after loading!\n" | |
| f" Sample param: {initial_sample_key}\n" | |
| f" This means the checkpoint is not being loaded properly." | |
| ) | |
| print(f" โ Weights verified changed (sample mean: {initial_mean:.6f} -> {final_mean:.6f})") | |
| print(f" โ Loaded from epoch {actual_epoch}") | |
| else: | |
| # Not a dict - shouldn't happen but handle it | |
| raise ValueError(f"Unexpected checkpoint format: {type(checkpoint)}") | |
| # Model info | |
| total_params = sum(p.numel() for p in collective.parameters()) | |
| print(f"\n Model Status:") | |
| print(f" Parameters: {total_params:,}") | |
| print(f" Active blocks: {len(collective.config.active_blocks)}") | |
| print(f" Companions: {list(collective.companions.keys())}") | |
| # Prompt logger | |
| print(f"\n[2/7] Initializing prompt logger...") | |
| prompt_logger = PromptLogger(prompt_log_path) | |
| print(f" โ Saving to: {prompt_log_path}") | |
| # Loss function | |
| print(f"\n[3/7] Setting up loss function...") | |
| criterion = PatternSupervisedLoss( | |
| num_timestep_bins=collective.config.num_timestep_bins, | |
| num_patterns_per_timestep=collective.config.num_feature_patterns_per_timestep, # CORRECT attribute name | |
| feature_similarity_weight=0.5, | |
| rose_weight=0.3, | |
| ce_weight=0.2, | |
| pattern_diversity_weight=0.05 | |
| ).to(device) | |
| print(f" โ Pattern-supervised loss ready") | |
| # Optimizer | |
| print(f"\n[4/7] Creating optimizer...") | |
| optimizer = torch.optim.AdamW( | |
| collective.parameters(), | |
| lr=1e-4, | |
| weight_decay=0.01 | |
| ) | |
| # Scheduler | |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( | |
| optimizer, | |
| T_max=num_epochs * len(dataloader), | |
| eta_min=1e-6 | |
| ) | |
| print(f" โ AdamW + Cosine annealing") | |
| # TensorBoard | |
| print(f"\n[5/7] Setting up TensorBoard...") | |
| writer = SummaryWriter(log_dir) | |
| print(f" โ Logging to: {log_dir}") | |
| # Training | |
| print(f"\n[6/7] Starting training...") | |
| print(f" Epochs: {start_epoch + 1} โ {start_epoch + num_epochs}") | |
| print(f" Batches per epoch: {len(dataloader)}") | |
| print() | |
| collective.train() | |
| # Initialize history (match original lines 519-528) | |
| history = { | |
| 'total_loss': [], | |
| 'feature_sim': [], | |
| 'rose': [], | |
| 'ce': [], | |
| 'pattern_diversity': [], | |
| 'timestep_accuracy': [], | |
| 'pattern_accuracy': [], | |
| 'full_accuracy': [] | |
| } | |
| global_step = start_epoch * len(dataloader) | |
| for epoch in range(start_epoch, start_epoch + num_epochs): | |
| collective.update_epoch(epoch) # Match original line 534 | |
| epoch_metrics = { | |
| 'total_loss': 0.0, | |
| 'feature_sim': 0.0, | |
| 'rose': 0.0, | |
| 'ce': 0.0, | |
| 'pattern_diversity': 0.0, | |
| 'timestep_accuracy': 0.0, | |
| 'pattern_accuracy': 0.0, | |
| 'full_accuracy': 0.0, | |
| 'num_batches': 0 | |
| } | |
| pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{start_epoch+num_epochs}") | |
| for batch_idx, (prompts, timesteps, metadata) in enumerate(pbar): | |
| batch_start = time.time() | |
| # Extract features (WITH spatial dimensions!) | |
| teacher_features = extractor.extract_batch(prompts, timesteps) | |
| teacher_features = { | |
| k: v.to(device) for k, v in teacher_features.items() | |
| } | |
| timesteps = timesteps.to(device) | |
| # Verify shapes on first batch | |
| if batch_idx == 0 and epoch == start_epoch: | |
| print(f"\n๐ Feature shapes (Epoch {epoch+1}, Batch 1):") | |
| for k, v in teacher_features.items(): | |
| print(f" {k}: {v.shape}") | |
| if v.dim() != 4: | |
| raise ValueError( | |
| f"Expected 4D features [B,C,H,W], got {v.dim()}D for {k}!" | |
| ) | |
| print() | |
| # Forward | |
| outputs = collective(teacher_features, timesteps) | |
| # Compute loss | |
| total_loss = torch.tensor(0.0, device=device) | |
| block_metrics = {k: [] for k in ['feature_sim', 'rose', 'ce', | |
| 'pattern_diversity', 'timestep_acc', | |
| 'pattern_acc', 'full_acc']} | |
| for block_name in collective.companions.keys(): | |
| if block_name not in outputs or block_name not in teacher_features: | |
| continue | |
| # Use EXACT same structure as original trainer (lines 589-592) | |
| companion = collective.companions[block_name] | |
| block_output = outputs[block_name] | |
| # Get features and timestep class FROM DavidCollective output | |
| student_features = block_output['scale_features'][0] # First scale | |
| student_logits = block_output['combined_logits'] | |
| timestep_class = block_output['timestep_class'] # โ FROM OUTPUT, not recomputed! | |
| # Get crystal centroids (lines 585-587 from original) | |
| crystal_anchors = companion.crystal_anchors # [bins, patterns, 5, max_scale] | |
| scale = companion.david_config.scales[0] # Use first scale | |
| crystal_centroids = crystal_anchors[..., :scale].mean(dim=2) # [bins, patterns, scale] | |
| # INLINE LOSS COMPUTATION (like original, NOT using forward()) | |
| # Assign patterns | |
| _, full_class_ids = criterion.assign_patterns( | |
| student_features, timestep_class, crystal_centroids | |
| ) | |
| # Feature similarity loss | |
| pattern_ids = full_class_ids % criterion.num_patterns | |
| target_centroids = torch.stack([ | |
| crystal_centroids[timestep_class[j], pattern_ids[j]] | |
| for j in range(len(timestep_class)) | |
| ]) | |
| cos_sim = F.cosine_similarity(student_features, target_centroids, dim=-1) | |
| feature_sim_loss = (1 - cos_sim).mean() | |
| # Rose loss (same as feature sim) | |
| rose_loss = feature_sim_loss | |
| # Cross-entropy with pattern supervision | |
| if criterion.use_soft_assignment: | |
| soft_targets = criterion.compute_soft_assignment( | |
| student_features, timestep_class, crystal_centroids | |
| ) | |
| log_probs = F.log_softmax(student_logits, dim=1) | |
| ce_loss = -(soft_targets * log_probs).sum(dim=1).mean() | |
| else: | |
| ce_loss = F.cross_entropy(student_logits, full_class_ids) | |
| # Pattern diversity | |
| diversity_loss = criterion.compute_pattern_diversity_loss( | |
| student_logits, timestep_class | |
| ) | |
| # Combined loss for this block | |
| block_loss = ( | |
| criterion.feature_sim_weight * feature_sim_loss + | |
| criterion.rose_weight * rose_loss + | |
| criterion.ce_weight * ce_loss + | |
| criterion.pattern_diversity_weight * diversity_loss | |
| ) | |
| total_loss = total_loss + block_loss | |
| # Accuracies (lines 626-642 from original) | |
| pred_class = student_logits.argmax(dim=1) | |
| pred_timestep = pred_class // criterion.num_patterns | |
| pred_pattern = pred_class % criterion.num_patterns | |
| true_pattern = full_class_ids % criterion.num_patterns | |
| timestep_acc = (pred_timestep == timestep_class).float().mean() | |
| correct_timestep_mask = (pred_timestep == timestep_class) | |
| if correct_timestep_mask.sum() > 0: | |
| pattern_acc = ( | |
| pred_pattern[correct_timestep_mask] == true_pattern[correct_timestep_mask] | |
| ).float().mean() | |
| else: | |
| pattern_acc = torch.tensor(0.0, device=device) | |
| full_acc = (pred_class == full_class_ids).float().mean() | |
| # Collect metrics | |
| block_metrics['feature_sim'].append(feature_sim_loss.item()) | |
| block_metrics['rose'].append(rose_loss.item()) | |
| block_metrics['ce'].append(ce_loss.item()) | |
| block_metrics['pattern_diversity'].append(diversity_loss.item()) | |
| block_metrics['timestep_acc'].append(timestep_acc.item()) | |
| block_metrics['pattern_acc'].append(pattern_acc.item()) | |
| block_metrics['full_acc'].append(full_acc.item()) | |
| # Average across blocks (from original trainer lines 664-668) | |
| num_processed_blocks = len([k for k in outputs.keys() if k in collective.companions]) | |
| if num_processed_blocks > 0: | |
| total_loss = total_loss / num_processed_blocks | |
| # Backward | |
| optimizer.zero_grad() | |
| total_loss.backward() | |
| torch.nn.utils.clip_grad_norm_(collective.parameters(), 1.0) | |
| optimizer.step() | |
| scheduler.step() | |
| # Log prompts (CRITICAL!) | |
| prompt_logger.log_batch(prompts, timesteps, epoch + 1, batch_idx, global_step) | |
| # Aggregate metrics (match original lines 677-686) | |
| epoch_metrics['total_loss'] += total_loss.item() | |
| epoch_metrics['feature_sim'] += np.mean(block_metrics['feature_sim']) | |
| epoch_metrics['rose'] += np.mean(block_metrics['rose']) | |
| epoch_metrics['ce'] += np.mean(block_metrics['ce']) | |
| epoch_metrics['pattern_diversity'] += np.mean(block_metrics['pattern_diversity']) | |
| epoch_metrics['timestep_accuracy'] += np.mean(block_metrics['timestep_acc']) | |
| epoch_metrics['pattern_accuracy'] += np.mean(block_metrics['pattern_acc']) | |
| epoch_metrics['full_accuracy'] += np.mean(block_metrics['full_acc']) | |
| epoch_metrics['num_batches'] += 1 | |
| # TensorBoard logging (match original lines 688-696) | |
| writer.add_scalar('Train/Total_Loss', total_loss.item(), global_step) | |
| writer.add_scalar('Train/Feature_Similarity', np.mean(block_metrics['feature_sim']), global_step) | |
| writer.add_scalar('Train/Rose_Loss', np.mean(block_metrics['rose']), global_step) | |
| writer.add_scalar('Train/CE_Loss', np.mean(block_metrics['ce']), global_step) | |
| writer.add_scalar('Train/Pattern_Diversity', np.mean(block_metrics['pattern_diversity']), global_step) | |
| writer.add_scalar('Train/Timestep_Accuracy', np.mean(block_metrics['timestep_acc']), global_step) | |
| writer.add_scalar('Train/Pattern_Accuracy', np.mean(block_metrics['pattern_acc']), global_step) | |
| writer.add_scalar('Train/Full_Accuracy', np.mean(block_metrics['full_acc']), global_step) | |
| # Update progress bar | |
| pbar.set_postfix({ | |
| 'loss': f"{total_loss.item():.4f}", | |
| 't_acc': f"{np.mean(block_metrics['timestep_acc']):.1%}" if block_metrics['timestep_acc'] else "N/A", | |
| 'p_acc': f"{np.mean(block_metrics['pattern_acc']):.1%}" if block_metrics['pattern_acc'] else "N/A", | |
| }) | |
| global_step += 1 | |
| # Cleanup | |
| del teacher_features, outputs, total_loss | |
| torch.cuda.empty_cache() | |
| # Epoch summary (match original lines 709-725) | |
| for key in epoch_metrics: | |
| if key != 'num_batches': | |
| avg = epoch_metrics[key] / epoch_metrics['num_batches'] | |
| history[key].append(avg) | |
| writer.add_scalar(f'Epoch/{key}', avg, epoch) | |
| print(f"\nEpoch {epoch+1} Summary:") | |
| print(f" Loss: {history['total_loss'][-1]:.4f}") | |
| print(f" Timestep Acc: {history['timestep_accuracy'][-1]:.2%}") | |
| print(f" Pattern Acc: {history['pattern_accuracy'][-1]:.2%}") | |
| print(f" Full Acc: {history['full_accuracy'][-1]:.2%}") | |
| print(f" Pattern Diversity: {history['pattern_diversity'][-1]:.3f}") | |
| # Save checkpoint | |
| if (epoch + 1) % checkpoint_interval == 0: | |
| checkpoint_path = f"checkpoint_continued_epoch_{epoch+1:03d}.pt" | |
| torch.save({ | |
| 'epoch': epoch + 1, | |
| 'model_state_dict': collective.state_dict(), | |
| 'optimizer_state_dict': optimizer.state_dict(), | |
| 'scheduler_state_dict': scheduler.state_dict(), | |
| 'history': history, | |
| 'config': collective.config.__dict__ | |
| }, checkpoint_path) | |
| print(f" โ Saved: {checkpoint_path}") | |
| # Also save as safetensors | |
| if HF_AVAILABLE: | |
| convert_to_safetensors(checkpoint_path) | |
| # Final checkpoint | |
| final_path = "david_collective_continued_final.pt" | |
| torch.save({ | |
| 'epoch': start_epoch + num_epochs, | |
| 'model_state_dict': collective.state_dict(), | |
| 'optimizer_state_dict': optimizer.state_dict(), | |
| 'scheduler_state_dict': scheduler.state_dict(), | |
| 'history': history, | |
| 'config': collective.config.__dict__ | |
| }, final_path) | |
| print(f"\nโ Final checkpoint: {final_path}") | |
| # Get prompt stats | |
| prompt_stats = prompt_logger.get_stats() | |
| print(f"โ Prompts logged: {prompt_stats['total']:,} ({prompt_stats['size_mb']:.2f} MB)") | |
| writer.close() | |
| # HuggingFace upload | |
| if auto_upload: | |
| print(f"\n[7/7] Uploading to HuggingFace...") | |
| upload_to_huggingface( | |
| model_path=final_path, | |
| repo_name=hf_repo_name, | |
| final_epoch=start_epoch + num_epochs, | |
| history=history, | |
| config=collective.config, | |
| total_prompts=prompt_stats['total'], | |
| private=False | |
| ) | |
| else: | |
| print(f"\n[7/7] Skipping HuggingFace upload (auto_upload=False)") | |
| return collective, history | |
| # ============================================================================ | |
| # MAIN | |
| # ============================================================================ | |
| def main(): | |
| print("\n" + "="*80) | |
| print("DAVID COLLECTIVE - COMPLETE CONTINUATION SYSTEM") | |
| print("Checkpoint loading + Prompt logging + HuggingFace upload") | |
| print("="*80) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"\nDevice: {device}") | |
| if device == "cpu": | |
| print("โ ๏ธ WARNING: Requires GPU!") | |
| return | |
| # Load SD1.5 | |
| print(f"\n[1/4] Loading SD1.5...") | |
| extractor = StreamingSD15Extractor( | |
| model_id="runwayml/stable-diffusion-v1-5", | |
| device=device, | |
| active_blocks=FULL_CONFIG.active_blocks | |
| ) | |
| # Create dataset | |
| print(f"\n[2/4] Creating symbolic dataset...") | |
| dataset = SymbolicPromptDataset( | |
| num_samples=100000, | |
| complexity_distribution={ | |
| 1: 0.05, 2: 0.15, 3: 0.40, 4: 0.25, 5: 0.15 | |
| }, | |
| seed=42, | |
| log_synthesis_stats=True | |
| ) | |
| dataloader = DataLoader( | |
| dataset, | |
| batch_size=256, | |
| shuffle=True, | |
| num_workers=0, | |
| pin_memory=True, | |
| collate_fn=collate_symbolic_batch | |
| ) | |
| print(f" โ Dataset: {len(dataset):,} samples") | |
| # Initialize collective | |
| print(f"\n[3/4] Initializing DavidCollective...") | |
| collective = DavidCollective(FULL_CONFIG).to(device) | |
| print(f" โ Ready for continuation training") | |
| # Continue training | |
| print(f"\n[4/4] Starting continuation training...") | |
| collective, history = continue_training( | |
| collective=collective, | |
| extractor=extractor, | |
| dataloader=dataloader, | |
| start_epoch=100, # Adjust based on your checkpoint | |
| num_epochs=5, | |
| device=device, | |
| log_dir="./runs/david_continued", | |
| prompt_log_path="./prompts_all_epochs.jsonl", | |
| checkpoint_interval=1, | |
| auto_upload=True, # Set to False to skip HuggingFace upload | |
| hf_repo_name="AbstractPhil/david-collective-sd15-geometric-distillation" | |
| ) | |
| print("\n" + "="*80) | |
| print("TRAINING COMPLETE!") | |
| print("="*80) | |
| print(f"\n๐ Files:") | |
| print(f" Model: david_collective_continued_final.pt") | |
| print(f" Prompts: ./prompts_all_epochs.jsonl") | |
| print(f" Logs: ./runs/david_continued") | |
| print(f"\n๐ Final Metrics:") | |
| print(f" Loss: {history.get('total_loss', [0])[-1] if history.get('total_loss') else 0:.4f}") | |
| print(f" Timestep Acc: {history.get('timestep_accuracy', [0])[-1] if history.get('timestep_accuracy') else 0:.2%}") | |
| print(f" Pattern Acc: {history.get('pattern_accuracy', [0])[-1] if history.get('pattern_accuracy') else 0:.2%}") | |
| print(f" Full Acc: {history.get('full_accuracy', [0])[-1] if history.get('full_accuracy') else 0:.2%}") | |
| return collective, history, extractor | |
| if __name__ == "__main__": | |
| collective, history, extractor = main() |