AbstractPhil's picture
Create trainer.py
656b6dd verified
"""
GeoDavidCollective Trainer
==============================================
Complete training system for ProjectiveHead-enhanced GeoDavidCollective:
- Proven data pipeline (StreamingSD15Extractor, SymbolicPromptDataset)
- Enhanced GeoDavidCollective with ProjectiveHead architecture
- Comprehensive logging and checkpointing
- HuggingFace Hub integration is clearly broken because Claude removed it and didn't put it back in when I asked four times.
Author: AbstractPhil
License: MIT
"""
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from pathlib import Path
from typing import Dict, List, Optional
import json
import numpy as np
from datetime import datetime
# Diffusers
from diffusers import StableDiffusionPipeline
# ENHANCED: Import GeoDavidCollective Enhanced
from geovocab2.train.model.core.geo_david_collective import GeoDavidCollective
# Symbolic synthesis
from geovocab2.data.prompt.symbolic_tree import SynthesisSystem
# HuggingFace
try:
from huggingface_hub import HfApi, create_repo, upload_folder
from safetensors.torch import save_file
HF_AVAILABLE = True
except ImportError:
HF_AVAILABLE = False
# ============================================================================
# PROMPT LOGGER
# ============================================================================
class PromptLogger:
"""Logs all prompts with metadata to JSONL, flushed per batch."""
def __init__(self, output_path: str = "./prompts_all_epochs.jsonl"):
self.output_path = Path(output_path)
self.output_path.parent.mkdir(parents=True, exist_ok=True)
# Create/truncate file
with open(self.output_path, 'w') as f:
f.write("")
self.batch_count = 0
print(f"βœ“ PromptLogger initialized: {self.output_path}")
def log_batch(
self,
prompts: List[str],
timesteps: torch.Tensor,
epoch: int,
batch_idx: int,
global_step: int
):
"""Log batch of prompts with immediate flush."""
with open(self.output_path, 'a') as f:
for i, (prompt, t) in enumerate(zip(prompts, timesteps)):
entry = {
'timestamp': datetime.now().isoformat(),
'epoch': epoch,
'batch': batch_idx,
'global_step': global_step,
'sample_idx': i,
'timestep': int(t.item()),
'timestep_bin': int(t.item()) // 10,
'prompt': prompt
}
f.write(json.dumps(entry) + '\n')
f.flush()
self.batch_count += 1
if self.batch_count % 100 == 0:
print(f" πŸ“ Logged {self.batch_count} batches ({self.batch_count * len(prompts):,} prompts)")
def get_stats(self) -> dict:
"""Get statistics about logged prompts."""
if not self.output_path.exists():
return {'total': 0}
with open(self.output_path, 'r') as f:
lines = f.readlines()
return {
'total': len(lines),
'size_mb': self.output_path.stat().st_size / 1024**2
}
# ============================================================================
# SD1.5 FEATURE EXTRACTOR
# ============================================================================
class StreamingSD15Extractor:
"""
Extract features from SD1.5 UNet blocks.
Returns SPATIAL features [B, C, H, W], not pooled.
"""
def __init__(
self,
model_id: str = "runwayml/stable-diffusion-v1-5",
device: str = "cuda",
active_blocks: List[str] = None
):
self.device = device
# Default blocks compatible with GeoDavidCollective
self.active_blocks = active_blocks or ['down_0', 'down_1', 'mid', 'up_0']
# Load pipeline
self.pipe = StableDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16,
safety_checker=None
).to(device)
self.unet = self.pipe.unet
self.unet.eval()
# Setup hooks
self.features = {}
self._register_hooks()
print(f"βœ“ StreamingSD15Extractor initialized")
print(f" Active blocks: {self.active_blocks}")
def _register_hooks(self):
"""Register forward hooks to capture block features."""
def make_hook(name):
def hook(module, input, output):
# Store spatial features [B, C, H, W]
if isinstance(output, tuple):
output = output[0]
self.features[name] = output.detach()
return hook
# Down blocks
for i, block in enumerate(self.unet.down_blocks):
name = f'down_{i}'
if name in self.active_blocks:
block.register_forward_hook(make_hook(name))
# Mid block
if 'mid' in self.active_blocks:
self.unet.mid_block.register_forward_hook(make_hook('mid'))
# Up blocks
for i, block in enumerate(self.unet.up_blocks):
name = f'up_{i}'
if name in self.active_blocks:
block.register_forward_hook(make_hook(name))
@torch.no_grad()
def extract_features(
self,
prompts: List[str],
timesteps: torch.Tensor
) -> Dict[str, torch.Tensor]:
"""
Extract features for a batch of prompts at given timesteps.
Returns:
Dict mapping block names to spatial features [B, C, H, W] in float32
"""
self.features = {}
# Encode prompts
text_inputs = self.pipe.tokenizer(
prompts,
padding="max_length",
max_length=self.pipe.tokenizer.model_max_length,
truncation=True,
return_tensors="pt"
)
text_embeddings = self.pipe.text_encoder(
text_inputs.input_ids.to(self.device)
)[0]
# Create noisy latents
latents = torch.randn(
len(prompts), 4, 64, 64,
device=self.device,
dtype=torch.float16
)
# Forward pass through UNet (features captured by hooks)
_ = self.unet(
latents,
timesteps,
encoder_hidden_states=text_embeddings
)
# Convert features to float32 (collective expects float32)
features_float32 = {
name: feat.float()
for name, feat in self.features.items()
}
return features_float32
# ============================================================================
# DATASET
# ============================================================================
class SymbolicPromptDataset(Dataset):
"""Generate prompts on-the-fly using synthesis system."""
def __init__(
self,
num_samples: int = 10000,
complexity_distribution: Optional[Dict[int, float]] = None,
bias_weights_path: Optional[str] = None,
seed: Optional[int] = None,
log_synthesis_stats: bool = False
):
self.num_samples = num_samples
self.log_stats = log_synthesis_stats
# Initialize synthesis system
self.synthesizer = SynthesisSystem(seed=seed)
# Load bias weights if provided
if bias_weights_path:
self.synthesizer.load_bias_weights(bias_weights_path)
# Complexity distribution (1-5)
self.complexity_dist = complexity_distribution or {
1: 0.05,
2: 0.15,
3: 0.40,
4: 0.30,
5: 0.10
}
# Precompute complexity for each sample
complexities = list(self.complexity_dist.keys())
probs = [self.complexity_dist[c] for c in complexities]
rng = np.random.RandomState(seed)
self.complexities = rng.choice(
complexities,
size=num_samples,
p=probs
)
print(f"βœ“ SymbolicPromptDataset: {num_samples:,} samples")
print(f" Complexity distribution: {self.complexity_dist}")
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
complexity = self.complexities[idx]
# Generate prompt
result = self.synthesizer.synthesize(complexity=complexity)
prompt = result['text'] # Extract text from synthesis result dict
# Random timestep [0, 999]
timestep = np.random.randint(0, 1000)
return {
'prompt': prompt,
'timestep': timestep,
'complexity': complexity
}
def collate_symbolic_batch(batch):
"""Collate batch for DataLoader."""
return {
'prompts': [item['prompt'] for item in batch],
'timesteps': torch.tensor([item['timestep'] for item in batch], dtype=torch.long),
'complexities': torch.tensor([item['complexity'] for item in batch], dtype=torch.long)
}
# ============================================================================
# SPATIAL POOLING
# ============================================================================
def spatial_pool_features(
features_dict: Dict[str, torch.Tensor],
pool_mode: str = 'mean'
) -> Dict[str, torch.Tensor]:
"""
Pool spatial dimensions [B, C, H, W] β†’ [B, C].
Args:
features_dict: Dict of spatial features
pool_mode: 'mean', 'max', or 'adaptive'
Returns:
Dict of pooled features [B, C]
"""
pooled = {}
for name, feat in features_dict.items():
if feat.dim() == 4: # [B, C, H, W]
if pool_mode == 'mean':
pooled[name] = feat.mean(dim=[-2, -1]) # [B, C]
elif pool_mode == 'max':
pooled[name] = feat.flatten(2).max(dim=-1)[0] # [B, C]
elif pool_mode == 'adaptive':
# Mix mean and max
mean_pool = feat.mean(dim=[-2, -1])
max_pool = feat.flatten(2).max(dim=-1)[0]
pooled[name] = 0.7 * mean_pool + 0.3 * max_pool
else:
pooled[name] = feat
return pooled
# ============================================================================
# TRAINING FUNCTION
# ============================================================================
def train_geo_collective(
collective: GeoDavidCollective,
extractor: StreamingSD15Extractor,
dataloader: DataLoader,
num_epochs: int,
device: str,
learning_rate: float = 1e-4,
weight_decay: float = 0.01,
log_dir: str = "./runs/geo_collective",
prompt_log_path: str = "./prompts_all_epochs.jsonl",
checkpoint_interval: int = 5,
checkpoint_dir: str = "./checkpoints",
pool_mode: str = 'mean'
):
"""
Train GeoDavidCollective with full data pipeline.
Args:
collective: GeoDavidCollective model (enhanced version)
extractor: StreamingSD15Extractor
dataloader: DataLoader with symbolic prompts
num_epochs: Number of training epochs
device: 'cuda' or 'cpu'
learning_rate: Learning rate
weight_decay: Weight decay for AdamW
log_dir: TensorBoard log directory
prompt_log_path: Path to save prompt logs
checkpoint_interval: Save checkpoint every N epochs
checkpoint_dir: Checkpoint directory
pool_mode: Spatial pooling mode ('mean', 'max', 'adaptive')
"""
# Setup
collective = collective.to(device)
collective.train()
# Optimizer & Scheduler
optimizer = torch.optim.AdamW(
collective.parameters(),
lr=learning_rate,
weight_decay=weight_decay
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=num_epochs * len(dataloader)
)
# Logging
writer = SummaryWriter(log_dir=log_dir)
prompt_logger = PromptLogger(output_path=prompt_log_path)
# Checkpoint dir
Path(checkpoint_dir).mkdir(parents=True, exist_ok=True)
# Training history
history = {
'total_loss': [],
'avg_cayley': [],
'avg_timestep_acc': [],
'avg_pattern_acc': [],
'avg_full_acc': []
}
global_step = 0
print("\n" + "="*80)
print("STARTING TRAINING")
print("="*80)
print(f" Device: {device}")
print(f" Epochs: {num_epochs}")
print(f" Batches per epoch: {len(dataloader)}")
print(f" Learning rate: {learning_rate}")
print(f" Spatial pooling: {pool_mode}")
print("="*80 + "\n")
for epoch in range(num_epochs):
epoch_metrics = {
'total_loss': 0.0,
'avg_cayley': 0.0,
'avg_timestep_acc': 0.0,
'avg_pattern_acc': 0.0,
'avg_full_acc': 0.0,
'num_batches': 0
}
pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
for batch_idx, batch in enumerate(pbar):
prompts = batch['prompts']
timesteps = batch['timesteps'].to(device)
# Log prompts
prompt_logger.log_batch(
prompts,
timesteps.cpu(),
epoch,
batch_idx,
global_step
)
# Extract SD1.5 features (spatial [B, C, H, W])
with torch.no_grad():
teacher_features_spatial = extractor.extract_features(prompts, timesteps)
# Pool to [B, C]
teacher_features = spatial_pool_features(teacher_features_spatial, pool_mode)
features_dict = {
name: feat.clone() + 0.01 * torch.randn_like(feat)
for name, feat in teacher_features.items()
}
# Forward pass
outputs = collective(features_dict, timesteps.float())
# Compute loss (now internal to model)
loss, metrics = collective.compute_loss(
outputs,
teacher_features,
timesteps.float()
)
# Backward pass
optimizer.zero_grad()
loss.backward()
# Gradient clipping
grad_norm = torch.nn.utils.clip_grad_norm_(
collective.parameters(), max_norm=1.0
)
optimizer.step()
scheduler.step()
# Accumulate metrics
batch_metrics = {
'total_loss': metrics['total_loss'],
'avg_cayley': metrics['avg/cayley'],
'avg_timestep_acc': metrics['avg/timestep_acc'],
'avg_pattern_acc': metrics['avg/pattern_acc'],
'avg_full_acc': metrics['avg/full_acc']
}
for k, v in batch_metrics.items():
epoch_metrics[k] += v
epoch_metrics['num_batches'] += 1
# TensorBoard logging (every step)
writer.add_scalar('Train/total_loss', batch_metrics['total_loss'], global_step)
writer.add_scalar('Train/cayley', batch_metrics['avg_cayley'], global_step)
writer.add_scalar('Train/timestep_acc', batch_metrics['avg_timestep_acc'], global_step)
writer.add_scalar('Train/pattern_acc', batch_metrics['avg_pattern_acc'], global_step)
writer.add_scalar('Train/full_acc', batch_metrics['avg_full_acc'], global_step)
writer.add_scalar('Train/grad_norm', grad_norm.item(), global_step)
writer.add_scalar('Train/lr', optimizer.param_groups[0]['lr'], global_step)
# Update progress bar
pbar.set_postfix({
'loss': f"{batch_metrics['total_loss']:.4f}",
'cayley': f"{batch_metrics['avg_cayley']:.4f}",
't_acc': f"{batch_metrics['avg_timestep_acc']:.1%}",
'p_acc': f"{batch_metrics['avg_pattern_acc']:.1%}",
'f_acc': f"{batch_metrics['avg_full_acc']:.1%}"
})
global_step += 1
# Cleanup
del teacher_features_spatial, teacher_features, features_dict, outputs, loss
torch.cuda.empty_cache()
# Epoch summary
for k in ['total_loss', 'avg_cayley', 'avg_timestep_acc', 'avg_pattern_acc', 'avg_full_acc']:
avg = epoch_metrics[k] / epoch_metrics['num_batches']
history[k].append(avg)
writer.add_scalar(f'Epoch/{k}', avg, epoch)
print(f"\nEpoch {epoch+1} Summary:")
print(f" Loss: {history['total_loss'][-1]:.4f}")
print(f" Cayley: {history['avg_cayley'][-1]:.4f}")
print(f" Timestep Acc: {history['avg_timestep_acc'][-1]:.2%}")
print(f" Pattern Acc: {history['avg_pattern_acc'][-1]:.2%}")
print(f" Full Acc: {history['avg_full_acc'][-1]:.2%}")
# Get Cantor alphas
alphas = collective.get_cantor_alphas()
print(f" Cantor Alphas: {', '.join([f'{k}={v:.3f}' for k, v in list(alphas.items())[:]])}")
# Save checkpoint
if (epoch + 1) % checkpoint_interval == 0:
checkpoint_path = Path(checkpoint_dir) / f"checkpoint_epoch_{epoch+1:03d}.pt"
torch.save({
'epoch': epoch + 1,
'global_step': global_step,
'model_state_dict': collective.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'history': history,
'model_info': collective.get_model_info()
}, checkpoint_path)
print(f" βœ“ Saved: {checkpoint_path}")
# Convert to safetensors
if HF_AVAILABLE:
safetensors_path = checkpoint_path.with_suffix('.safetensors')
save_file(collective.state_dict(), str(safetensors_path))
print(f" βœ“ Safetensors: {safetensors_path}")
# Final checkpoint
final_path = Path(checkpoint_dir) / "final.pt"
torch.save({
'epoch': num_epochs,
'global_step': global_step,
'model_state_dict': collective.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'history': history,
'model_info': collective.get_model_info()
}, final_path)
print(f"\nβœ… Final checkpoint: {final_path}")
# Prompt stats
prompt_stats = prompt_logger.get_stats()
print(f"βœ… Prompts logged: {prompt_stats['total']:,} ({prompt_stats['size_mb']:.2f} MB)")
writer.close()
return collective, history
# ============================================================================
# MAIN
# ============================================================================
def main():
print("\n" + "="*80)
print("GEODAVIDCOLLECTIVE TRAINER - ENHANCED VERSION")
print("ProjectiveHead multi-expert architecture with proven data pipeline")
print("="*80)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"\nDevice: {device}")
if device == "cpu":
print("⚠️ WARNING: Training requires GPU!")
return
# ========================================================================
# CONFIGURATION - ENHANCED
# ========================================================================
# Block configurations with ProjectiveHead parameters
# These use auto-configuration based on scale_dim, but you can override
block_configs = {
# Down blocks (4)
'down_0': {
'input_dim': 320,
'scale_dim': 128, # Compressed for efficiency
'use_belly': True,
'belly_expand': 2.0,
# ProjectiveHead auto-configured (3 experts, 3 gates)
},
'down_1': {
'input_dim': 640,
'scale_dim': 192,
'use_belly': True,
'belly_expand': 2.0,
# ProjectiveHead auto-configured (3 experts, 3 gates)
},
'down_2': {
'input_dim': 1280,
'scale_dim': 256,
'use_belly': True,
'belly_expand': 2.0,
# ProjectiveHead auto-configured (3 experts, 3 gates)
},
'down_3': {
'input_dim': 1280,
'scale_dim': 256,
'use_belly': True,
'belly_expand': 2.0,
# ProjectiveHead auto-configured (3 experts, 3 gates)
},
# Mid block (1) - Most important, use higher capacity
'mid': {
'input_dim': 1280,
'scale_dim': 256,
'use_belly': True,
'belly_expand': 1.5,
# Custom ProjectiveHead: more experts for mid block
'num_experts': 4,
'num_gate_heads': 4,
},
# Up blocks (4)
'up_0': {
'input_dim': 1280,
'scale_dim': 256,
'use_belly': True,
'belly_expand': 2.0,
# ProjectiveHead auto-configured
},
'up_1': {
'input_dim': 1280,
'scale_dim': 256,
'use_belly': True,
'belly_expand': 2.0,
# ProjectiveHead auto-configured
},
'up_2': {
'input_dim': 640,
'scale_dim': 192,
'use_belly': True,
'belly_expand': 2.0,
# ProjectiveHead auto-configured
},
'up_3': {
'input_dim': 320,
'scale_dim': 128,
'use_belly': True,
'belly_expand': 1.5,
# ProjectiveHead auto-configured
}
}
# Block importance weights (mid-block most important)
block_weights = {
'down_0': 0.8,
'down_1': 1.0,
'down_2': 1.2,
'down_3': 1.3,
'mid': 1.5, # Highest importance
'up_0': 1.3,
'up_1': 1.2,
'up_2': 1.0,
'up_3': 0.8
}
# Geometric loss configuration - FIXED cayley_weight
loss_config = {
'feature_similarity_weight': 0.4,
'rose_weight': 0.25,
'ce_weight': 0.15,
'pattern_diversity_weight': 0.05,
'cayley_weight': 0.10, # FIXED: Was 0.0001, now 0.10 for proper geometry
'cantor_coherence_weight': 0.05,
'use_soft_assignment': True,
'temperature': 0.1,
# Cayley loss parameters
'cayley_volume_floor': 1e-4,
'cayley_chaos_scale': 1.0,
'cayley_edge_weight': 0.5,
'cayley_gram_weight': 0.1,
}
print("\nβœ“ Configuration loaded (ENHANCED)")
print(f" Blocks: {len(block_configs)}")
print(f" ProjectiveHead: Auto-configured based on scale_dim")
print(f" Loss weights: feature={loss_config['feature_similarity_weight']:.2f}, "
f"rose={loss_config['rose_weight']:.2f}, cayley={loss_config['cayley_weight']:.2f}")
# ========================================================================
# LOAD SD1.5
# ========================================================================
print(f"\n[1/4] Loading SD1.5...")
extractor = StreamingSD15Extractor(
model_id="runwayml/stable-diffusion-v1-5",
device=device,
active_blocks=list(block_configs.keys())
)
# ========================================================================
# CREATE DATASET
# ========================================================================
print(f"\n[2/4] Creating symbolic dataset...")
dataset = SymbolicPromptDataset(
num_samples=10000,
complexity_distribution={
1: 0.05, 2: 0.15, 3: 0.40, 4: 0.25, 5: 0.15
},
seed=42
)
dataloader = DataLoader(
dataset,
batch_size=16, # Adjusted for GPU memory
shuffle=True,
num_workers=2,
pin_memory=True,
collate_fn=collate_symbolic_batch
)
print(f" βœ“ Dataset: {len(dataset):,} samples")
print(f" βœ“ Batch size: 16")
# ========================================================================
# INITIALIZE MODEL - ENHANCED
# ========================================================================
print(f"\n[3/4] Initializing GeoDavidCollective (ENHANCED)...")
collective = GeoDavidCollective(
block_configs=block_configs,
num_timestep_bins=100,
num_patterns_per_bin=10,
block_weights=block_weights,
loss_config=loss_config
)
model_info = collective.get_model_info()
print(f" βœ“ Architecture: {model_info['architecture']}")
print(f" βœ“ Blocks: {model_info['num_blocks']}")
print(f" βœ“ Total parameters: {model_info['total_parameters']:,}")
print(f" βœ“ Timestep bins: {model_info['num_timestep_bins']}")
print(f" βœ“ Patterns per bin: {model_info['num_patterns_per_bin']}")
# Show ProjectiveHead configs
print(f"\n ProjectiveHead Configurations:")
for block_name, companion_info in list(model_info['companions'].items())[:3]:
print(f" {block_name}:")
print(f" Timestep head: {companion_info['timestep_head']['num_experts']} experts, "
f"{companion_info['timestep_head']['num_gate_heads']} gates")
print(f" ... and {len(model_info['companions'])-3} more blocks")
# ========================================================================
# TRAIN
# ========================================================================
print(f"\n[4/4] Starting training...")
collective, history = train_geo_collective(
collective=collective,
extractor=extractor,
dataloader=dataloader,
num_epochs=10,
device=device,
learning_rate=1e-3,
weight_decay=0.001,
log_dir="./runs/geo_collective_enhanced",
prompt_log_path="./prompts_enhanced.jsonl",
checkpoint_interval=2,
checkpoint_dir="./checkpoints_enhanced",
pool_mode='mean'
)
print("\n" + "="*80)
print("TRAINING COMPLETE!")
print("="*80)
print(f"\nπŸ“Š Final Metrics:")
print(f" Loss: {history['total_loss'][-1]:.4f}")
print(f" Cayley: {history['avg_cayley'][-1]:.4f}")
print(f" Timestep Acc: {history['avg_timestep_acc'][-1]:.2%}")
print(f" Pattern Acc: {history['avg_pattern_acc'][-1]:.2%}")
print(f" Full Acc: {history['avg_full_acc'][-1]:.2%}")
return collective, history
if __name__ == "__main__":
collective, history = main()