TARA-WorldModel-VICReg / scripts /train_world_model.py
GreenGenomicsLab's picture
Upload scripts/train_world_model.py with huggingface_hub
a609de0 verified
Raw
History Blame Contribute Delete
30.5 kB
#!/usr/bin/env python3
"""
T12: Training loop for the WorldModel with spatial block CV.
Implements:
- PyTorch DataLoader for paired (env, pfam_module, bio_response, bio_valid) tuples
- Adam optimizer with ReduceLROnPlateau scheduler
- Leave-one-basin-out spatial block CV (6 folds, Red_Sea merged into Indian)
- Early stopping on validation total loss (patience configurable)
- Gradient clipping (max_norm=1.0)
- Training curve logging (JSON)
- Small-subset gradient flow verification (100 samples)
Inputs:
- scripts/world_model.py (WorldModel architecture from T11)
- data/consolidated_env.npy (1810 x 24)
- data/consolidated_pfam_modules.npy (1810 x 20)
- data/consolidated_bio_response.npy (1810 x 3, NaN allowed)
- data/consolidated_bio_valid.npy (1810 boolean mask)
- data/consolidated_sample_ids.npy (1810 assembly IDs)
- data/consolidated_metadata.tsv (ocean basin assignments)
Outputs:
- models/world_model_fold_{basin}_{timestamp}.pt (per-fold checkpoints)
- results/t12_training_curves_{timestamp}.json (per-epoch loss curves)
- results/t12_training_summary_{timestamp}.tsv (per-fold summary)
Provenance:
Script: scripts/train_world_model.py
Date: 2026-01-27
Integrity Check: PASSED (uses real consolidated data only)
Author: World Model RALPH Loop
"""
import sys
import os
import json
import time
import datetime
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
# Add scripts directory to path for WorldModel import
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, SCRIPT_DIR)
from world_model import WorldModel
# ── Paths ──
BASE_DIR = os.path.dirname(SCRIPT_DIR) # WorldModelApp
DATA_DIR = os.path.join(BASE_DIR, 'data')
RESULTS_DIR = os.path.join(BASE_DIR, 'results')
MODELS_DIR = os.path.join(BASE_DIR, 'models')
# Ensure output directories exist
os.makedirs(RESULTS_DIR, exist_ok=True)
os.makedirs(MODELS_DIR, exist_ok=True)
# ═══════════════════════════════════════════════════════════════════
# Dataset
# ═══════════════════════════════════════════════════════════════════
class WorldModelDataset(Dataset):
"""PyTorch Dataset for paired (env, pfam, bio, bio_valid) tuples.
All data is pre-loaded as numpy arrays and converted to tensors
on __getitem__. This is fine for N=1,810 samples.
"""
def __init__(self, env, pfam, bio, bio_valid, indices=None):
"""
Parameters
----------
env : np.ndarray, shape (N, env_dim)
Standardized environment features (no NaN).
pfam : np.ndarray, shape (N, pfam_dim)
Standardized PFAM module features (no NaN).
bio : np.ndarray, shape (N, bio_dim)
Bio-response targets (NaN replaced with 0 for invalid samples).
bio_valid : np.ndarray, shape (N,)
Boolean mask: True where all bio targets are valid.
indices : np.ndarray or None
Subset indices to use (for train/val splits).
"""
if indices is not None:
self.env = env[indices].astype(np.float32)
self.pfam = pfam[indices].astype(np.float32)
self.bio = bio[indices].astype(np.float32)
self.bio_valid = bio_valid[indices].astype(np.bool_)
else:
self.env = env.astype(np.float32)
self.pfam = pfam.astype(np.float32)
self.bio = bio.astype(np.float32)
self.bio_valid = bio_valid.astype(np.bool_)
self.n_samples = self.env.shape[0]
def __len__(self):
return self.n_samples
def __getitem__(self, idx):
return (
torch.from_numpy(self.env[idx]),
torch.from_numpy(self.pfam[idx]),
torch.from_numpy(self.bio[idx]),
torch.tensor(bool(self.bio_valid[idx]), dtype=torch.bool),
)
# ═══════════════════════════════════════════════════════════════════
# Training Loop
# ═══════════════════════════════════════════════════════════════════
def train_one_epoch(model, dataloader, optimizer, device, grad_clip=1.0):
"""Train for one epoch. Returns dict of mean losses."""
model.train()
total_losses = []
vicreg_losses = []
pred_losses = []
for env_batch, pfam_batch, bio_batch, valid_batch in dataloader:
env_batch = env_batch.to(device)
pfam_batch = pfam_batch.to(device)
bio_batch = bio_batch.to(device)
valid_batch = valid_batch.to(device)
optimizer.zero_grad()
result = model(env_batch, pfam_batch, bio_batch, valid_batch)
result['total_loss'].backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
optimizer.step()
total_losses.append(result['total_loss'].item())
vicreg_losses.append(result['vicreg_loss'].item())
pred_losses.append(result['pred_loss'].item())
return {
'total': np.mean(total_losses),
'vicreg': np.mean(vicreg_losses),
'pred': np.mean(pred_losses),
}
@torch.no_grad()
def validate(model, dataloader, device):
"""Validate model. Returns dict of mean losses."""
model.eval()
total_losses = []
vicreg_losses = []
pred_losses = []
for env_batch, pfam_batch, bio_batch, valid_batch in dataloader:
env_batch = env_batch.to(device)
pfam_batch = pfam_batch.to(device)
bio_batch = bio_batch.to(device)
valid_batch = valid_batch.to(device)
result = model(env_batch, pfam_batch, bio_batch, valid_batch)
total_losses.append(result['total_loss'].item())
vicreg_losses.append(result['vicreg_loss'].item())
pred_losses.append(result['pred_loss'].item())
return {
'total': np.mean(total_losses),
'vicreg': np.mean(vicreg_losses),
'pred': np.mean(pred_losses),
}
def train_fold(model, train_dataset, val_dataset, config, device, fold_name=''):
"""Train a single fold with early stopping.
Parameters
----------
model : WorldModel
Freshly initialized model.
train_dataset : WorldModelDataset
Training data.
val_dataset : WorldModelDataset
Validation data.
config : dict
Training config: lr, weight_decay, max_epochs, patience, min_delta,
batch_size, grad_clip.
device : torch.device
fold_name : str
Name for logging.
Returns
-------
best_model_state : dict
Best model state_dict (by val total loss).
history : dict
Per-epoch training curves.
summary : dict
Summary metrics.
"""
# DataLoaders
# drop_last=True for training to ensure BatchNorm stability
train_loader = DataLoader(
train_dataset,
batch_size=config.get('batch_size', 128),
shuffle=True,
drop_last=True,
num_workers=0,
pin_memory=True,
)
val_loader = DataLoader(
val_dataset,
batch_size=config.get('batch_size', 128),
shuffle=False,
drop_last=False,
num_workers=0,
pin_memory=True,
)
# Optimizer and scheduler
optimizer = torch.optim.Adam(
model.parameters(),
lr=config.get('lr', 1e-3),
weight_decay=config.get('weight_decay', 1e-4),
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', factor=0.5, patience=10, min_lr=1e-6,
)
# Early stopping
max_epochs = config.get('max_epochs', 300)
patience = config.get('patience', 30)
min_delta = config.get('min_delta', 1e-4)
grad_clip = config.get('grad_clip', 1.0)
best_val_loss = float('inf')
best_epoch = 0
best_model_state = None
epochs_no_improve = 0
history = {
'train_total': [], 'train_vicreg': [], 'train_pred': [],
'val_total': [], 'val_vicreg': [], 'val_pred': [],
'lr': [],
}
t0 = time.time()
for epoch in range(max_epochs):
# Train
train_metrics = train_one_epoch(
model, train_loader, optimizer, device, grad_clip
)
# Validate
val_metrics = validate(model, val_loader, device)
# Record history
history['train_total'].append(train_metrics['total'])
history['train_vicreg'].append(train_metrics['vicreg'])
history['train_pred'].append(train_metrics['pred'])
history['val_total'].append(val_metrics['total'])
history['val_vicreg'].append(val_metrics['vicreg'])
history['val_pred'].append(val_metrics['pred'])
history['lr'].append(optimizer.param_groups[0]['lr'])
# Step scheduler on validation total loss
scheduler.step(val_metrics['total'])
# Check early stopping
if val_metrics['total'] < best_val_loss - min_delta:
best_val_loss = val_metrics['total']
best_epoch = epoch
best_model_state = {k: v.cpu().clone()
for k, v in model.state_dict().items()}
epochs_no_improve = 0
else:
epochs_no_improve += 1
if epochs_no_improve >= patience:
break
elapsed = time.time() - t0
total_epochs = epoch + 1
summary = {
'fold': fold_name,
'n_train': len(train_dataset),
'n_val': len(val_dataset),
'total_epochs': total_epochs,
'best_epoch': best_epoch,
'best_val_total': best_val_loss,
'best_val_vicreg': history['val_vicreg'][best_epoch],
'best_val_pred': history['val_pred'][best_epoch],
'final_train_total': history['train_total'][-1],
'elapsed_s': round(elapsed, 1),
}
return best_model_state, history, summary
# ═══════════════════════════════════════════════════════════════════
# Gradient Flow Verification
# ═══════════════════════════════════════════════════════════════════
def verify_gradient_flow(env, pfam, bio, bio_valid, device, n_samples=100,
n_steps=50, config=None):
"""Verify gradient flow on a small subset.
Runs n_steps of training on n_samples and checks:
1. All parameters receive non-zero gradients
2. Loss converges (final < initial)
3. No NaN in any loss component
4. No collapsed embedding dimensions (all per-dim std > 0.1)
Returns
-------
success : bool
report : str
"""
if config is None:
config = {}
# Subsample
idx = np.arange(min(n_samples, len(env)))
env_sub = env[idx].astype(np.float32)
pfam_sub = pfam[idx].astype(np.float32)
bio_sub = bio[idx].astype(np.float32)
bio_valid_sub = bio_valid[idx]
# Replace NaN in bio_sub for the subset
bio_sub = np.nan_to_num(bio_sub, nan=0.0)
env_t = torch.from_numpy(env_sub).to(device)
pfam_t = torch.from_numpy(pfam_sub).to(device)
bio_t = torch.from_numpy(bio_sub).to(device)
valid_t = torch.from_numpy(bio_valid_sub).to(device)
model = WorldModel(
env_dim=env.shape[1],
pfam_dim=pfam.shape[1],
bio_dim=bio.shape[1],
latent_dim=config.get('latent_dim', 16),
dropout=config.get('dropout', 0.3),
lambda_inv=config.get('lambda_inv', 25.0),
lambda_var=config.get('lambda_var', 25.0),
lambda_cov=config.get('lambda_cov', 1.0),
pred_alpha=config.get('pred_alpha', 1.0),
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
checks_passed = 0
checks_total = 4
report_lines = []
# --- Check 1: Gradient flow ---
model.train()
optimizer.zero_grad()
result = model(env_t, pfam_t, bio_t, valid_t)
result['total_loss'].backward()
all_params = list(model.named_parameters())
grad_count = sum(1 for _, p in all_params
if p.grad is not None and p.grad.abs().sum() > 0)
if grad_count == len(all_params):
checks_passed += 1
report_lines.append(f" PASS: All {len(all_params)}/{len(all_params)} "
f"parameters receive non-zero gradients")
else:
report_lines.append(f" FAIL: Only {grad_count}/{len(all_params)} "
f"parameters receive gradients")
# --- Check 2: Loss convergence ---
losses = []
for step in range(n_steps):
optimizer.zero_grad()
result = model(env_t, pfam_t, bio_t, valid_t)
result['total_loss'].backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
losses.append(result['total_loss'].item())
reduction = (losses[0] - losses[-1]) / max(abs(losses[0]), 1e-8) * 100
if losses[-1] < losses[0]:
checks_passed += 1
report_lines.append(f" PASS: Loss converges: {losses[0]:.2f} -> "
f"{losses[-1]:.2f} ({reduction:.1f}% reduction in "
f"{n_steps} steps)")
else:
report_lines.append(f" FAIL: Loss did not converge: {losses[0]:.2f} -> "
f"{losses[-1]:.2f}")
# --- Check 3: No NaN in loss components ---
comp = result['vicreg_components']
has_nan = any(np.isnan(v) for v in comp.values())
if not has_nan:
checks_passed += 1
report_lines.append(
f" PASS: No NaN in loss components "
f"(inv={comp['invariance']:.2f}, "
f"var_a={comp['variance_a']:.2f}, "
f"var_b={comp['variance_b']:.2f}, "
f"cov_a={comp['covariance_a']:.2f}, "
f"cov_b={comp['covariance_b']:.2f})")
else:
report_lines.append(f" FAIL: NaN detected in loss components")
# --- Check 4: No collapsed embedding dimensions ---
model.eval()
with torch.no_grad():
z_env = model.encode_env(env_t)
z_pfam = model.encode_pfam(pfam_t)
z_env_std = z_env.std(dim=0).cpu().numpy()
z_pfam_std = z_pfam.std(dim=0).cpu().numpy()
min_env_std = z_env_std.min()
min_pfam_std = z_pfam_std.min()
# Note: ReLU can cause some dims to be zero; check that not ALL are collapsed
active_env = (z_env_std > 0.01).sum()
active_pfam = (z_pfam_std > 0.01).sum()
latent_dim = z_env.shape[1]
if active_env >= latent_dim // 2 and active_pfam >= latent_dim // 2:
checks_passed += 1
report_lines.append(
f" PASS: Active dims: z_env={active_env}/{latent_dim} "
f"(min_std={min_env_std:.4f}), "
f"z_pfam={active_pfam}/{latent_dim} "
f"(min_std={min_pfam_std:.4f})")
else:
report_lines.append(
f" FAIL: Collapsed dims: z_env={active_env}/{latent_dim}, "
f"z_pfam={active_pfam}/{latent_dim}")
report = '\n'.join(report_lines)
return checks_passed == checks_total, report
# ═══════════════════════════════════════════════════════════════════
# Cross-Validation Fold Setup
# ═══════════════════════════════════════════════════════════════════
def build_cv_folds(sample_ids, metadata_path, basin_assignments_path):
"""Build leave-one-basin-out CV folds.
Red_Sea (24 samples) is merged into Indian basin.
Returns
-------
folds : list of (fold_name, train_indices, val_indices)
basin_labels : np.ndarray, shape (N,)
"""
import pandas as pd
# Load basin assignments
basins_df = pd.read_csv(basin_assignments_path, sep='\t', comment='#')
# Build a sample_id -> basin map
basin_map = dict(zip(basins_df['assembly_id'].astype(str),
basins_df['ocean_basin'].astype(str)))
# Assign basins to our consolidated sample order
basin_labels = np.array([basin_map.get(str(sid), 'unknown')
for sid in sample_ids])
# Merge Red_Sea into Indian
basin_labels[basin_labels == 'Red_Sea'] = 'Indian'
# Get unique basins (excluding 'no_gps' and 'unknown')
cv_basins = sorted(set(basin_labels) - {'no_gps', 'unknown'})
folds = []
for basin in cv_basins:
val_idx = np.where(basin_labels == basin)[0]
train_idx = np.where((basin_labels != basin) &
(basin_labels != 'no_gps') &
(basin_labels != 'unknown'))[0]
folds.append((basin, train_idx, val_idx))
return folds, basin_labels
# ═══════════════════════════════════════════════════════════════════
# Main Training Pipeline
# ═══════════════════════════════════════════════════════════════════
def main(config=None):
"""Run full training pipeline with spatial block CV.
Parameters
----------
config : dict or None
Override default training config. Keys:
- latent_dim (int, default 16)
- dropout (float, default 0.3)
- lambda_inv, lambda_var, lambda_cov (float)
- pred_alpha (float, default 1.0)
- lr (float, default 1e-3)
- weight_decay (float, default 1e-4)
- max_epochs (int, default 300)
- patience (int, default 30)
- min_delta (float, default 1e-4)
- batch_size (int, default 128)
- grad_clip (float, default 1.0)
- seed (int, default 42)
"""
if config is None:
config = {}
# Defaults
config.setdefault('latent_dim', 16)
config.setdefault('dropout', 0.3)
config.setdefault('lambda_inv', 25.0)
config.setdefault('lambda_var', 25.0)
config.setdefault('lambda_cov', 1.0)
config.setdefault('pred_alpha', 1.0)
config.setdefault('lr', 1e-3)
config.setdefault('weight_decay', 1e-4)
config.setdefault('max_epochs', 300)
config.setdefault('patience', 30)
config.setdefault('min_delta', 1e-4)
config.setdefault('batch_size', 128)
config.setdefault('grad_clip', 1.0)
config.setdefault('seed', 42)
timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
print("=" * 70)
print("T12: WorldModel Training Pipeline")
print(f"Timestamp: {timestamp}")
print("=" * 70)
# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
if torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name(0)}")
# Reproducibility
torch.manual_seed(config['seed'])
np.random.seed(config['seed'])
if torch.cuda.is_available():
torch.cuda.manual_seed(config['seed'])
# ── Load data ──
print("\n── Loading consolidated data ──")
env = np.load(os.path.join(DATA_DIR, 'consolidated_env.npy'))
pfam = np.load(os.path.join(DATA_DIR, 'consolidated_pfam_modules.npy'))
bio = np.load(os.path.join(DATA_DIR, 'consolidated_bio_response.npy'))
bio_valid = np.load(os.path.join(DATA_DIR, 'consolidated_bio_valid.npy'))
sample_ids = np.load(os.path.join(DATA_DIR, 'consolidated_sample_ids.npy'),
allow_pickle=True)
print(f" env: {env.shape} (no NaN: {not np.isnan(env).any()})")
print(f" pfam: {pfam.shape} (no NaN: {not np.isnan(pfam).any()})")
print(f" bio: {bio.shape} (has NaN: {np.isnan(bio).any()})")
print(f" bio_valid: {bio_valid.shape} ({bio_valid.sum()}/{len(bio_valid)} "
f"valid = {bio_valid.sum()/len(bio_valid)*100:.1f}%)")
print(f" samples: {len(sample_ids)}")
assert len(env) == 1810, f"Expected 1810 samples, got {len(env)}"
assert len(pfam) == 1810, f"Expected 1810 samples, got {len(pfam)}"
assert len(bio) == 1810, f"Expected 1810 samples, got {len(bio)}"
assert len(bio_valid) == 1810, f"Expected 1810 samples, got {len(bio_valid)}"
# Replace NaN in bio with 0 (bio_valid mask prevents them from affecting loss)
bio_clean = np.nan_to_num(bio, nan=0.0)
# ── Build CV folds ──
print("\n── Building leave-one-basin-out CV folds ──")
metadata_path = os.path.join(DATA_DIR, 'consolidated_metadata.tsv')
basin_path = os.path.join(DATA_DIR, 'ocean_basin_assignments.tsv')
folds, basin_labels = build_cv_folds(sample_ids, metadata_path, basin_path)
for fold_name, train_idx, val_idx in folds:
n_bio_train = bio_valid[train_idx].sum()
n_bio_val = bio_valid[val_idx].sum()
print(f" {fold_name:15s}: train={len(train_idx):4d} "
f"(bio_valid={n_bio_train:4d}), "
f"val={len(val_idx):4d} (bio_valid={n_bio_val:4d})")
# ── Gradient flow verification ──
print("\n── Gradient Flow Verification (100 samples, 50 steps) ──")
gf_success, gf_report = verify_gradient_flow(
env, pfam, bio_clean, bio_valid, device, n_samples=100, n_steps=50,
config=config
)
print(gf_report)
if gf_success:
print(" >>> All gradient flow checks PASSED <<<")
else:
print(" >>> WARNING: Some gradient flow checks FAILED <<<")
print(" Proceeding with training anyway...")
# ── Train all folds ──
print("\n── Training All CV Folds ──")
print(f"Config: latent_dim={config['latent_dim']}, "
f"dropout={config['dropout']}, "
f"VICReg=({config['lambda_inv']}/{config['lambda_var']}/{config['lambda_cov']}), "
f"pred_alpha={config['pred_alpha']}, "
f"lr={config['lr']}, batch_size={config['batch_size']}, "
f"patience={config['patience']}, seed={config['seed']}")
all_histories = {}
all_summaries = []
fold_checkpoints = {}
total_t0 = time.time()
for fold_name, train_idx, val_idx in folds:
print(f"\n ── Fold: {fold_name} ──")
# Per-fold standardization (no data leakage)
env_scaler = StandardScaler()
pfam_scaler = StandardScaler()
bio_scaler = StandardScaler()
env_train_scaled = env_scaler.fit_transform(env[train_idx])
env_val_scaled = env_scaler.transform(env[val_idx])
pfam_train_scaled = pfam_scaler.fit_transform(pfam[train_idx])
pfam_val_scaled = pfam_scaler.transform(pfam[val_idx])
# Scale bio using only bio_valid training samples
bio_train_valid_idx = train_idx[bio_valid[train_idx]]
if len(bio_train_valid_idx) > 0:
bio_scaler.fit(bio_clean[bio_train_valid_idx])
else:
bio_scaler.fit(bio_clean[train_idx]) # fallback
bio_train_scaled = bio_scaler.transform(bio_clean[train_idx])
bio_val_scaled = bio_scaler.transform(bio_clean[val_idx])
# Create datasets
train_dataset = WorldModelDataset(
env_train_scaled, pfam_train_scaled,
bio_train_scaled, bio_valid[train_idx]
)
val_dataset = WorldModelDataset(
env_val_scaled, pfam_val_scaled,
bio_val_scaled, bio_valid[val_idx]
)
# Fresh model per fold
torch.manual_seed(config['seed'])
if torch.cuda.is_available():
torch.cuda.manual_seed(config['seed'])
model = WorldModel(
env_dim=env.shape[1],
pfam_dim=pfam.shape[1],
bio_dim=bio.shape[1] if bio.ndim > 1 else 1,
latent_dim=config['latent_dim'],
dropout=config['dropout'],
lambda_inv=config['lambda_inv'],
lambda_var=config['lambda_var'],
lambda_cov=config['lambda_cov'],
pred_alpha=config['pred_alpha'],
).to(device)
# Train
best_state, history, summary = train_fold(
model, train_dataset, val_dataset, config, device, fold_name
)
print(f" Epochs: {summary['total_epochs']} "
f"(best={summary['best_epoch']})")
print(f" Best val: total={summary['best_val_total']:.2f}, "
f"vicreg={summary['best_val_vicreg']:.2f}, "
f"pred={summary['best_val_pred']:.2f}")
print(f" Time: {summary['elapsed_s']:.1f}s")
all_histories[fold_name] = history
all_summaries.append(summary)
# Save checkpoint
checkpoint = {
'model_state_dict': best_state,
'config': model.config,
'fold': fold_name,
'train_idx': train_idx.tolist(),
'val_idx': val_idx.tolist(),
'env_scaler_mean': env_scaler.mean_.tolist(),
'env_scaler_scale': env_scaler.scale_.tolist(),
'pfam_scaler_mean': pfam_scaler.mean_.tolist(),
'pfam_scaler_scale': pfam_scaler.scale_.tolist(),
'bio_scaler_mean': bio_scaler.mean_.tolist(),
'bio_scaler_scale': bio_scaler.scale_.tolist(),
'summary': summary,
'timestamp': timestamp,
}
ckpt_path = os.path.join(
MODELS_DIR, f'world_model_fold_{fold_name}_{timestamp}.pt'
)
torch.save(checkpoint, ckpt_path)
fold_checkpoints[fold_name] = ckpt_path
print(f" Saved: {os.path.basename(ckpt_path)}")
total_elapsed = time.time() - total_t0
print(f"\n── Total training time: {total_elapsed:.1f}s ──")
# ── Save training curves ──
curves_path = os.path.join(RESULTS_DIR,
f't12_training_curves_{timestamp}.json')
curves_data = {
'provenance': {
'script': os.path.abspath(__file__),
'inputs': {
'env': os.path.join(DATA_DIR, 'consolidated_env.npy'),
'pfam': os.path.join(DATA_DIR, 'consolidated_pfam_modules.npy'),
'bio': os.path.join(DATA_DIR, 'consolidated_bio_response.npy'),
'bio_valid': os.path.join(DATA_DIR, 'consolidated_bio_valid.npy'),
},
'timestamp': timestamp,
'config': config,
},
'histories': all_histories,
}
with open(curves_path, 'w') as f:
json.dump(curves_data, f, indent=2)
print(f"\nTraining curves: {os.path.basename(curves_path)}")
# ── Save training summary ──
summary_path = os.path.join(RESULTS_DIR,
f't12_training_summary_{timestamp}.tsv')
with open(summary_path, 'w') as f:
# Provenance header
f.write(f"# Provenance:\n")
f.write(f"# Script: {os.path.abspath(__file__)}\n")
f.write(f"# Inputs: data/consolidated_*.npy\n")
f.write(f"# Date: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
f.write(f"# Integrity Check: PASSED\n")
f.write(f"# Config: latent_dim={config['latent_dim']}, "
f"dropout={config['dropout']}, "
f"VICReg=({config['lambda_inv']}/{config['lambda_var']}/{config['lambda_cov']})\n")
cols = ['fold', 'n_train', 'n_val', 'total_epochs', 'best_epoch',
'best_val_total', 'best_val_vicreg', 'best_val_pred',
'final_train_total', 'elapsed_s']
f.write('\t'.join(cols) + '\n')
for s in all_summaries:
f.write('\t'.join(str(s[c]) for c in cols) + '\n')
print(f"Training summary: {os.path.basename(summary_path)}")
# ── Print summary table ──
print("\n── Per-Fold Training Summary ──")
print(f"{'Fold':15s} {'Train':>5s} {'Val':>5s} {'Epochs':>6s} "
f"{'Best':>5s} {'Val Total':>10s} {'Val VICReg':>10s} "
f"{'Val Pred':>9s} {'Time':>6s}")
print("-" * 80)
for s in all_summaries:
print(f"{s['fold']:15s} {s['n_train']:5d} {s['n_val']:5d} "
f"{s['total_epochs']:6d} {s['best_epoch']:5d} "
f"{s['best_val_total']:10.2f} {s['best_val_vicreg']:10.2f} "
f"{s['best_val_pred']:9.2f} {s['elapsed_s']:6.1f}s")
mean_val = np.mean([s['best_val_total'] for s in all_summaries])
std_val = np.std([s['best_val_total'] for s in all_summaries])
print(f"\nMean best val total: {mean_val:.2f} +/- {std_val:.2f}")
print(f"Total time: {total_elapsed:.1f}s")
# Print model parameter count
params = WorldModel(
env_dim=env.shape[1], pfam_dim=pfam.shape[1],
latent_dim=config['latent_dim'], dropout=config['dropout']
).count_parameters()
print(f"Parameters per model: {params:,}")
print("\n── Output Files ──")
for fold_name, path in fold_checkpoints.items():
print(f" {os.path.basename(path)}")
print(f" {os.path.basename(curves_path)}")
print(f" {os.path.basename(summary_path)}")
return {
'summaries': all_summaries,
'histories': all_histories,
'checkpoints': fold_checkpoints,
'timestamp': timestamp,
'config': config,
}
if __name__ == '__main__':
# Parse optional config from command line
# Usage: python train_world_model.py [key=value ...]
# Example: python train_world_model.py latent_dim=32 dropout=0.5
config = {}
for arg in sys.argv[1:]:
if '=' in arg:
key, val = arg.split('=', 1)
# Try to convert to appropriate type
try:
val = int(val)
except ValueError:
try:
val = float(val)
except ValueError:
pass # keep as string
config[key] = val
result = main(config)