AbstractPhil's picture
updated with new and more reliable trainer to allow continuing and janky prompt recording, will improve over time
f53ef80 verified
"""
DavidCollective SD1.5 - Complete System with Pattern Supervision
================================================================
Integrates symbolic synthesis + proper pattern-supervised losses.
Key features:
- Symbolic caption synthesis
- All 9 SD1.5 blocks
- Full pattern supervision (1000 classes, not just 100)
- Pattern diversity regularization
- Three accuracy metrics (timestep, pattern, full)
- Minimal disk usage
- TensorBoard logging
Author: AbstractPhil + Claude Sonnet 4.5
Run it in colab after installing the necessary repo.
!pip install git+https://github.com/AbstractEyes/lattice_vocabulary.git
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import time
import json
import numpy as np
from datetime import datetime
# Diffusers
from diffusers import StableDiffusionPipeline
# David imports
from geovocab2.train.model.core.david_diffusion import (
DavidCollective,
DavidCollectiveConfig,
SD15_BLOCKS
)
# Symbolic synthesis
from geovocab2.data.prompt.symbolic_tree import SynthesisSystem
# HuggingFace
try:
from huggingface_hub import HfApi, create_repo, upload_folder
from safetensors.torch import save_file
HF_AVAILABLE = True
except ImportError:
print("โš ๏ธ HuggingFace libraries not available. Install with:")
print(" pip install huggingface_hub safetensors")
HF_AVAILABLE = False
# ============================================================================
# PROMPT LOGGER - Saves ALL prompts to JSONL
# ============================================================================
class PromptLogger:
"""
Logs ALL prompts with metadata to JSONL.
Flushes after every batch to prevent data loss.
"""
def __init__(self, output_path: str = "./prompts_all_epochs.jsonl"):
self.output_path = Path(output_path)
self.output_path.parent.mkdir(parents=True, exist_ok=True)
# Create/truncate file
with open(self.output_path, 'w') as f:
f.write("")
self.batch_count = 0
print(f"โœ“ PromptLogger initialized: {self.output_path}")
def log_batch(
self,
prompts: List[str],
timesteps: torch.Tensor,
epoch: int,
batch_idx: int,
global_step: int
):
"""
Log a batch of prompts with metadata.
Flushes immediately to prevent data loss.
"""
with open(self.output_path, 'a') as f:
for i, (prompt, t) in enumerate(zip(prompts, timesteps)):
entry = {
'timestamp': datetime.now().isoformat(),
'epoch': epoch,
'batch': batch_idx,
'global_step': global_step,
'sample_idx': i,
'timestep': int(t.item()),
'timestep_bin': int(t.item()) // 10,
'prompt': prompt
}
f.write(json.dumps(entry) + '\n')
f.flush() # CRITICAL: Force write to disk
self.batch_count += 1
if self.batch_count % 100 == 0:
print(f" ๐Ÿ“ Logged {self.batch_count} batches ({self.batch_count * len(prompts):,} prompts)")
def get_stats(self) -> dict:
"""Get statistics about logged prompts."""
if not self.output_path.exists():
return {'total': 0}
with open(self.output_path, 'r') as f:
lines = f.readlines()
return {
'total': len(lines),
'size_mb': self.output_path.stat().st_size / 1024**2
}
# ============================================================================
# PATTERN-SUPERVISED LOSS
# ============================================================================
class PatternSupervisedLoss(nn.Module):
"""
Pattern-supervised loss with full 1000-class supervision.
Supervises all 1000 classes (100 timesteps ร— 10 patterns).
"""
def __init__(
self,
num_timestep_bins: int = 100,
num_patterns_per_timestep: int = 10,
feature_similarity_weight: float = 0.5,
rose_weight: float = 0.3,
ce_weight: float = 0.2,
pattern_diversity_weight: float = 0.05,
use_soft_assignment: bool = True,
temperature: float = 0.1
):
super().__init__()
self.num_bins = num_timestep_bins
self.num_patterns = num_patterns_per_timestep
self.num_classes = num_timestep_bins * num_patterns_per_timestep
self.feature_sim_weight = feature_similarity_weight
self.rose_weight = rose_weight
self.ce_weight = ce_weight
self.pattern_diversity_weight = pattern_diversity_weight
self.use_soft_assignment = use_soft_assignment
self.temperature = temperature
def assign_patterns(
self,
features: torch.Tensor,
timestep_class: torch.Tensor,
crystal_centroids: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Assign samples to nearest pattern within their timestep bin.
FIXED: Uses COSINE SIMILARITY (not Euclidean distance) to match original trainer.
Args:
features: [B, D]
timestep_class: [B] - timestep bins [0, num_bins)
crystal_centroids: [num_bins, num_patterns, D]
Returns:
pattern_ids: [B] - pattern indices [0, num_patterns)
full_class_ids: [B] - full class [0, num_classes)
"""
B = features.shape[0]
# Get centroids for each sample's timestep
batch_centroids = crystal_centroids[timestep_class] # [B, num_patterns, D]
# Compute similarities (CRITICAL: Use cosine, not Euclidean!)
features_expanded = features.unsqueeze(1) # [B, 1, D]
similarities = F.cosine_similarity(
features_expanded,
batch_centroids,
dim=2
) # [B, num_patterns]
# Assign to nearest (highest similarity)
pattern_ids = similarities.argmax(dim=1)
full_class_ids = timestep_class * self.num_patterns + pattern_ids
return pattern_ids, full_class_ids
def compute_soft_assignment(
self,
features: torch.Tensor,
timestep_class: torch.Tensor,
crystal_centroids: torch.Tensor
) -> torch.Tensor:
"""
Compute soft pattern assignment with temperature smoothing.
MATCHES ORIGINAL: Lines 120-156
Args:
features: [B, D]
timestep_class: [B] - timestep bins
crystal_centroids: [num_bins, num_patterns, D]
Returns:
soft_targets: [B, num_classes] - soft target distribution
"""
B, D = features.shape
device = features.device
# Get centroids for each sample's timestep bin
batch_centroids = crystal_centroids[timestep_class] # [B, num_patterns, D]
features_expanded = features.unsqueeze(1) # [B, 1, D]
# Compute cosine similarities
similarities = F.cosine_similarity(
features_expanded,
batch_centroids,
dim=2
) # [B, num_patterns]
# Soft assignment with temperature
pattern_probs = F.softmax(similarities / self.temperature, dim=1)
# Create full soft targets [B, num_classes]
soft_targets = torch.zeros(B, self.num_classes, device=device)
for i in range(B):
t = timestep_class[i]
start_idx = t * self.num_patterns
end_idx = start_idx + self.num_patterns
soft_targets[i, start_idx:end_idx] = pattern_probs[i]
return soft_targets
def compute_pattern_diversity_loss(
self,
logits: torch.Tensor,
timestep_class: torch.Tensor
) -> torch.Tensor:
"""
Encourage diverse pattern usage (prevent mode collapse).
MATCHES ORIGINAL: Lines 157-182
"""
B = logits.shape[0]
# For each sample, get pattern probs within its timestep
pattern_probs_list = []
for i in range(B):
t = timestep_class[i]
start_idx = t * self.num_patterns
end_idx = start_idx + self.num_patterns
probs = F.softmax(logits[i, start_idx:end_idx], dim=0)
pattern_probs_list.append(probs)
pattern_probs = torch.stack(pattern_probs_list) # [B, num_patterns]
# Entropy (higher = more diverse)
entropy = -(pattern_probs * torch.log(pattern_probs + 1e-8)).sum(dim=1).mean()
# Minimize negative entropy (maximize diversity)
return -entropy
def forward(
self,
student_features: torch.Tensor,
teacher_features: torch.Tensor,
student_logits: torch.Tensor,
crystal_centroids: torch.Tensor,
timesteps: torch.Tensor
) -> Tuple[torch.Tensor, Dict]:
"""
Compute full loss with pattern supervision.
Returns:
total_loss: Combined weighted loss
metrics: Dict of individual metrics
"""
# Timestep classification (0-999 -> 0-99 bins)
timestep_class = (timesteps // 10).clamp(0, self.num_bins - 1)
# Pattern assignment (use STUDENT features, not teacher!)
pattern_ids, full_class_ids = self.assign_patterns(
student_features, # โœ“ FIXED: Use student features [B, D]
timestep_class,
crystal_centroids
)
# Get target centroids for assigned patterns
target_centroids = torch.stack([
crystal_centroids[timestep_class[j], pattern_ids[j]]
for j in range(len(timestep_class))
])
# 1. Feature similarity loss (student vs target centroids)
feature_sim_loss = 1.0 - F.cosine_similarity(
student_features,
target_centroids, # โœ“ FIXED: Compare to centroids, not teacher
dim=-1
).mean()
# 2. Rose loss (MATCHES ORIGINAL: Same as feature_sim_loss!)
# Original trainer line 609: rose_loss = feature_sim_loss
rose_loss = feature_sim_loss # โœ“ Simple copy, not contrastive learning!
# 3. Cross-entropy with soft assignment (MATCHES ORIGINAL)
# Original trainer lines 612-617
if self.use_soft_assignment:
soft_targets = self.compute_soft_assignment(
student_features, timestep_class, crystal_centroids
)
log_probs = F.log_softmax(student_logits, dim=1)
ce_loss = -(soft_targets * log_probs).sum(dim=1).mean()
else:
ce_loss = F.cross_entropy(student_logits, full_class_ids)
# 4. Pattern diversity (MATCHES ORIGINAL: lines 622-623)
diversity_loss = self.compute_pattern_diversity_loss(
student_logits, timestep_class
)
# Total loss
total_loss = (
self.feature_sim_weight * feature_sim_loss +
self.rose_weight * rose_loss +
self.ce_weight * ce_loss +
self.pattern_diversity_weight * diversity_loss
)
# Accuracy metrics
timestep_pred = student_logits.argmax(dim=-1) // self.num_patterns
pattern_pred = student_logits.argmax(dim=-1) % self.num_patterns
full_pred = student_logits.argmax(dim=-1)
timestep_acc = (timestep_pred == timestep_class).float().mean()
pattern_acc = (pattern_pred == pattern_ids).float().mean()
full_acc = (full_pred == full_class_ids).float().mean()
metrics = {
'feature_sim': feature_sim_loss.item(),
'rose': rose_loss.item(),
'ce': ce_loss.item(),
'pattern_diversity': diversity_loss.item(),
'timestep_acc': timestep_acc.item(),
'pattern_acc': pattern_acc.item(),
'full_acc': full_acc.item()
}
return total_loss, metrics
# ============================================================================
# CONFIG
# ============================================================================
FULL_CONFIG = DavidCollectiveConfig(
# Timestep discretization
num_timestep_bins=100,
num_feature_patterns_per_timestep=10, # CORRECT parameter name
# Active blocks (all 9)
active_blocks=['down_0', 'down_1', 'down_2', 'down_3', 'mid', 'up_0', 'up_1', 'up_2', 'up_3'],
# David architecture
david_sharing_mode='fully_shared',
david_fusion_mode='deep_efficiency',
use_belly=True,
belly_expand=1.5,
# Loss weights
feature_similarity_weight=0.5,
rose_weight=0.3,
cayley_weight=0.0,
ce_weight=0.2,
# Geometric constraints
rose_margin=1.0,
rose_temperature=0.07,
cayley_volume_floor=1e-4,
# Progressive training
progressive_training=True,
warmup_epochs_per_block=2,
# No caching
cache_dir=None,
max_cache_size_gb=0.0
)
# ============================================================================
# SD1.5 EXTRACTOR - FIXED (NO POOLING)
# ============================================================================
class StreamingSD15Extractor:
"""
Extract features from SD1.5 UNet.
CRITICAL: Returns spatial features [B, C, H, W]
David's Companions will handle pooling internally.
"""
def __init__(
self,
model_id: str = "runwayml/stable-diffusion-v1-5",
device: str = "cuda",
active_blocks: List[str] = None
):
self.device = device
self.active_blocks = active_blocks or FULL_CONFIG.active_blocks
print(f"Loading SD1.5 from {model_id}...")
self.pipe = StableDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16,
safety_checker=None,
requires_safety_checker=False
).to(device)
self.unet = self.pipe.unet
self.vae = self.pipe.vae
self.text_encoder = self.pipe.text_encoder
self.tokenizer = self.pipe.tokenizer
self.scheduler = self.pipe.scheduler
self.features = {}
self.hooks = []
self.block_mapping = {
'down_0': ('down_blocks', 0),
'down_1': ('down_blocks', 1),
'down_2': ('down_blocks', 2),
'down_3': ('down_blocks', 3),
'mid': ('mid_block', None),
'up_0': ('up_blocks', 0),
'up_1': ('up_blocks', 1),
'up_2': ('up_blocks', 2),
'up_3': ('up_blocks', 3),
}
print(f"โœ“ SD1.5 loaded on {device}")
def _register_hooks(self):
def make_hook(name):
def hook(module, input, output):
# CRITICAL: Store WITH spatial dimensions
self.features[name] = output.detach().float()
return hook
self._remove_hooks()
for block_name in self.active_blocks:
block_type, idx = self.block_mapping[block_name]
if block_type == 'down_blocks':
block = self.unet.down_blocks[idx]
if hasattr(block, 'resnets') and len(block.resnets) > 0:
hook = block.resnets[-1].register_forward_hook(make_hook(block_name))
self.hooks.append(hook)
elif block_type == 'mid_block':
hook = self.unet.mid_block.register_forward_hook(make_hook(block_name))
self.hooks.append(hook)
elif block_type == 'up_blocks':
block = self.unet.up_blocks[idx]
if hasattr(block, 'resnets') and len(block.resnets) > 0:
hook = block.resnets[-1].register_forward_hook(make_hook(block_name))
self.hooks.append(hook)
def _remove_hooks(self):
for hook in self.hooks:
hook.remove()
self.hooks = []
@torch.no_grad()
def extract_batch(
self,
prompts: List[str],
timesteps: torch.Tensor
) -> Dict[str, torch.Tensor]:
"""
Extract features from SD1.5 UNet.
CRITICAL: Returns spatial features [B, C, H, W]
NO POOLING - Companions handle it internally
"""
self._register_hooks()
text_inputs = self.tokenizer(
prompts,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt"
).to(self.device)
text_embeddings = self.text_encoder(text_inputs.input_ids)[0]
B = len(prompts)
latents = torch.randn(B, 4, 64, 64, device=self.device, dtype=torch.float16)
for i, t in enumerate(timesteps):
noise = torch.randn_like(latents[i:i+1])
latents[i:i+1] = self.scheduler.add_noise(
latents[i:i+1],
noise,
t.unsqueeze(0)
)
self.features = {}
_ = self.unet(
latents,
timesteps.to(self.device),
encoder_hidden_states=text_embeddings
).sample
self._remove_hooks()
# Return features WITH spatial dimensions [B, C, H, W]
# NO POOLING HERE - Companions will handle it
return self.features.copy()
def __del__(self):
self._remove_hooks()
# ============================================================================
# SYMBOLIC PROMPT DATASET
# ============================================================================
class SymbolicPromptDataset(Dataset):
"""Generate prompts on-the-fly using synthesis system."""
def __init__(
self,
num_samples: int = 10000,
complexity_distribution: Optional[Dict[int, float]] = None,
bias_weights_path: Optional[str] = None,
seed: Optional[int] = None,
log_synthesis_stats: bool = False
):
self.num_samples = num_samples
self.log_synthesis_stats = log_synthesis_stats
if complexity_distribution is None:
complexity_distribution = {
1: 0.05, 2: 0.15, 3: 0.40, 4: 0.25, 5: 0.15
}
self.complexity_dist = complexity_distribution
# Initialize synthesis system (no seed parameter)
self.synth = SynthesisSystem()
# Apply bias weights if provided
if bias_weights_path and Path(bias_weights_path).exists():
with open(bias_weights_path, 'r') as f:
bias_weights = json.load(f)
# Apply bias weights to synthesis system
# (assuming it has some method to set them)
if hasattr(self.synth, 'bias_weights'):
self.synth.bias_weights = bias_weights
# Pre-generate prompts
self.rng = np.random.RandomState(seed)
self.prompts = []
self.metadata = []
print(f"Generating {num_samples:,} prompts...")
for i in range(num_samples):
# Sample complexity
complexities = list(complexity_distribution.keys())
probs = list(complexity_distribution.values())
complexity = self.rng.choice(complexities, p=probs)
# Generate prompt
try:
result = self.synth.synthesize(complexity=complexity) # โœ… Correct method name
# Extract text and path_info from result dict
if isinstance(result, dict):
prompt = result.get('text', 'a photo')
path_info = result.get('selected_paths', [])
else:
# Fallback if unexpected format
prompt = str(result)
path_info = {}
self.prompts.append(prompt)
if log_synthesis_stats:
self.metadata.append({
'complexity': complexity,
'path_info': path_info,
'sample_id': i
})
except Exception as e:
# Fallback prompt if generation fails
print(f" โš ๏ธ Warning: Failed to generate prompt {i}: {e}")
self.prompts.append("a photo")
if log_synthesis_stats:
self.metadata.append({
'complexity': complexity,
'path_info': {},
'sample_id': i,
'error': str(e)
})
if (i + 1) % 1000 == 0:
print(f" Generated {i+1:,}/{num_samples:,} prompts...")
print(f"โœ“ Generated {len(self.prompts):,} prompts")
if log_synthesis_stats:
self._log_statistics()
def _log_statistics(self):
from collections import Counter
complexity_counts = Counter(m['complexity'] for m in self.metadata)
print("\nSynthesis Statistics:")
print(" Complexity distribution:")
for complexity in sorted(complexity_counts.keys()):
count = complexity_counts[complexity]
pct = 100 * count / len(self.metadata)
print(f" Complexity {complexity}: {count:,} ({pct:.1f}%)")
print("\n Example prompts:")
for i in [0, len(self.prompts)//4, len(self.prompts)//2, 3*len(self.prompts)//4]:
complexity = self.metadata[i]['complexity']
print(f" [C={complexity}] {self.prompts[i][:80]}...")
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
prompt = self.prompts[idx]
timestep = self.rng.randint(0, 1000)
metadata = self.metadata[idx] if self.log_synthesis_stats else {}
return {
'prompt': prompt,
'timestep': torch.tensor(timestep),
'metadata': metadata
}
def collate_symbolic_batch(batch):
prompts = [item['prompt'] for item in batch]
timesteps = torch.stack([item['timestep'] for item in batch])
metadata = [item['metadata'] for item in batch]
return prompts, timesteps, metadata
# ============================================================================
# HUGGINGFACE UTILITIES
# ============================================================================
def convert_to_safetensors(checkpoint_path: str) -> str:
"""Convert .pt checkpoint to .safetensors format."""
if not HF_AVAILABLE:
print("โš ๏ธ Safetensors not available, skipping conversion")
return None
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
# Handle different checkpoint formats
if isinstance(checkpoint, dict):
if 'model_state_dict' in checkpoint:
state_dict = checkpoint['model_state_dict']
elif 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
# Assume the dict IS the state_dict
state_dict = checkpoint
else:
raise ValueError(f"Unexpected checkpoint format: {type(checkpoint)}")
output_path = checkpoint_path.replace('.pt', '.safetensors')
print(f" Converting to safetensors: {output_path}")
# Ensure all tensors are contiguous
state_dict_safe = {
k: v.contiguous() for k, v in state_dict.items()
}
# Save as safetensors
save_file(state_dict_safe, output_path)
size_mb = Path(output_path).stat().st_size / 1024**2
print(f" โœ“ Saved safetensors ({size_mb:.2f} MB)")
return output_path
def create_readme(
final_epoch: int,
history: dict,
config: DavidCollectiveConfig,
total_prompts: int
) -> str:
"""Generate comprehensive README for HuggingFace model card."""
readme = f"""# David Collective - SD1.5 Geometric Distillation (Continued)
## Model Description
**David Collective** is a revolutionary geometric deep learning system that distills Stable Diffusion 1.5's knowledge into an ultra-efficient pentachoron-based architecture. This model was continued from epoch 20 to epoch {final_epoch}, achieving remarkable performance with full pattern supervision.
### Architecture Highlights
- **Geometric Foundation**: Uses 5D pentachora (5-vertex simplices) instead of traditional attention
- **Multi-Scale Learning**: Extracts features from all 9 SD1.5 UNet blocks
- **Crystal Navigation**: 1000-class supervision (100 timesteps ร— 10 geometric patterns)
- **Parameter Efficiency**: Ultra-compact architecture with shared geometric structures
- **Full Supervision**: Every sample supervised by both timestep and geometric pattern
### Training Details
**Continuation Training:**
- Starting epoch: 20
- Final epoch: {final_epoch}
- Total prompts trained: {total_prompts:,}
- **All prompts included**: `prompts_all_epochs.jsonl` contains every prompt with metadata
- Dataset: Symbolic caption synthesis (complexity 1-5)
- Batch size: 32
- Learning rate: 1e-4 with cosine annealing
- Optimizer: AdamW (weight_decay=0.01)
**Final Metrics (Epoch {final_epoch}):**
- Total Loss: {history.get('total_loss', [0])[-1] if history.get('total_loss') else 0:.4f}
- Timestep Accuracy: {history.get('timestep_accuracy', [0])[-1] if history.get('timestep_accuracy') else 0:.2%}
- Pattern Accuracy: {history.get('pattern_accuracy', [0])[-1] if history.get('pattern_accuracy') else 0:.2%}
- Full Accuracy: {history.get('full_accuracy', [0])[-1] if history.get('full_accuracy') else 0:.2%}
- Pattern Diversity: {history.get('pattern_diversity', [0])[-1] if history.get('pattern_diversity') else 0:.3f}
### Active Blocks
David learns from all 9 SD1.5 UNet blocks:
- `down_0`, `down_1`, `down_2`, `down_3`: Coarse semantic features
- `mid`: Bottleneck representations
- `up_0`, `up_1`, `up_2`, `up_3`: Fine reconstruction details
### Loss Components
1. **Feature Similarity** ({config.feature_similarity_weight}): Cosine similarity with teacher
2. **Rose Loss** ({config.rose_weight}): Geometric alignment with crystal centroids
3. **Cross-Entropy** ({config.ce_weight}): 1000-class classification
4. **Pattern Diversity** (0.05): Encourages balanced pattern usage
## Usage
### Loading the Model
```python
import torch
from geovocab2.train.model.core.david_diffusion import DavidCollective, DavidCollectiveConfig
from safetensors.torch import load_file
# Load configuration
config = DavidCollectiveConfig(
num_timestep_bins=100,
num_feature_patterns_per_timestep=10,
active_blocks={config.active_blocks},
david_sharing_mode='fully_shared',
david_fusion_mode='deep_efficiency',
use_belly=True,
belly_expand=1.5
)
# Create model
model = DavidCollective(config)
# Load weights from safetensors
state_dict = load_file("model.safetensors")
model.load_state_dict(state_dict)
model.eval()
# Inference
with torch.no_grad():
outputs = model(teacher_features, timesteps)
```
### Training Data
This model includes `prompts_all_epochs.jsonl` - every single prompt used during training with full metadata:
```json
{{"timestamp": "2025-10-27T01:30:00", "epoch": 21, "batch": 0, "global_step": 6250, "sample_idx": 0, "timestep": 453, "timestep_bin": 45, "prompt": "a woman wearing red dress, against mountain landscape"}}
```
**Total prompts:** {total_prompts:,}
You can use this to:
- Analyze training data distribution
- Reproduce training
- Study prompt complexity vs model performance
- Generate similar synthetic datasets
## Technical Details
### Crystal System
- **Architecture**: Pentachoron-based geometric deep learning
- **Centroids**: 100 timestep bins ร— 10 patterns = 1000 anchors
- **Navigation**: Samples assigned to nearest pattern within timestep bin
- **Diversity**: Regularization prevents mode collapse
### Progressive Training
- Started with early blocks (down_0, down_1)
- Progressively activated all 9 blocks
- Each block warmed up for {config.warmup_epochs_per_block} epochs
### Pattern Supervision
Unlike traditional timestep-only supervision, David learns:
1. **When** (timestep bin 0-99)
2. **How** (geometric pattern 0-9 within that bin)
3. **Combined** (full 1000-class space)
This provides 10x finer-grained supervision of the diffusion process.
## Training History
Trained continuously from epoch 20 to epoch {final_epoch}. See metrics:
- Timestep accuracy improved from ~{history.get('timestep_accuracy', [0])[0] if history.get('timestep_accuracy') else 0:.1%} to {history.get('timestep_accuracy', [0])[-1] if history.get('timestep_accuracy') else 0:.2%}
- Pattern accuracy maintained at {history.get('pattern_accuracy', [0])[-1] if history.get('pattern_accuracy') else 0:.2%}
- Loss decreased from {history.get('total_loss', [0])[0] if history.get('total_loss') else 0:.4f} to {history.get('total_loss', [0])[-1] if history.get('total_loss') else 0:.4f}
## Citation
```bibtex
@misc{{david-collective-sd15,
title={{David Collective: Geometric Deep Learning for Diffusion Distillation}},
author={{AbstractPhil}},
year={{2025}},
publisher={{HuggingFace}},
howpublished={{\\url{{https://huggingface.co/AbstractPhil/david-collective-sd15-geometric-distillation}}}}
}}
```
## License
MIT License - See repository for details.
## Acknowledgments
Built on the geometric deep learning research by AbstractPhil, using:
- Stable Diffusion 1.5 (teacher model)
- Pentachoron-based geometric algebra
- Crystalline consciousness architectures
- Symbolic caption synthesis
For more information, visit the [geovocab2 repository](https://github.com/AbstractEyes/lattice_vocabulary).
"""
return readme
def create_model_card(config: DavidCollectiveConfig) -> dict:
"""Create model card metadata for HuggingFace."""
return {
'language': ['en'],
'license': 'mit',
'tags': [
'geometric-deep-learning',
'diffusion-distillation',
'stable-diffusion',
'pentachoron',
'crystal-navigation',
'pattern-supervision',
'ultra-efficient',
'sd15-distillation'
],
'datasets': ['synthetic-captions'],
'metrics': ['accuracy', 'loss'],
'library_name': 'pytorch',
'pipeline_tag': 'image-classification',
}
def upload_to_huggingface(
model_path: str,
repo_name: str = "AbstractPhil/david-collective-sd15-geometric-distillation",
final_epoch: int = 50,
history: dict = None,
config: DavidCollectiveConfig = None,
total_prompts: int = 0,
private: bool = False
):
"""Upload model to HuggingFace Hub with README and model card."""
if not HF_AVAILABLE:
print("\nโš ๏ธ HuggingFace libraries not available")
print("Install with: pip install huggingface_hub safetensors")
return None
print(f"\n{'='*80}")
print("UPLOADING TO HUGGINGFACE")
print(f"{'='*80}\n")
# Convert to safetensors
print("[1/5] Converting to safetensors...")
safetensors_path = convert_to_safetensors(model_path)
if not safetensors_path:
print("โŒ Conversion failed, aborting upload")
return None
# Create temporary upload directory
print("\n[2/5] Preparing upload directory...")
upload_dir = Path("./hf_upload_temp")
upload_dir.mkdir(exist_ok=True)
# Copy safetensors
import shutil
shutil.copy(safetensors_path, upload_dir / "model.safetensors")
print(f" โœ“ Copied model.safetensors")
# Copy prompts file (CRITICAL!)
prompt_file = Path("./prompts_all_epochs.jsonl")
if prompt_file.exists():
shutil.copy(prompt_file, upload_dir / "prompts_all_epochs.jsonl")
print(f" โœ“ Copied prompts_all_epochs.jsonl ({prompt_file.stat().st_size / 1024**2:.2f} MB)")
else:
print(f" โš ๏ธ Warning: prompts_all_epochs.jsonl not found, skipping")
# Generate README
print("\n[3/5] Generating README...")
readme_content = create_readme(final_epoch, history, config, total_prompts)
(upload_dir / "README.md").write_text(readme_content)
print(f" โœ“ Created README.md")
# Generate model card
print("\n[4/5] Creating model card...")
model_card = create_model_card(config)
(upload_dir / "model_card.json").write_text(json.dumps(model_card, indent=2))
print(f" โœ“ Created model_card.json")
# Save config
config_dict = {k: v for k, v in config.__dict__.items() if not k.startswith('_')}
(upload_dir / "config.json").write_text(json.dumps(config_dict, indent=2))
print(f" โœ“ Created config.json")
# Upload
print(f"\n[5/5] Uploading to {repo_name}...")
try:
api = HfApi()
# Create repo if doesn't exist
try:
create_repo(repo_name, private=private, exist_ok=True)
print(f" โœ“ Repository ready")
except Exception as e:
print(f" โš ๏ธ Repo might already exist: {e}")
# Upload folder
api.upload_folder(
folder_path=str(upload_dir),
repo_id=repo_name,
repo_type="model",
commit_message=f"Upload continuation training (epoch {final_epoch})"
)
print(f"\nโœ… UPLOAD COMPLETE!")
print(f"\n๐Ÿ”— View your model: https://huggingface.co/{repo_name}")
print(f"\n๐Ÿ“ฆ Uploaded files:")
print(f" - model.safetensors")
print(f" - prompts_all_epochs.jsonl ({total_prompts:,} prompts)")
print(f" - README.md (with metrics)")
print(f" - config.json")
print(f" - model_card.json")
# Cleanup
shutil.rmtree(upload_dir)
print(f"\n๐Ÿงน Cleaned up temporary files")
return f"https://huggingface.co/{repo_name}"
except Exception as e:
print(f"\nโŒ Upload failed: {e}")
print(f"Files are still in: {upload_dir}")
print(f"You can upload manually or fix the error and retry")
return None
# ============================================================================
# CONTINUATION TRAINING
# ============================================================================
def continue_training(
collective: DavidCollective,
extractor: StreamingSD15Extractor,
dataloader: DataLoader,
start_epoch: int,
num_epochs: int,
device: str = "cuda",
log_dir: str = "./runs/david_continued",
prompt_log_path: str = "./prompts_all_epochs.jsonl",
checkpoint_interval: int = 5,
auto_upload: bool = True,
hf_repo_name: str = "AbstractPhil/david-collective-sd15-geometric-distillation"
):
"""
Continue training from checkpoint with full logging and HuggingFace upload.
Args:
collective: DavidCollective model
extractor: SD1.5 feature extractor
dataloader: Training data
start_epoch: Epoch to start from (loaded from checkpoint)
num_epochs: Additional epochs to train
device: Device to train on
log_dir: TensorBoard log directory
prompt_log_path: Where to save all prompts
checkpoint_interval: Save checkpoint every N epochs
auto_upload: Automatically upload to HuggingFace after training
hf_repo_name: HuggingFace repository name
"""
print("\n" + "="*80)
print("DAVID COLLECTIVE - CONTINUATION TRAINING")
print("="*80)
# Load checkpoint
print(f"\n[1/7] Loading checkpoint...")
checkpoint_path = Path("david_collective_continued_final.pt")
if not checkpoint_path.exists():
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
# Capture initial weights for verification
print(f" Capturing initial weights...")
initial_sample_key = list(collective.state_dict().keys())[0]
initial_sample_weight = collective.state_dict()[initial_sample_key].clone()
initial_mean = initial_sample_weight.mean().item()
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
# Handle different checkpoint formats
if isinstance(checkpoint, dict):
# Check what keys are present
print(f" Checkpoint keys: {list(checkpoint.keys())}")
if 'model_state_dict' in checkpoint:
state_dict = checkpoint['model_state_dict']
actual_epoch = checkpoint.get('epoch', start_epoch)
elif 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
actual_epoch = checkpoint.get('epoch', start_epoch)
else:
# Assume the dict IS the state_dict
state_dict = checkpoint
actual_epoch = start_epoch
# Load with strict=False to see any issues
print(f" Loading state dict ({len(state_dict)} parameters)...")
missing_keys, unexpected_keys = collective.load_state_dict(state_dict, strict=False)
if missing_keys:
print(f" โš ๏ธ Missing keys ({len(missing_keys)}): {missing_keys[:3]}...")
if unexpected_keys:
print(f" โš ๏ธ Unexpected keys ({len(unexpected_keys)}): {unexpected_keys[:3]}...")
# Verify weights actually changed
final_sample_weight = collective.state_dict()[initial_sample_key]
final_mean = final_sample_weight.mean().item()
if torch.equal(initial_sample_weight, final_sample_weight):
raise RuntimeError(
f"โŒ CRITICAL: Weights did NOT change after loading!\n"
f" Sample param: {initial_sample_key}\n"
f" This means the checkpoint is not being loaded properly."
)
print(f" โœ“ Weights verified changed (sample mean: {initial_mean:.6f} -> {final_mean:.6f})")
print(f" โœ“ Loaded from epoch {actual_epoch}")
else:
# Not a dict - shouldn't happen but handle it
raise ValueError(f"Unexpected checkpoint format: {type(checkpoint)}")
# Model info
total_params = sum(p.numel() for p in collective.parameters())
print(f"\n Model Status:")
print(f" Parameters: {total_params:,}")
print(f" Active blocks: {len(collective.config.active_blocks)}")
print(f" Companions: {list(collective.companions.keys())}")
# Prompt logger
print(f"\n[2/7] Initializing prompt logger...")
prompt_logger = PromptLogger(prompt_log_path)
print(f" โœ“ Saving to: {prompt_log_path}")
# Loss function
print(f"\n[3/7] Setting up loss function...")
criterion = PatternSupervisedLoss(
num_timestep_bins=collective.config.num_timestep_bins,
num_patterns_per_timestep=collective.config.num_feature_patterns_per_timestep, # CORRECT attribute name
feature_similarity_weight=0.5,
rose_weight=0.3,
ce_weight=0.2,
pattern_diversity_weight=0.05
).to(device)
print(f" โœ“ Pattern-supervised loss ready")
# Optimizer
print(f"\n[4/7] Creating optimizer...")
optimizer = torch.optim.AdamW(
collective.parameters(),
lr=1e-4,
weight_decay=0.01
)
# Scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=num_epochs * len(dataloader),
eta_min=1e-6
)
print(f" โœ“ AdamW + Cosine annealing")
# TensorBoard
print(f"\n[5/7] Setting up TensorBoard...")
writer = SummaryWriter(log_dir)
print(f" โœ“ Logging to: {log_dir}")
# Training
print(f"\n[6/7] Starting training...")
print(f" Epochs: {start_epoch + 1} โ†’ {start_epoch + num_epochs}")
print(f" Batches per epoch: {len(dataloader)}")
print()
collective.train()
# Initialize history (match original lines 519-528)
history = {
'total_loss': [],
'feature_sim': [],
'rose': [],
'ce': [],
'pattern_diversity': [],
'timestep_accuracy': [],
'pattern_accuracy': [],
'full_accuracy': []
}
global_step = start_epoch * len(dataloader)
for epoch in range(start_epoch, start_epoch + num_epochs):
collective.update_epoch(epoch) # Match original line 534
epoch_metrics = {
'total_loss': 0.0,
'feature_sim': 0.0,
'rose': 0.0,
'ce': 0.0,
'pattern_diversity': 0.0,
'timestep_accuracy': 0.0,
'pattern_accuracy': 0.0,
'full_accuracy': 0.0,
'num_batches': 0
}
pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{start_epoch+num_epochs}")
for batch_idx, (prompts, timesteps, metadata) in enumerate(pbar):
batch_start = time.time()
# Extract features (WITH spatial dimensions!)
teacher_features = extractor.extract_batch(prompts, timesteps)
teacher_features = {
k: v.to(device) for k, v in teacher_features.items()
}
timesteps = timesteps.to(device)
# Verify shapes on first batch
if batch_idx == 0 and epoch == start_epoch:
print(f"\n๐Ÿ” Feature shapes (Epoch {epoch+1}, Batch 1):")
for k, v in teacher_features.items():
print(f" {k}: {v.shape}")
if v.dim() != 4:
raise ValueError(
f"Expected 4D features [B,C,H,W], got {v.dim()}D for {k}!"
)
print()
# Forward
outputs = collective(teacher_features, timesteps)
# Compute loss
total_loss = torch.tensor(0.0, device=device)
block_metrics = {k: [] for k in ['feature_sim', 'rose', 'ce',
'pattern_diversity', 'timestep_acc',
'pattern_acc', 'full_acc']}
for block_name in collective.companions.keys():
if block_name not in outputs or block_name not in teacher_features:
continue
# Use EXACT same structure as original trainer (lines 589-592)
companion = collective.companions[block_name]
block_output = outputs[block_name]
# Get features and timestep class FROM DavidCollective output
student_features = block_output['scale_features'][0] # First scale
student_logits = block_output['combined_logits']
timestep_class = block_output['timestep_class'] # โœ“ FROM OUTPUT, not recomputed!
# Get crystal centroids (lines 585-587 from original)
crystal_anchors = companion.crystal_anchors # [bins, patterns, 5, max_scale]
scale = companion.david_config.scales[0] # Use first scale
crystal_centroids = crystal_anchors[..., :scale].mean(dim=2) # [bins, patterns, scale]
# INLINE LOSS COMPUTATION (like original, NOT using forward())
# Assign patterns
_, full_class_ids = criterion.assign_patterns(
student_features, timestep_class, crystal_centroids
)
# Feature similarity loss
pattern_ids = full_class_ids % criterion.num_patterns
target_centroids = torch.stack([
crystal_centroids[timestep_class[j], pattern_ids[j]]
for j in range(len(timestep_class))
])
cos_sim = F.cosine_similarity(student_features, target_centroids, dim=-1)
feature_sim_loss = (1 - cos_sim).mean()
# Rose loss (same as feature sim)
rose_loss = feature_sim_loss
# Cross-entropy with pattern supervision
if criterion.use_soft_assignment:
soft_targets = criterion.compute_soft_assignment(
student_features, timestep_class, crystal_centroids
)
log_probs = F.log_softmax(student_logits, dim=1)
ce_loss = -(soft_targets * log_probs).sum(dim=1).mean()
else:
ce_loss = F.cross_entropy(student_logits, full_class_ids)
# Pattern diversity
diversity_loss = criterion.compute_pattern_diversity_loss(
student_logits, timestep_class
)
# Combined loss for this block
block_loss = (
criterion.feature_sim_weight * feature_sim_loss +
criterion.rose_weight * rose_loss +
criterion.ce_weight * ce_loss +
criterion.pattern_diversity_weight * diversity_loss
)
total_loss = total_loss + block_loss
# Accuracies (lines 626-642 from original)
pred_class = student_logits.argmax(dim=1)
pred_timestep = pred_class // criterion.num_patterns
pred_pattern = pred_class % criterion.num_patterns
true_pattern = full_class_ids % criterion.num_patterns
timestep_acc = (pred_timestep == timestep_class).float().mean()
correct_timestep_mask = (pred_timestep == timestep_class)
if correct_timestep_mask.sum() > 0:
pattern_acc = (
pred_pattern[correct_timestep_mask] == true_pattern[correct_timestep_mask]
).float().mean()
else:
pattern_acc = torch.tensor(0.0, device=device)
full_acc = (pred_class == full_class_ids).float().mean()
# Collect metrics
block_metrics['feature_sim'].append(feature_sim_loss.item())
block_metrics['rose'].append(rose_loss.item())
block_metrics['ce'].append(ce_loss.item())
block_metrics['pattern_diversity'].append(diversity_loss.item())
block_metrics['timestep_acc'].append(timestep_acc.item())
block_metrics['pattern_acc'].append(pattern_acc.item())
block_metrics['full_acc'].append(full_acc.item())
# Average across blocks (from original trainer lines 664-668)
num_processed_blocks = len([k for k in outputs.keys() if k in collective.companions])
if num_processed_blocks > 0:
total_loss = total_loss / num_processed_blocks
# Backward
optimizer.zero_grad()
total_loss.backward()
torch.nn.utils.clip_grad_norm_(collective.parameters(), 1.0)
optimizer.step()
scheduler.step()
# Log prompts (CRITICAL!)
prompt_logger.log_batch(prompts, timesteps, epoch + 1, batch_idx, global_step)
# Aggregate metrics (match original lines 677-686)
epoch_metrics['total_loss'] += total_loss.item()
epoch_metrics['feature_sim'] += np.mean(block_metrics['feature_sim'])
epoch_metrics['rose'] += np.mean(block_metrics['rose'])
epoch_metrics['ce'] += np.mean(block_metrics['ce'])
epoch_metrics['pattern_diversity'] += np.mean(block_metrics['pattern_diversity'])
epoch_metrics['timestep_accuracy'] += np.mean(block_metrics['timestep_acc'])
epoch_metrics['pattern_accuracy'] += np.mean(block_metrics['pattern_acc'])
epoch_metrics['full_accuracy'] += np.mean(block_metrics['full_acc'])
epoch_metrics['num_batches'] += 1
# TensorBoard logging (match original lines 688-696)
writer.add_scalar('Train/Total_Loss', total_loss.item(), global_step)
writer.add_scalar('Train/Feature_Similarity', np.mean(block_metrics['feature_sim']), global_step)
writer.add_scalar('Train/Rose_Loss', np.mean(block_metrics['rose']), global_step)
writer.add_scalar('Train/CE_Loss', np.mean(block_metrics['ce']), global_step)
writer.add_scalar('Train/Pattern_Diversity', np.mean(block_metrics['pattern_diversity']), global_step)
writer.add_scalar('Train/Timestep_Accuracy', np.mean(block_metrics['timestep_acc']), global_step)
writer.add_scalar('Train/Pattern_Accuracy', np.mean(block_metrics['pattern_acc']), global_step)
writer.add_scalar('Train/Full_Accuracy', np.mean(block_metrics['full_acc']), global_step)
# Update progress bar
pbar.set_postfix({
'loss': f"{total_loss.item():.4f}",
't_acc': f"{np.mean(block_metrics['timestep_acc']):.1%}" if block_metrics['timestep_acc'] else "N/A",
'p_acc': f"{np.mean(block_metrics['pattern_acc']):.1%}" if block_metrics['pattern_acc'] else "N/A",
})
global_step += 1
# Cleanup
del teacher_features, outputs, total_loss
torch.cuda.empty_cache()
# Epoch summary (match original lines 709-725)
for key in epoch_metrics:
if key != 'num_batches':
avg = epoch_metrics[key] / epoch_metrics['num_batches']
history[key].append(avg)
writer.add_scalar(f'Epoch/{key}', avg, epoch)
print(f"\nEpoch {epoch+1} Summary:")
print(f" Loss: {history['total_loss'][-1]:.4f}")
print(f" Timestep Acc: {history['timestep_accuracy'][-1]:.2%}")
print(f" Pattern Acc: {history['pattern_accuracy'][-1]:.2%}")
print(f" Full Acc: {history['full_accuracy'][-1]:.2%}")
print(f" Pattern Diversity: {history['pattern_diversity'][-1]:.3f}")
# Save checkpoint
if (epoch + 1) % checkpoint_interval == 0:
checkpoint_path = f"checkpoint_continued_epoch_{epoch+1:03d}.pt"
torch.save({
'epoch': epoch + 1,
'model_state_dict': collective.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'history': history,
'config': collective.config.__dict__
}, checkpoint_path)
print(f" โœ“ Saved: {checkpoint_path}")
# Also save as safetensors
if HF_AVAILABLE:
convert_to_safetensors(checkpoint_path)
# Final checkpoint
final_path = "david_collective_continued_final.pt"
torch.save({
'epoch': start_epoch + num_epochs,
'model_state_dict': collective.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'history': history,
'config': collective.config.__dict__
}, final_path)
print(f"\nโœ… Final checkpoint: {final_path}")
# Get prompt stats
prompt_stats = prompt_logger.get_stats()
print(f"โœ… Prompts logged: {prompt_stats['total']:,} ({prompt_stats['size_mb']:.2f} MB)")
writer.close()
# HuggingFace upload
if auto_upload:
print(f"\n[7/7] Uploading to HuggingFace...")
upload_to_huggingface(
model_path=final_path,
repo_name=hf_repo_name,
final_epoch=start_epoch + num_epochs,
history=history,
config=collective.config,
total_prompts=prompt_stats['total'],
private=False
)
else:
print(f"\n[7/7] Skipping HuggingFace upload (auto_upload=False)")
return collective, history
# ============================================================================
# MAIN
# ============================================================================
def main():
print("\n" + "="*80)
print("DAVID COLLECTIVE - COMPLETE CONTINUATION SYSTEM")
print("Checkpoint loading + Prompt logging + HuggingFace upload")
print("="*80)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"\nDevice: {device}")
if device == "cpu":
print("โš ๏ธ WARNING: Requires GPU!")
return
# Load SD1.5
print(f"\n[1/4] Loading SD1.5...")
extractor = StreamingSD15Extractor(
model_id="runwayml/stable-diffusion-v1-5",
device=device,
active_blocks=FULL_CONFIG.active_blocks
)
# Create dataset
print(f"\n[2/4] Creating symbolic dataset...")
dataset = SymbolicPromptDataset(
num_samples=100000,
complexity_distribution={
1: 0.05, 2: 0.15, 3: 0.40, 4: 0.25, 5: 0.15
},
seed=42,
log_synthesis_stats=True
)
dataloader = DataLoader(
dataset,
batch_size=256,
shuffle=True,
num_workers=0,
pin_memory=True,
collate_fn=collate_symbolic_batch
)
print(f" โœ“ Dataset: {len(dataset):,} samples")
# Initialize collective
print(f"\n[3/4] Initializing DavidCollective...")
collective = DavidCollective(FULL_CONFIG).to(device)
print(f" โœ“ Ready for continuation training")
# Continue training
print(f"\n[4/4] Starting continuation training...")
collective, history = continue_training(
collective=collective,
extractor=extractor,
dataloader=dataloader,
start_epoch=100, # Adjust based on your checkpoint
num_epochs=5,
device=device,
log_dir="./runs/david_continued",
prompt_log_path="./prompts_all_epochs.jsonl",
checkpoint_interval=1,
auto_upload=True, # Set to False to skip HuggingFace upload
hf_repo_name="AbstractPhil/david-collective-sd15-geometric-distillation"
)
print("\n" + "="*80)
print("TRAINING COMPLETE!")
print("="*80)
print(f"\n๐Ÿ“ Files:")
print(f" Model: david_collective_continued_final.pt")
print(f" Prompts: ./prompts_all_epochs.jsonl")
print(f" Logs: ./runs/david_continued")
print(f"\n๐Ÿ“Š Final Metrics:")
print(f" Loss: {history.get('total_loss', [0])[-1] if history.get('total_loss') else 0:.4f}")
print(f" Timestep Acc: {history.get('timestep_accuracy', [0])[-1] if history.get('timestep_accuracy') else 0:.2%}")
print(f" Pattern Acc: {history.get('pattern_accuracy', [0])[-1] if history.get('pattern_accuracy') else 0:.2%}")
print(f" Full Acc: {history.get('full_accuracy', [0])[-1] if history.get('full_accuracy') else 0:.2%}")
return collective, history, extractor
if __name__ == "__main__":
collective, history, extractor = main()