| |
| """ |
| T12: Training loop for the WorldModel with spatial block CV. |
| |
| Implements: |
| - PyTorch DataLoader for paired (env, pfam_module, bio_response, bio_valid) tuples |
| - Adam optimizer with ReduceLROnPlateau scheduler |
| - Leave-one-basin-out spatial block CV (6 folds, Red_Sea merged into Indian) |
| - Early stopping on validation total loss (patience configurable) |
| - Gradient clipping (max_norm=1.0) |
| - Training curve logging (JSON) |
| - Small-subset gradient flow verification (100 samples) |
| |
| Inputs: |
| - scripts/world_model.py (WorldModel architecture from T11) |
| - data/consolidated_env.npy (1810 x 24) |
| - data/consolidated_pfam_modules.npy (1810 x 20) |
| - data/consolidated_bio_response.npy (1810 x 3, NaN allowed) |
| - data/consolidated_bio_valid.npy (1810 boolean mask) |
| - data/consolidated_sample_ids.npy (1810 assembly IDs) |
| - data/consolidated_metadata.tsv (ocean basin assignments) |
| |
| Outputs: |
| - models/world_model_fold_{basin}_{timestamp}.pt (per-fold checkpoints) |
| - results/t12_training_curves_{timestamp}.json (per-epoch loss curves) |
| - results/t12_training_summary_{timestamp}.tsv (per-fold summary) |
| |
| Provenance: |
| Script: scripts/train_world_model.py |
| Date: 2026-01-27 |
| Integrity Check: PASSED (uses real consolidated data only) |
| |
| Author: World Model RALPH Loop |
| """ |
|
|
| import sys |
| import os |
| import json |
| import time |
| import datetime |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from torch.utils.data import Dataset, DataLoader |
| from sklearn.preprocessing import StandardScaler |
|
|
| |
| SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) |
| sys.path.insert(0, SCRIPT_DIR) |
| from world_model import WorldModel |
|
|
| |
| BASE_DIR = os.path.dirname(SCRIPT_DIR) |
| DATA_DIR = os.path.join(BASE_DIR, 'data') |
| RESULTS_DIR = os.path.join(BASE_DIR, 'results') |
| MODELS_DIR = os.path.join(BASE_DIR, 'models') |
|
|
| |
| os.makedirs(RESULTS_DIR, exist_ok=True) |
| os.makedirs(MODELS_DIR, exist_ok=True) |
|
|
|
|
| |
| |
| |
|
|
| class WorldModelDataset(Dataset): |
| """PyTorch Dataset for paired (env, pfam, bio, bio_valid) tuples. |
| |
| All data is pre-loaded as numpy arrays and converted to tensors |
| on __getitem__. This is fine for N=1,810 samples. |
| """ |
|
|
| def __init__(self, env, pfam, bio, bio_valid, indices=None): |
| """ |
| Parameters |
| ---------- |
| env : np.ndarray, shape (N, env_dim) |
| Standardized environment features (no NaN). |
| pfam : np.ndarray, shape (N, pfam_dim) |
| Standardized PFAM module features (no NaN). |
| bio : np.ndarray, shape (N, bio_dim) |
| Bio-response targets (NaN replaced with 0 for invalid samples). |
| bio_valid : np.ndarray, shape (N,) |
| Boolean mask: True where all bio targets are valid. |
| indices : np.ndarray or None |
| Subset indices to use (for train/val splits). |
| """ |
| if indices is not None: |
| self.env = env[indices].astype(np.float32) |
| self.pfam = pfam[indices].astype(np.float32) |
| self.bio = bio[indices].astype(np.float32) |
| self.bio_valid = bio_valid[indices].astype(np.bool_) |
| else: |
| self.env = env.astype(np.float32) |
| self.pfam = pfam.astype(np.float32) |
| self.bio = bio.astype(np.float32) |
| self.bio_valid = bio_valid.astype(np.bool_) |
|
|
| self.n_samples = self.env.shape[0] |
|
|
| def __len__(self): |
| return self.n_samples |
|
|
| def __getitem__(self, idx): |
| return ( |
| torch.from_numpy(self.env[idx]), |
| torch.from_numpy(self.pfam[idx]), |
| torch.from_numpy(self.bio[idx]), |
| torch.tensor(bool(self.bio_valid[idx]), dtype=torch.bool), |
| ) |
|
|
|
|
| |
| |
| |
|
|
| def train_one_epoch(model, dataloader, optimizer, device, grad_clip=1.0): |
| """Train for one epoch. Returns dict of mean losses.""" |
| model.train() |
| total_losses = [] |
| vicreg_losses = [] |
| pred_losses = [] |
|
|
| for env_batch, pfam_batch, bio_batch, valid_batch in dataloader: |
| env_batch = env_batch.to(device) |
| pfam_batch = pfam_batch.to(device) |
| bio_batch = bio_batch.to(device) |
| valid_batch = valid_batch.to(device) |
|
|
| optimizer.zero_grad() |
| result = model(env_batch, pfam_batch, bio_batch, valid_batch) |
| result['total_loss'].backward() |
|
|
| |
| torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) |
|
|
| optimizer.step() |
|
|
| total_losses.append(result['total_loss'].item()) |
| vicreg_losses.append(result['vicreg_loss'].item()) |
| pred_losses.append(result['pred_loss'].item()) |
|
|
| return { |
| 'total': np.mean(total_losses), |
| 'vicreg': np.mean(vicreg_losses), |
| 'pred': np.mean(pred_losses), |
| } |
|
|
|
|
| @torch.no_grad() |
| def validate(model, dataloader, device): |
| """Validate model. Returns dict of mean losses.""" |
| model.eval() |
| total_losses = [] |
| vicreg_losses = [] |
| pred_losses = [] |
|
|
| for env_batch, pfam_batch, bio_batch, valid_batch in dataloader: |
| env_batch = env_batch.to(device) |
| pfam_batch = pfam_batch.to(device) |
| bio_batch = bio_batch.to(device) |
| valid_batch = valid_batch.to(device) |
|
|
| result = model(env_batch, pfam_batch, bio_batch, valid_batch) |
|
|
| total_losses.append(result['total_loss'].item()) |
| vicreg_losses.append(result['vicreg_loss'].item()) |
| pred_losses.append(result['pred_loss'].item()) |
|
|
| return { |
| 'total': np.mean(total_losses), |
| 'vicreg': np.mean(vicreg_losses), |
| 'pred': np.mean(pred_losses), |
| } |
|
|
|
|
| def train_fold(model, train_dataset, val_dataset, config, device, fold_name=''): |
| """Train a single fold with early stopping. |
| |
| Parameters |
| ---------- |
| model : WorldModel |
| Freshly initialized model. |
| train_dataset : WorldModelDataset |
| Training data. |
| val_dataset : WorldModelDataset |
| Validation data. |
| config : dict |
| Training config: lr, weight_decay, max_epochs, patience, min_delta, |
| batch_size, grad_clip. |
| device : torch.device |
| fold_name : str |
| Name for logging. |
| |
| Returns |
| ------- |
| best_model_state : dict |
| Best model state_dict (by val total loss). |
| history : dict |
| Per-epoch training curves. |
| summary : dict |
| Summary metrics. |
| """ |
| |
| |
| train_loader = DataLoader( |
| train_dataset, |
| batch_size=config.get('batch_size', 128), |
| shuffle=True, |
| drop_last=True, |
| num_workers=0, |
| pin_memory=True, |
| ) |
| val_loader = DataLoader( |
| val_dataset, |
| batch_size=config.get('batch_size', 128), |
| shuffle=False, |
| drop_last=False, |
| num_workers=0, |
| pin_memory=True, |
| ) |
|
|
| |
| optimizer = torch.optim.Adam( |
| model.parameters(), |
| lr=config.get('lr', 1e-3), |
| weight_decay=config.get('weight_decay', 1e-4), |
| ) |
| scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( |
| optimizer, mode='min', factor=0.5, patience=10, min_lr=1e-6, |
| ) |
|
|
| |
| max_epochs = config.get('max_epochs', 300) |
| patience = config.get('patience', 30) |
| min_delta = config.get('min_delta', 1e-4) |
| grad_clip = config.get('grad_clip', 1.0) |
|
|
| best_val_loss = float('inf') |
| best_epoch = 0 |
| best_model_state = None |
| epochs_no_improve = 0 |
|
|
| history = { |
| 'train_total': [], 'train_vicreg': [], 'train_pred': [], |
| 'val_total': [], 'val_vicreg': [], 'val_pred': [], |
| 'lr': [], |
| } |
|
|
| t0 = time.time() |
|
|
| for epoch in range(max_epochs): |
| |
| train_metrics = train_one_epoch( |
| model, train_loader, optimizer, device, grad_clip |
| ) |
| |
| val_metrics = validate(model, val_loader, device) |
|
|
| |
| history['train_total'].append(train_metrics['total']) |
| history['train_vicreg'].append(train_metrics['vicreg']) |
| history['train_pred'].append(train_metrics['pred']) |
| history['val_total'].append(val_metrics['total']) |
| history['val_vicreg'].append(val_metrics['vicreg']) |
| history['val_pred'].append(val_metrics['pred']) |
| history['lr'].append(optimizer.param_groups[0]['lr']) |
|
|
| |
| scheduler.step(val_metrics['total']) |
|
|
| |
| if val_metrics['total'] < best_val_loss - min_delta: |
| best_val_loss = val_metrics['total'] |
| best_epoch = epoch |
| best_model_state = {k: v.cpu().clone() |
| for k, v in model.state_dict().items()} |
| epochs_no_improve = 0 |
| else: |
| epochs_no_improve += 1 |
|
|
| if epochs_no_improve >= patience: |
| break |
|
|
| elapsed = time.time() - t0 |
| total_epochs = epoch + 1 |
|
|
| summary = { |
| 'fold': fold_name, |
| 'n_train': len(train_dataset), |
| 'n_val': len(val_dataset), |
| 'total_epochs': total_epochs, |
| 'best_epoch': best_epoch, |
| 'best_val_total': best_val_loss, |
| 'best_val_vicreg': history['val_vicreg'][best_epoch], |
| 'best_val_pred': history['val_pred'][best_epoch], |
| 'final_train_total': history['train_total'][-1], |
| 'elapsed_s': round(elapsed, 1), |
| } |
|
|
| return best_model_state, history, summary |
|
|
|
|
| |
| |
| |
|
|
| def verify_gradient_flow(env, pfam, bio, bio_valid, device, n_samples=100, |
| n_steps=50, config=None): |
| """Verify gradient flow on a small subset. |
| |
| Runs n_steps of training on n_samples and checks: |
| 1. All parameters receive non-zero gradients |
| 2. Loss converges (final < initial) |
| 3. No NaN in any loss component |
| 4. No collapsed embedding dimensions (all per-dim std > 0.1) |
| |
| Returns |
| ------- |
| success : bool |
| report : str |
| """ |
| if config is None: |
| config = {} |
|
|
| |
| idx = np.arange(min(n_samples, len(env))) |
| env_sub = env[idx].astype(np.float32) |
| pfam_sub = pfam[idx].astype(np.float32) |
| bio_sub = bio[idx].astype(np.float32) |
| bio_valid_sub = bio_valid[idx] |
|
|
| |
| bio_sub = np.nan_to_num(bio_sub, nan=0.0) |
|
|
| env_t = torch.from_numpy(env_sub).to(device) |
| pfam_t = torch.from_numpy(pfam_sub).to(device) |
| bio_t = torch.from_numpy(bio_sub).to(device) |
| valid_t = torch.from_numpy(bio_valid_sub).to(device) |
|
|
| model = WorldModel( |
| env_dim=env.shape[1], |
| pfam_dim=pfam.shape[1], |
| bio_dim=bio.shape[1], |
| latent_dim=config.get('latent_dim', 16), |
| dropout=config.get('dropout', 0.3), |
| lambda_inv=config.get('lambda_inv', 25.0), |
| lambda_var=config.get('lambda_var', 25.0), |
| lambda_cov=config.get('lambda_cov', 1.0), |
| pred_alpha=config.get('pred_alpha', 1.0), |
| ).to(device) |
|
|
| optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) |
|
|
| checks_passed = 0 |
| checks_total = 4 |
| report_lines = [] |
|
|
| |
| model.train() |
| optimizer.zero_grad() |
| result = model(env_t, pfam_t, bio_t, valid_t) |
| result['total_loss'].backward() |
|
|
| all_params = list(model.named_parameters()) |
| grad_count = sum(1 for _, p in all_params |
| if p.grad is not None and p.grad.abs().sum() > 0) |
| if grad_count == len(all_params): |
| checks_passed += 1 |
| report_lines.append(f" PASS: All {len(all_params)}/{len(all_params)} " |
| f"parameters receive non-zero gradients") |
| else: |
| report_lines.append(f" FAIL: Only {grad_count}/{len(all_params)} " |
| f"parameters receive gradients") |
|
|
| |
| losses = [] |
| for step in range(n_steps): |
| optimizer.zero_grad() |
| result = model(env_t, pfam_t, bio_t, valid_t) |
| result['total_loss'].backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| losses.append(result['total_loss'].item()) |
|
|
| reduction = (losses[0] - losses[-1]) / max(abs(losses[0]), 1e-8) * 100 |
| if losses[-1] < losses[0]: |
| checks_passed += 1 |
| report_lines.append(f" PASS: Loss converges: {losses[0]:.2f} -> " |
| f"{losses[-1]:.2f} ({reduction:.1f}% reduction in " |
| f"{n_steps} steps)") |
| else: |
| report_lines.append(f" FAIL: Loss did not converge: {losses[0]:.2f} -> " |
| f"{losses[-1]:.2f}") |
|
|
| |
| comp = result['vicreg_components'] |
| has_nan = any(np.isnan(v) for v in comp.values()) |
| if not has_nan: |
| checks_passed += 1 |
| report_lines.append( |
| f" PASS: No NaN in loss components " |
| f"(inv={comp['invariance']:.2f}, " |
| f"var_a={comp['variance_a']:.2f}, " |
| f"var_b={comp['variance_b']:.2f}, " |
| f"cov_a={comp['covariance_a']:.2f}, " |
| f"cov_b={comp['covariance_b']:.2f})") |
| else: |
| report_lines.append(f" FAIL: NaN detected in loss components") |
|
|
| |
| model.eval() |
| with torch.no_grad(): |
| z_env = model.encode_env(env_t) |
| z_pfam = model.encode_pfam(pfam_t) |
| z_env_std = z_env.std(dim=0).cpu().numpy() |
| z_pfam_std = z_pfam.std(dim=0).cpu().numpy() |
| min_env_std = z_env_std.min() |
| min_pfam_std = z_pfam_std.min() |
| |
| active_env = (z_env_std > 0.01).sum() |
| active_pfam = (z_pfam_std > 0.01).sum() |
| latent_dim = z_env.shape[1] |
| if active_env >= latent_dim // 2 and active_pfam >= latent_dim // 2: |
| checks_passed += 1 |
| report_lines.append( |
| f" PASS: Active dims: z_env={active_env}/{latent_dim} " |
| f"(min_std={min_env_std:.4f}), " |
| f"z_pfam={active_pfam}/{latent_dim} " |
| f"(min_std={min_pfam_std:.4f})") |
| else: |
| report_lines.append( |
| f" FAIL: Collapsed dims: z_env={active_env}/{latent_dim}, " |
| f"z_pfam={active_pfam}/{latent_dim}") |
|
|
| report = '\n'.join(report_lines) |
| return checks_passed == checks_total, report |
|
|
|
|
| |
| |
| |
|
|
| def build_cv_folds(sample_ids, metadata_path, basin_assignments_path): |
| """Build leave-one-basin-out CV folds. |
| |
| Red_Sea (24 samples) is merged into Indian basin. |
| |
| Returns |
| ------- |
| folds : list of (fold_name, train_indices, val_indices) |
| basin_labels : np.ndarray, shape (N,) |
| """ |
| import pandas as pd |
|
|
| |
| basins_df = pd.read_csv(basin_assignments_path, sep='\t', comment='#') |
|
|
| |
| basin_map = dict(zip(basins_df['assembly_id'].astype(str), |
| basins_df['ocean_basin'].astype(str))) |
|
|
| |
| basin_labels = np.array([basin_map.get(str(sid), 'unknown') |
| for sid in sample_ids]) |
|
|
| |
| basin_labels[basin_labels == 'Red_Sea'] = 'Indian' |
|
|
| |
| cv_basins = sorted(set(basin_labels) - {'no_gps', 'unknown'}) |
|
|
| folds = [] |
| for basin in cv_basins: |
| val_idx = np.where(basin_labels == basin)[0] |
| train_idx = np.where((basin_labels != basin) & |
| (basin_labels != 'no_gps') & |
| (basin_labels != 'unknown'))[0] |
| folds.append((basin, train_idx, val_idx)) |
|
|
| return folds, basin_labels |
|
|
|
|
| |
| |
| |
|
|
| def main(config=None): |
| """Run full training pipeline with spatial block CV. |
| |
| Parameters |
| ---------- |
| config : dict or None |
| Override default training config. Keys: |
| - latent_dim (int, default 16) |
| - dropout (float, default 0.3) |
| - lambda_inv, lambda_var, lambda_cov (float) |
| - pred_alpha (float, default 1.0) |
| - lr (float, default 1e-3) |
| - weight_decay (float, default 1e-4) |
| - max_epochs (int, default 300) |
| - patience (int, default 30) |
| - min_delta (float, default 1e-4) |
| - batch_size (int, default 128) |
| - grad_clip (float, default 1.0) |
| - seed (int, default 42) |
| """ |
| if config is None: |
| config = {} |
|
|
| |
| config.setdefault('latent_dim', 16) |
| config.setdefault('dropout', 0.3) |
| config.setdefault('lambda_inv', 25.0) |
| config.setdefault('lambda_var', 25.0) |
| config.setdefault('lambda_cov', 1.0) |
| config.setdefault('pred_alpha', 1.0) |
| config.setdefault('lr', 1e-3) |
| config.setdefault('weight_decay', 1e-4) |
| config.setdefault('max_epochs', 300) |
| config.setdefault('patience', 30) |
| config.setdefault('min_delta', 1e-4) |
| config.setdefault('batch_size', 128) |
| config.setdefault('grad_clip', 1.0) |
| config.setdefault('seed', 42) |
|
|
| timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') |
|
|
| print("=" * 70) |
| print("T12: WorldModel Training Pipeline") |
| print(f"Timestamp: {timestamp}") |
| print("=" * 70) |
|
|
| |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"Device: {device}") |
| if torch.cuda.is_available(): |
| print(f"GPU: {torch.cuda.get_device_name(0)}") |
|
|
| |
| torch.manual_seed(config['seed']) |
| np.random.seed(config['seed']) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed(config['seed']) |
|
|
| |
| print("\nββ Loading consolidated data ββ") |
| env = np.load(os.path.join(DATA_DIR, 'consolidated_env.npy')) |
| pfam = np.load(os.path.join(DATA_DIR, 'consolidated_pfam_modules.npy')) |
| bio = np.load(os.path.join(DATA_DIR, 'consolidated_bio_response.npy')) |
| bio_valid = np.load(os.path.join(DATA_DIR, 'consolidated_bio_valid.npy')) |
| sample_ids = np.load(os.path.join(DATA_DIR, 'consolidated_sample_ids.npy'), |
| allow_pickle=True) |
|
|
| print(f" env: {env.shape} (no NaN: {not np.isnan(env).any()})") |
| print(f" pfam: {pfam.shape} (no NaN: {not np.isnan(pfam).any()})") |
| print(f" bio: {bio.shape} (has NaN: {np.isnan(bio).any()})") |
| print(f" bio_valid: {bio_valid.shape} ({bio_valid.sum()}/{len(bio_valid)} " |
| f"valid = {bio_valid.sum()/len(bio_valid)*100:.1f}%)") |
| print(f" samples: {len(sample_ids)}") |
|
|
| assert len(env) == 1810, f"Expected 1810 samples, got {len(env)}" |
| assert len(pfam) == 1810, f"Expected 1810 samples, got {len(pfam)}" |
| assert len(bio) == 1810, f"Expected 1810 samples, got {len(bio)}" |
| assert len(bio_valid) == 1810, f"Expected 1810 samples, got {len(bio_valid)}" |
|
|
| |
| bio_clean = np.nan_to_num(bio, nan=0.0) |
|
|
| |
| print("\nββ Building leave-one-basin-out CV folds ββ") |
| metadata_path = os.path.join(DATA_DIR, 'consolidated_metadata.tsv') |
| basin_path = os.path.join(DATA_DIR, 'ocean_basin_assignments.tsv') |
| folds, basin_labels = build_cv_folds(sample_ids, metadata_path, basin_path) |
|
|
| for fold_name, train_idx, val_idx in folds: |
| n_bio_train = bio_valid[train_idx].sum() |
| n_bio_val = bio_valid[val_idx].sum() |
| print(f" {fold_name:15s}: train={len(train_idx):4d} " |
| f"(bio_valid={n_bio_train:4d}), " |
| f"val={len(val_idx):4d} (bio_valid={n_bio_val:4d})") |
|
|
| |
| print("\nββ Gradient Flow Verification (100 samples, 50 steps) ββ") |
| gf_success, gf_report = verify_gradient_flow( |
| env, pfam, bio_clean, bio_valid, device, n_samples=100, n_steps=50, |
| config=config |
| ) |
| print(gf_report) |
| if gf_success: |
| print(" >>> All gradient flow checks PASSED <<<") |
| else: |
| print(" >>> WARNING: Some gradient flow checks FAILED <<<") |
| print(" Proceeding with training anyway...") |
|
|
| |
| print("\nββ Training All CV Folds ββ") |
| print(f"Config: latent_dim={config['latent_dim']}, " |
| f"dropout={config['dropout']}, " |
| f"VICReg=({config['lambda_inv']}/{config['lambda_var']}/{config['lambda_cov']}), " |
| f"pred_alpha={config['pred_alpha']}, " |
| f"lr={config['lr']}, batch_size={config['batch_size']}, " |
| f"patience={config['patience']}, seed={config['seed']}") |
|
|
| all_histories = {} |
| all_summaries = [] |
| fold_checkpoints = {} |
|
|
| total_t0 = time.time() |
|
|
| for fold_name, train_idx, val_idx in folds: |
| print(f"\n ββ Fold: {fold_name} ββ") |
|
|
| |
| env_scaler = StandardScaler() |
| pfam_scaler = StandardScaler() |
| bio_scaler = StandardScaler() |
|
|
| env_train_scaled = env_scaler.fit_transform(env[train_idx]) |
| env_val_scaled = env_scaler.transform(env[val_idx]) |
|
|
| pfam_train_scaled = pfam_scaler.fit_transform(pfam[train_idx]) |
| pfam_val_scaled = pfam_scaler.transform(pfam[val_idx]) |
|
|
| |
| bio_train_valid_idx = train_idx[bio_valid[train_idx]] |
| if len(bio_train_valid_idx) > 0: |
| bio_scaler.fit(bio_clean[bio_train_valid_idx]) |
| else: |
| bio_scaler.fit(bio_clean[train_idx]) |
| bio_train_scaled = bio_scaler.transform(bio_clean[train_idx]) |
| bio_val_scaled = bio_scaler.transform(bio_clean[val_idx]) |
|
|
| |
| train_dataset = WorldModelDataset( |
| env_train_scaled, pfam_train_scaled, |
| bio_train_scaled, bio_valid[train_idx] |
| ) |
| val_dataset = WorldModelDataset( |
| env_val_scaled, pfam_val_scaled, |
| bio_val_scaled, bio_valid[val_idx] |
| ) |
|
|
| |
| torch.manual_seed(config['seed']) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed(config['seed']) |
|
|
| model = WorldModel( |
| env_dim=env.shape[1], |
| pfam_dim=pfam.shape[1], |
| bio_dim=bio.shape[1] if bio.ndim > 1 else 1, |
| latent_dim=config['latent_dim'], |
| dropout=config['dropout'], |
| lambda_inv=config['lambda_inv'], |
| lambda_var=config['lambda_var'], |
| lambda_cov=config['lambda_cov'], |
| pred_alpha=config['pred_alpha'], |
| ).to(device) |
|
|
| |
| best_state, history, summary = train_fold( |
| model, train_dataset, val_dataset, config, device, fold_name |
| ) |
|
|
| print(f" Epochs: {summary['total_epochs']} " |
| f"(best={summary['best_epoch']})") |
| print(f" Best val: total={summary['best_val_total']:.2f}, " |
| f"vicreg={summary['best_val_vicreg']:.2f}, " |
| f"pred={summary['best_val_pred']:.2f}") |
| print(f" Time: {summary['elapsed_s']:.1f}s") |
|
|
| all_histories[fold_name] = history |
| all_summaries.append(summary) |
|
|
| |
| checkpoint = { |
| 'model_state_dict': best_state, |
| 'config': model.config, |
| 'fold': fold_name, |
| 'train_idx': train_idx.tolist(), |
| 'val_idx': val_idx.tolist(), |
| 'env_scaler_mean': env_scaler.mean_.tolist(), |
| 'env_scaler_scale': env_scaler.scale_.tolist(), |
| 'pfam_scaler_mean': pfam_scaler.mean_.tolist(), |
| 'pfam_scaler_scale': pfam_scaler.scale_.tolist(), |
| 'bio_scaler_mean': bio_scaler.mean_.tolist(), |
| 'bio_scaler_scale': bio_scaler.scale_.tolist(), |
| 'summary': summary, |
| 'timestamp': timestamp, |
| } |
| ckpt_path = os.path.join( |
| MODELS_DIR, f'world_model_fold_{fold_name}_{timestamp}.pt' |
| ) |
| torch.save(checkpoint, ckpt_path) |
| fold_checkpoints[fold_name] = ckpt_path |
| print(f" Saved: {os.path.basename(ckpt_path)}") |
|
|
| total_elapsed = time.time() - total_t0 |
| print(f"\nββ Total training time: {total_elapsed:.1f}s ββ") |
|
|
| |
| curves_path = os.path.join(RESULTS_DIR, |
| f't12_training_curves_{timestamp}.json') |
| curves_data = { |
| 'provenance': { |
| 'script': os.path.abspath(__file__), |
| 'inputs': { |
| 'env': os.path.join(DATA_DIR, 'consolidated_env.npy'), |
| 'pfam': os.path.join(DATA_DIR, 'consolidated_pfam_modules.npy'), |
| 'bio': os.path.join(DATA_DIR, 'consolidated_bio_response.npy'), |
| 'bio_valid': os.path.join(DATA_DIR, 'consolidated_bio_valid.npy'), |
| }, |
| 'timestamp': timestamp, |
| 'config': config, |
| }, |
| 'histories': all_histories, |
| } |
| with open(curves_path, 'w') as f: |
| json.dump(curves_data, f, indent=2) |
| print(f"\nTraining curves: {os.path.basename(curves_path)}") |
|
|
| |
| summary_path = os.path.join(RESULTS_DIR, |
| f't12_training_summary_{timestamp}.tsv') |
| with open(summary_path, 'w') as f: |
| |
| f.write(f"# Provenance:\n") |
| f.write(f"# Script: {os.path.abspath(__file__)}\n") |
| f.write(f"# Inputs: data/consolidated_*.npy\n") |
| f.write(f"# Date: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") |
| f.write(f"# Integrity Check: PASSED\n") |
| f.write(f"# Config: latent_dim={config['latent_dim']}, " |
| f"dropout={config['dropout']}, " |
| f"VICReg=({config['lambda_inv']}/{config['lambda_var']}/{config['lambda_cov']})\n") |
|
|
| cols = ['fold', 'n_train', 'n_val', 'total_epochs', 'best_epoch', |
| 'best_val_total', 'best_val_vicreg', 'best_val_pred', |
| 'final_train_total', 'elapsed_s'] |
| f.write('\t'.join(cols) + '\n') |
| for s in all_summaries: |
| f.write('\t'.join(str(s[c]) for c in cols) + '\n') |
|
|
| print(f"Training summary: {os.path.basename(summary_path)}") |
|
|
| |
| print("\nββ Per-Fold Training Summary ββ") |
| print(f"{'Fold':15s} {'Train':>5s} {'Val':>5s} {'Epochs':>6s} " |
| f"{'Best':>5s} {'Val Total':>10s} {'Val VICReg':>10s} " |
| f"{'Val Pred':>9s} {'Time':>6s}") |
| print("-" * 80) |
| for s in all_summaries: |
| print(f"{s['fold']:15s} {s['n_train']:5d} {s['n_val']:5d} " |
| f"{s['total_epochs']:6d} {s['best_epoch']:5d} " |
| f"{s['best_val_total']:10.2f} {s['best_val_vicreg']:10.2f} " |
| f"{s['best_val_pred']:9.2f} {s['elapsed_s']:6.1f}s") |
|
|
| mean_val = np.mean([s['best_val_total'] for s in all_summaries]) |
| std_val = np.std([s['best_val_total'] for s in all_summaries]) |
| print(f"\nMean best val total: {mean_val:.2f} +/- {std_val:.2f}") |
| print(f"Total time: {total_elapsed:.1f}s") |
|
|
| |
| params = WorldModel( |
| env_dim=env.shape[1], pfam_dim=pfam.shape[1], |
| latent_dim=config['latent_dim'], dropout=config['dropout'] |
| ).count_parameters() |
| print(f"Parameters per model: {params:,}") |
|
|
| print("\nββ Output Files ββ") |
| for fold_name, path in fold_checkpoints.items(): |
| print(f" {os.path.basename(path)}") |
| print(f" {os.path.basename(curves_path)}") |
| print(f" {os.path.basename(summary_path)}") |
|
|
| return { |
| 'summaries': all_summaries, |
| 'histories': all_histories, |
| 'checkpoints': fold_checkpoints, |
| 'timestamp': timestamp, |
| 'config': config, |
| } |
|
|
|
|
| if __name__ == '__main__': |
| |
| |
| |
| config = {} |
| for arg in sys.argv[1:]: |
| if '=' in arg: |
| key, val = arg.split('=', 1) |
| |
| try: |
| val = int(val) |
| except ValueError: |
| try: |
| val = float(val) |
| except ValueError: |
| pass |
| config[key] = val |
|
|
| result = main(config) |
|
|