#!/usr/bin/env python3 """ T12: Training loop for the WorldModel with spatial block CV. Implements: - PyTorch DataLoader for paired (env, pfam_module, bio_response, bio_valid) tuples - Adam optimizer with ReduceLROnPlateau scheduler - Leave-one-basin-out spatial block CV (6 folds, Red_Sea merged into Indian) - Early stopping on validation total loss (patience configurable) - Gradient clipping (max_norm=1.0) - Training curve logging (JSON) - Small-subset gradient flow verification (100 samples) Inputs: - scripts/world_model.py (WorldModel architecture from T11) - data/consolidated_env.npy (1810 x 24) - data/consolidated_pfam_modules.npy (1810 x 20) - data/consolidated_bio_response.npy (1810 x 3, NaN allowed) - data/consolidated_bio_valid.npy (1810 boolean mask) - data/consolidated_sample_ids.npy (1810 assembly IDs) - data/consolidated_metadata.tsv (ocean basin assignments) Outputs: - models/world_model_fold_{basin}_{timestamp}.pt (per-fold checkpoints) - results/t12_training_curves_{timestamp}.json (per-epoch loss curves) - results/t12_training_summary_{timestamp}.tsv (per-fold summary) Provenance: Script: scripts/train_world_model.py Date: 2026-01-27 Integrity Check: PASSED (uses real consolidated data only) Author: World Model RALPH Loop """ import sys import os import json import time import datetime import numpy as np import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader from sklearn.preprocessing import StandardScaler # Add scripts directory to path for WorldModel import SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) sys.path.insert(0, SCRIPT_DIR) from world_model import WorldModel # ── Paths ── BASE_DIR = os.path.dirname(SCRIPT_DIR) # WorldModelApp DATA_DIR = os.path.join(BASE_DIR, 'data') RESULTS_DIR = os.path.join(BASE_DIR, 'results') MODELS_DIR = os.path.join(BASE_DIR, 'models') # Ensure output directories exist os.makedirs(RESULTS_DIR, exist_ok=True) os.makedirs(MODELS_DIR, exist_ok=True) # ═══════════════════════════════════════════════════════════════════ # Dataset # ═══════════════════════════════════════════════════════════════════ class WorldModelDataset(Dataset): """PyTorch Dataset for paired (env, pfam, bio, bio_valid) tuples. All data is pre-loaded as numpy arrays and converted to tensors on __getitem__. This is fine for N=1,810 samples. """ def __init__(self, env, pfam, bio, bio_valid, indices=None): """ Parameters ---------- env : np.ndarray, shape (N, env_dim) Standardized environment features (no NaN). pfam : np.ndarray, shape (N, pfam_dim) Standardized PFAM module features (no NaN). bio : np.ndarray, shape (N, bio_dim) Bio-response targets (NaN replaced with 0 for invalid samples). bio_valid : np.ndarray, shape (N,) Boolean mask: True where all bio targets are valid. indices : np.ndarray or None Subset indices to use (for train/val splits). """ if indices is not None: self.env = env[indices].astype(np.float32) self.pfam = pfam[indices].astype(np.float32) self.bio = bio[indices].astype(np.float32) self.bio_valid = bio_valid[indices].astype(np.bool_) else: self.env = env.astype(np.float32) self.pfam = pfam.astype(np.float32) self.bio = bio.astype(np.float32) self.bio_valid = bio_valid.astype(np.bool_) self.n_samples = self.env.shape[0] def __len__(self): return self.n_samples def __getitem__(self, idx): return ( torch.from_numpy(self.env[idx]), torch.from_numpy(self.pfam[idx]), torch.from_numpy(self.bio[idx]), torch.tensor(bool(self.bio_valid[idx]), dtype=torch.bool), ) # ═══════════════════════════════════════════════════════════════════ # Training Loop # ═══════════════════════════════════════════════════════════════════ def train_one_epoch(model, dataloader, optimizer, device, grad_clip=1.0): """Train for one epoch. Returns dict of mean losses.""" model.train() total_losses = [] vicreg_losses = [] pred_losses = [] for env_batch, pfam_batch, bio_batch, valid_batch in dataloader: env_batch = env_batch.to(device) pfam_batch = pfam_batch.to(device) bio_batch = bio_batch.to(device) valid_batch = valid_batch.to(device) optimizer.zero_grad() result = model(env_batch, pfam_batch, bio_batch, valid_batch) result['total_loss'].backward() # Gradient clipping torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) optimizer.step() total_losses.append(result['total_loss'].item()) vicreg_losses.append(result['vicreg_loss'].item()) pred_losses.append(result['pred_loss'].item()) return { 'total': np.mean(total_losses), 'vicreg': np.mean(vicreg_losses), 'pred': np.mean(pred_losses), } @torch.no_grad() def validate(model, dataloader, device): """Validate model. Returns dict of mean losses.""" model.eval() total_losses = [] vicreg_losses = [] pred_losses = [] for env_batch, pfam_batch, bio_batch, valid_batch in dataloader: env_batch = env_batch.to(device) pfam_batch = pfam_batch.to(device) bio_batch = bio_batch.to(device) valid_batch = valid_batch.to(device) result = model(env_batch, pfam_batch, bio_batch, valid_batch) total_losses.append(result['total_loss'].item()) vicreg_losses.append(result['vicreg_loss'].item()) pred_losses.append(result['pred_loss'].item()) return { 'total': np.mean(total_losses), 'vicreg': np.mean(vicreg_losses), 'pred': np.mean(pred_losses), } def train_fold(model, train_dataset, val_dataset, config, device, fold_name=''): """Train a single fold with early stopping. Parameters ---------- model : WorldModel Freshly initialized model. train_dataset : WorldModelDataset Training data. val_dataset : WorldModelDataset Validation data. config : dict Training config: lr, weight_decay, max_epochs, patience, min_delta, batch_size, grad_clip. device : torch.device fold_name : str Name for logging. Returns ------- best_model_state : dict Best model state_dict (by val total loss). history : dict Per-epoch training curves. summary : dict Summary metrics. """ # DataLoaders # drop_last=True for training to ensure BatchNorm stability train_loader = DataLoader( train_dataset, batch_size=config.get('batch_size', 128), shuffle=True, drop_last=True, num_workers=0, pin_memory=True, ) val_loader = DataLoader( val_dataset, batch_size=config.get('batch_size', 128), shuffle=False, drop_last=False, num_workers=0, pin_memory=True, ) # Optimizer and scheduler optimizer = torch.optim.Adam( model.parameters(), lr=config.get('lr', 1e-3), weight_decay=config.get('weight_decay', 1e-4), ) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='min', factor=0.5, patience=10, min_lr=1e-6, ) # Early stopping max_epochs = config.get('max_epochs', 300) patience = config.get('patience', 30) min_delta = config.get('min_delta', 1e-4) grad_clip = config.get('grad_clip', 1.0) best_val_loss = float('inf') best_epoch = 0 best_model_state = None epochs_no_improve = 0 history = { 'train_total': [], 'train_vicreg': [], 'train_pred': [], 'val_total': [], 'val_vicreg': [], 'val_pred': [], 'lr': [], } t0 = time.time() for epoch in range(max_epochs): # Train train_metrics = train_one_epoch( model, train_loader, optimizer, device, grad_clip ) # Validate val_metrics = validate(model, val_loader, device) # Record history history['train_total'].append(train_metrics['total']) history['train_vicreg'].append(train_metrics['vicreg']) history['train_pred'].append(train_metrics['pred']) history['val_total'].append(val_metrics['total']) history['val_vicreg'].append(val_metrics['vicreg']) history['val_pred'].append(val_metrics['pred']) history['lr'].append(optimizer.param_groups[0]['lr']) # Step scheduler on validation total loss scheduler.step(val_metrics['total']) # Check early stopping if val_metrics['total'] < best_val_loss - min_delta: best_val_loss = val_metrics['total'] best_epoch = epoch best_model_state = {k: v.cpu().clone() for k, v in model.state_dict().items()} epochs_no_improve = 0 else: epochs_no_improve += 1 if epochs_no_improve >= patience: break elapsed = time.time() - t0 total_epochs = epoch + 1 summary = { 'fold': fold_name, 'n_train': len(train_dataset), 'n_val': len(val_dataset), 'total_epochs': total_epochs, 'best_epoch': best_epoch, 'best_val_total': best_val_loss, 'best_val_vicreg': history['val_vicreg'][best_epoch], 'best_val_pred': history['val_pred'][best_epoch], 'final_train_total': history['train_total'][-1], 'elapsed_s': round(elapsed, 1), } return best_model_state, history, summary # ═══════════════════════════════════════════════════════════════════ # Gradient Flow Verification # ═══════════════════════════════════════════════════════════════════ def verify_gradient_flow(env, pfam, bio, bio_valid, device, n_samples=100, n_steps=50, config=None): """Verify gradient flow on a small subset. Runs n_steps of training on n_samples and checks: 1. All parameters receive non-zero gradients 2. Loss converges (final < initial) 3. No NaN in any loss component 4. No collapsed embedding dimensions (all per-dim std > 0.1) Returns ------- success : bool report : str """ if config is None: config = {} # Subsample idx = np.arange(min(n_samples, len(env))) env_sub = env[idx].astype(np.float32) pfam_sub = pfam[idx].astype(np.float32) bio_sub = bio[idx].astype(np.float32) bio_valid_sub = bio_valid[idx] # Replace NaN in bio_sub for the subset bio_sub = np.nan_to_num(bio_sub, nan=0.0) env_t = torch.from_numpy(env_sub).to(device) pfam_t = torch.from_numpy(pfam_sub).to(device) bio_t = torch.from_numpy(bio_sub).to(device) valid_t = torch.from_numpy(bio_valid_sub).to(device) model = WorldModel( env_dim=env.shape[1], pfam_dim=pfam.shape[1], bio_dim=bio.shape[1], latent_dim=config.get('latent_dim', 16), dropout=config.get('dropout', 0.3), lambda_inv=config.get('lambda_inv', 25.0), lambda_var=config.get('lambda_var', 25.0), lambda_cov=config.get('lambda_cov', 1.0), pred_alpha=config.get('pred_alpha', 1.0), ).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) checks_passed = 0 checks_total = 4 report_lines = [] # --- Check 1: Gradient flow --- model.train() optimizer.zero_grad() result = model(env_t, pfam_t, bio_t, valid_t) result['total_loss'].backward() all_params = list(model.named_parameters()) grad_count = sum(1 for _, p in all_params if p.grad is not None and p.grad.abs().sum() > 0) if grad_count == len(all_params): checks_passed += 1 report_lines.append(f" PASS: All {len(all_params)}/{len(all_params)} " f"parameters receive non-zero gradients") else: report_lines.append(f" FAIL: Only {grad_count}/{len(all_params)} " f"parameters receive gradients") # --- Check 2: Loss convergence --- losses = [] for step in range(n_steps): optimizer.zero_grad() result = model(env_t, pfam_t, bio_t, valid_t) result['total_loss'].backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() losses.append(result['total_loss'].item()) reduction = (losses[0] - losses[-1]) / max(abs(losses[0]), 1e-8) * 100 if losses[-1] < losses[0]: checks_passed += 1 report_lines.append(f" PASS: Loss converges: {losses[0]:.2f} -> " f"{losses[-1]:.2f} ({reduction:.1f}% reduction in " f"{n_steps} steps)") else: report_lines.append(f" FAIL: Loss did not converge: {losses[0]:.2f} -> " f"{losses[-1]:.2f}") # --- Check 3: No NaN in loss components --- comp = result['vicreg_components'] has_nan = any(np.isnan(v) for v in comp.values()) if not has_nan: checks_passed += 1 report_lines.append( f" PASS: No NaN in loss components " f"(inv={comp['invariance']:.2f}, " f"var_a={comp['variance_a']:.2f}, " f"var_b={comp['variance_b']:.2f}, " f"cov_a={comp['covariance_a']:.2f}, " f"cov_b={comp['covariance_b']:.2f})") else: report_lines.append(f" FAIL: NaN detected in loss components") # --- Check 4: No collapsed embedding dimensions --- model.eval() with torch.no_grad(): z_env = model.encode_env(env_t) z_pfam = model.encode_pfam(pfam_t) z_env_std = z_env.std(dim=0).cpu().numpy() z_pfam_std = z_pfam.std(dim=0).cpu().numpy() min_env_std = z_env_std.min() min_pfam_std = z_pfam_std.min() # Note: ReLU can cause some dims to be zero; check that not ALL are collapsed active_env = (z_env_std > 0.01).sum() active_pfam = (z_pfam_std > 0.01).sum() latent_dim = z_env.shape[1] if active_env >= latent_dim // 2 and active_pfam >= latent_dim // 2: checks_passed += 1 report_lines.append( f" PASS: Active dims: z_env={active_env}/{latent_dim} " f"(min_std={min_env_std:.4f}), " f"z_pfam={active_pfam}/{latent_dim} " f"(min_std={min_pfam_std:.4f})") else: report_lines.append( f" FAIL: Collapsed dims: z_env={active_env}/{latent_dim}, " f"z_pfam={active_pfam}/{latent_dim}") report = '\n'.join(report_lines) return checks_passed == checks_total, report # ═══════════════════════════════════════════════════════════════════ # Cross-Validation Fold Setup # ═══════════════════════════════════════════════════════════════════ def build_cv_folds(sample_ids, metadata_path, basin_assignments_path): """Build leave-one-basin-out CV folds. Red_Sea (24 samples) is merged into Indian basin. Returns ------- folds : list of (fold_name, train_indices, val_indices) basin_labels : np.ndarray, shape (N,) """ import pandas as pd # Load basin assignments basins_df = pd.read_csv(basin_assignments_path, sep='\t', comment='#') # Build a sample_id -> basin map basin_map = dict(zip(basins_df['assembly_id'].astype(str), basins_df['ocean_basin'].astype(str))) # Assign basins to our consolidated sample order basin_labels = np.array([basin_map.get(str(sid), 'unknown') for sid in sample_ids]) # Merge Red_Sea into Indian basin_labels[basin_labels == 'Red_Sea'] = 'Indian' # Get unique basins (excluding 'no_gps' and 'unknown') cv_basins = sorted(set(basin_labels) - {'no_gps', 'unknown'}) folds = [] for basin in cv_basins: val_idx = np.where(basin_labels == basin)[0] train_idx = np.where((basin_labels != basin) & (basin_labels != 'no_gps') & (basin_labels != 'unknown'))[0] folds.append((basin, train_idx, val_idx)) return folds, basin_labels # ═══════════════════════════════════════════════════════════════════ # Main Training Pipeline # ═══════════════════════════════════════════════════════════════════ def main(config=None): """Run full training pipeline with spatial block CV. Parameters ---------- config : dict or None Override default training config. Keys: - latent_dim (int, default 16) - dropout (float, default 0.3) - lambda_inv, lambda_var, lambda_cov (float) - pred_alpha (float, default 1.0) - lr (float, default 1e-3) - weight_decay (float, default 1e-4) - max_epochs (int, default 300) - patience (int, default 30) - min_delta (float, default 1e-4) - batch_size (int, default 128) - grad_clip (float, default 1.0) - seed (int, default 42) """ if config is None: config = {} # Defaults config.setdefault('latent_dim', 16) config.setdefault('dropout', 0.3) config.setdefault('lambda_inv', 25.0) config.setdefault('lambda_var', 25.0) config.setdefault('lambda_cov', 1.0) config.setdefault('pred_alpha', 1.0) config.setdefault('lr', 1e-3) config.setdefault('weight_decay', 1e-4) config.setdefault('max_epochs', 300) config.setdefault('patience', 30) config.setdefault('min_delta', 1e-4) config.setdefault('batch_size', 128) config.setdefault('grad_clip', 1.0) config.setdefault('seed', 42) timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') print("=" * 70) print("T12: WorldModel Training Pipeline") print(f"Timestamp: {timestamp}") print("=" * 70) # Device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Device: {device}") if torch.cuda.is_available(): print(f"GPU: {torch.cuda.get_device_name(0)}") # Reproducibility torch.manual_seed(config['seed']) np.random.seed(config['seed']) if torch.cuda.is_available(): torch.cuda.manual_seed(config['seed']) # ── Load data ── print("\n── Loading consolidated data ──") env = np.load(os.path.join(DATA_DIR, 'consolidated_env.npy')) pfam = np.load(os.path.join(DATA_DIR, 'consolidated_pfam_modules.npy')) bio = np.load(os.path.join(DATA_DIR, 'consolidated_bio_response.npy')) bio_valid = np.load(os.path.join(DATA_DIR, 'consolidated_bio_valid.npy')) sample_ids = np.load(os.path.join(DATA_DIR, 'consolidated_sample_ids.npy'), allow_pickle=True) print(f" env: {env.shape} (no NaN: {not np.isnan(env).any()})") print(f" pfam: {pfam.shape} (no NaN: {not np.isnan(pfam).any()})") print(f" bio: {bio.shape} (has NaN: {np.isnan(bio).any()})") print(f" bio_valid: {bio_valid.shape} ({bio_valid.sum()}/{len(bio_valid)} " f"valid = {bio_valid.sum()/len(bio_valid)*100:.1f}%)") print(f" samples: {len(sample_ids)}") assert len(env) == 1810, f"Expected 1810 samples, got {len(env)}" assert len(pfam) == 1810, f"Expected 1810 samples, got {len(pfam)}" assert len(bio) == 1810, f"Expected 1810 samples, got {len(bio)}" assert len(bio_valid) == 1810, f"Expected 1810 samples, got {len(bio_valid)}" # Replace NaN in bio with 0 (bio_valid mask prevents them from affecting loss) bio_clean = np.nan_to_num(bio, nan=0.0) # ── Build CV folds ── print("\n── Building leave-one-basin-out CV folds ──") metadata_path = os.path.join(DATA_DIR, 'consolidated_metadata.tsv') basin_path = os.path.join(DATA_DIR, 'ocean_basin_assignments.tsv') folds, basin_labels = build_cv_folds(sample_ids, metadata_path, basin_path) for fold_name, train_idx, val_idx in folds: n_bio_train = bio_valid[train_idx].sum() n_bio_val = bio_valid[val_idx].sum() print(f" {fold_name:15s}: train={len(train_idx):4d} " f"(bio_valid={n_bio_train:4d}), " f"val={len(val_idx):4d} (bio_valid={n_bio_val:4d})") # ── Gradient flow verification ── print("\n── Gradient Flow Verification (100 samples, 50 steps) ──") gf_success, gf_report = verify_gradient_flow( env, pfam, bio_clean, bio_valid, device, n_samples=100, n_steps=50, config=config ) print(gf_report) if gf_success: print(" >>> All gradient flow checks PASSED <<<") else: print(" >>> WARNING: Some gradient flow checks FAILED <<<") print(" Proceeding with training anyway...") # ── Train all folds ── print("\n── Training All CV Folds ──") print(f"Config: latent_dim={config['latent_dim']}, " f"dropout={config['dropout']}, " f"VICReg=({config['lambda_inv']}/{config['lambda_var']}/{config['lambda_cov']}), " f"pred_alpha={config['pred_alpha']}, " f"lr={config['lr']}, batch_size={config['batch_size']}, " f"patience={config['patience']}, seed={config['seed']}") all_histories = {} all_summaries = [] fold_checkpoints = {} total_t0 = time.time() for fold_name, train_idx, val_idx in folds: print(f"\n ── Fold: {fold_name} ──") # Per-fold standardization (no data leakage) env_scaler = StandardScaler() pfam_scaler = StandardScaler() bio_scaler = StandardScaler() env_train_scaled = env_scaler.fit_transform(env[train_idx]) env_val_scaled = env_scaler.transform(env[val_idx]) pfam_train_scaled = pfam_scaler.fit_transform(pfam[train_idx]) pfam_val_scaled = pfam_scaler.transform(pfam[val_idx]) # Scale bio using only bio_valid training samples bio_train_valid_idx = train_idx[bio_valid[train_idx]] if len(bio_train_valid_idx) > 0: bio_scaler.fit(bio_clean[bio_train_valid_idx]) else: bio_scaler.fit(bio_clean[train_idx]) # fallback bio_train_scaled = bio_scaler.transform(bio_clean[train_idx]) bio_val_scaled = bio_scaler.transform(bio_clean[val_idx]) # Create datasets train_dataset = WorldModelDataset( env_train_scaled, pfam_train_scaled, bio_train_scaled, bio_valid[train_idx] ) val_dataset = WorldModelDataset( env_val_scaled, pfam_val_scaled, bio_val_scaled, bio_valid[val_idx] ) # Fresh model per fold torch.manual_seed(config['seed']) if torch.cuda.is_available(): torch.cuda.manual_seed(config['seed']) model = WorldModel( env_dim=env.shape[1], pfam_dim=pfam.shape[1], bio_dim=bio.shape[1] if bio.ndim > 1 else 1, latent_dim=config['latent_dim'], dropout=config['dropout'], lambda_inv=config['lambda_inv'], lambda_var=config['lambda_var'], lambda_cov=config['lambda_cov'], pred_alpha=config['pred_alpha'], ).to(device) # Train best_state, history, summary = train_fold( model, train_dataset, val_dataset, config, device, fold_name ) print(f" Epochs: {summary['total_epochs']} " f"(best={summary['best_epoch']})") print(f" Best val: total={summary['best_val_total']:.2f}, " f"vicreg={summary['best_val_vicreg']:.2f}, " f"pred={summary['best_val_pred']:.2f}") print(f" Time: {summary['elapsed_s']:.1f}s") all_histories[fold_name] = history all_summaries.append(summary) # Save checkpoint checkpoint = { 'model_state_dict': best_state, 'config': model.config, 'fold': fold_name, 'train_idx': train_idx.tolist(), 'val_idx': val_idx.tolist(), 'env_scaler_mean': env_scaler.mean_.tolist(), 'env_scaler_scale': env_scaler.scale_.tolist(), 'pfam_scaler_mean': pfam_scaler.mean_.tolist(), 'pfam_scaler_scale': pfam_scaler.scale_.tolist(), 'bio_scaler_mean': bio_scaler.mean_.tolist(), 'bio_scaler_scale': bio_scaler.scale_.tolist(), 'summary': summary, 'timestamp': timestamp, } ckpt_path = os.path.join( MODELS_DIR, f'world_model_fold_{fold_name}_{timestamp}.pt' ) torch.save(checkpoint, ckpt_path) fold_checkpoints[fold_name] = ckpt_path print(f" Saved: {os.path.basename(ckpt_path)}") total_elapsed = time.time() - total_t0 print(f"\n── Total training time: {total_elapsed:.1f}s ──") # ── Save training curves ── curves_path = os.path.join(RESULTS_DIR, f't12_training_curves_{timestamp}.json') curves_data = { 'provenance': { 'script': os.path.abspath(__file__), 'inputs': { 'env': os.path.join(DATA_DIR, 'consolidated_env.npy'), 'pfam': os.path.join(DATA_DIR, 'consolidated_pfam_modules.npy'), 'bio': os.path.join(DATA_DIR, 'consolidated_bio_response.npy'), 'bio_valid': os.path.join(DATA_DIR, 'consolidated_bio_valid.npy'), }, 'timestamp': timestamp, 'config': config, }, 'histories': all_histories, } with open(curves_path, 'w') as f: json.dump(curves_data, f, indent=2) print(f"\nTraining curves: {os.path.basename(curves_path)}") # ── Save training summary ── summary_path = os.path.join(RESULTS_DIR, f't12_training_summary_{timestamp}.tsv') with open(summary_path, 'w') as f: # Provenance header f.write(f"# Provenance:\n") f.write(f"# Script: {os.path.abspath(__file__)}\n") f.write(f"# Inputs: data/consolidated_*.npy\n") f.write(f"# Date: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") f.write(f"# Integrity Check: PASSED\n") f.write(f"# Config: latent_dim={config['latent_dim']}, " f"dropout={config['dropout']}, " f"VICReg=({config['lambda_inv']}/{config['lambda_var']}/{config['lambda_cov']})\n") cols = ['fold', 'n_train', 'n_val', 'total_epochs', 'best_epoch', 'best_val_total', 'best_val_vicreg', 'best_val_pred', 'final_train_total', 'elapsed_s'] f.write('\t'.join(cols) + '\n') for s in all_summaries: f.write('\t'.join(str(s[c]) for c in cols) + '\n') print(f"Training summary: {os.path.basename(summary_path)}") # ── Print summary table ── print("\n── Per-Fold Training Summary ──") print(f"{'Fold':15s} {'Train':>5s} {'Val':>5s} {'Epochs':>6s} " f"{'Best':>5s} {'Val Total':>10s} {'Val VICReg':>10s} " f"{'Val Pred':>9s} {'Time':>6s}") print("-" * 80) for s in all_summaries: print(f"{s['fold']:15s} {s['n_train']:5d} {s['n_val']:5d} " f"{s['total_epochs']:6d} {s['best_epoch']:5d} " f"{s['best_val_total']:10.2f} {s['best_val_vicreg']:10.2f} " f"{s['best_val_pred']:9.2f} {s['elapsed_s']:6.1f}s") mean_val = np.mean([s['best_val_total'] for s in all_summaries]) std_val = np.std([s['best_val_total'] for s in all_summaries]) print(f"\nMean best val total: {mean_val:.2f} +/- {std_val:.2f}") print(f"Total time: {total_elapsed:.1f}s") # Print model parameter count params = WorldModel( env_dim=env.shape[1], pfam_dim=pfam.shape[1], latent_dim=config['latent_dim'], dropout=config['dropout'] ).count_parameters() print(f"Parameters per model: {params:,}") print("\n── Output Files ──") for fold_name, path in fold_checkpoints.items(): print(f" {os.path.basename(path)}") print(f" {os.path.basename(curves_path)}") print(f" {os.path.basename(summary_path)}") return { 'summaries': all_summaries, 'histories': all_histories, 'checkpoints': fold_checkpoints, 'timestamp': timestamp, 'config': config, } if __name__ == '__main__': # Parse optional config from command line # Usage: python train_world_model.py [key=value ...] # Example: python train_world_model.py latent_dim=32 dropout=0.5 config = {} for arg in sys.argv[1:]: if '=' in arg: key, val = arg.split('=', 1) # Try to convert to appropriate type try: val = int(val) except ValueError: try: val = float(val) except ValueError: pass # keep as string config[key] = val result = main(config)