TARA-WorldModel-VICReg / scripts /world_model.py
GreenGenomicsLab's picture
Upload scripts/world_model.py with huggingface_hub
37171a0 verified
Raw
History Blame Contribute Delete
20.4 kB
#!/usr/bin/env python3
"""
WorldModel: Joint Environment-Genome Embedding for Productivity Prediction.
Architecture:
- Encoder_E: Environment MLP (env_dim -> 128 -> latent_dim)
- Encoder_P: PFAM Module MLP (pfam_dim -> 256 -> 128 -> latent_dim)
- Predictor: Productivity head (latent_dim -> 64 -> 3)
Training:
Loss = VICReg(z_env, z_pfam) + alpha * MSE(Predictor(z_env), bio_targets)
Inference (environment-only):
env -> Encoder_E -> z_env -> Predictor -> productivity (chl-a, POC, NFLH)
Designed for:
- 1,810 ocean samples with 24 environmental variables, 20 PFAM modules, 3 bio targets
- Spatial block CV (leave-one-basin-out)
- VICReg non-contrastive alignment (Bardes et al., ICLR 2022)
Author: World Model RALPH Loop
Date: 2026-01-27
"""
import sys
import os
import torch
import torch.nn as nn
# Import VICReg loss from sibling module
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from vicreg_loss import VICRegLoss
class EncoderE(nn.Module):
"""Environment encoder MLP.
Architecture: env_dim -> 128 -> latent_dim
Each layer: Linear -> BatchNorm1d -> ReLU -> Dropout
Parameters
----------
env_dim : int
Number of environment input features (default 24).
latent_dim : int
Latent embedding dimension (default 16).
dropout : float
Dropout probability (default 0.3).
"""
def __init__(self, env_dim=24, latent_dim=16, dropout=0.3):
super().__init__()
self.env_dim = env_dim
self.latent_dim = latent_dim
self.layers = nn.Sequential(
# Block 1: env_dim -> 128
nn.Linear(env_dim, 128),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.Dropout(dropout),
# Block 2: 128 -> latent_dim
nn.Linear(128, latent_dim),
nn.BatchNorm1d(latent_dim),
nn.ReLU(),
nn.Dropout(dropout),
)
def forward(self, x):
"""
Parameters
----------
x : torch.Tensor, shape (N, env_dim)
Standardized environment features.
Returns
-------
z_env : torch.Tensor, shape (N, latent_dim)
Environment embedding.
"""
return self.layers(x)
class EncoderP(nn.Module):
"""PFAM module encoder MLP.
Architecture: pfam_dim -> 256 -> 128 -> latent_dim
Each layer: Linear -> BatchNorm1d -> ReLU -> Dropout
Deeper than EncoderE because PFAM modules encode richer combinatorial
information.
Parameters
----------
pfam_dim : int
Number of PFAM module input features (default 20).
latent_dim : int
Latent embedding dimension (default 16).
dropout : float
Dropout probability (default 0.3).
"""
def __init__(self, pfam_dim=20, latent_dim=16, dropout=0.3):
super().__init__()
self.pfam_dim = pfam_dim
self.latent_dim = latent_dim
self.layers = nn.Sequential(
# Block 1: pfam_dim -> 256
nn.Linear(pfam_dim, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Dropout(dropout),
# Block 2: 256 -> 128
nn.Linear(256, 128),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.Dropout(dropout),
# Block 3: 128 -> latent_dim
nn.Linear(128, latent_dim),
nn.BatchNorm1d(latent_dim),
nn.ReLU(),
nn.Dropout(dropout),
)
def forward(self, x):
"""
Parameters
----------
x : torch.Tensor, shape (N, pfam_dim)
Standardized PFAM module features.
Returns
-------
z_pfam : torch.Tensor, shape (N, latent_dim)
PFAM module embedding.
"""
return self.layers(x)
class Predictor(nn.Module):
"""Productivity prediction head.
Architecture: input_dim -> 64 -> bio_dim
Simple head: Linear -> ReLU -> Linear (no BatchNorm/Dropout).
Parameters
----------
input_dim : int
Input dimension (latent_dim for env-only, 2*latent_dim for joint).
bio_dim : int
Number of bio-response targets (default 3: chl-a, POC, NFLH).
"""
def __init__(self, input_dim=16, bio_dim=3):
super().__init__()
self.input_dim = input_dim
self.bio_dim = bio_dim
self.layers = nn.Sequential(
nn.Linear(input_dim, 64),
nn.ReLU(),
nn.Linear(64, bio_dim),
)
def forward(self, z):
"""
Parameters
----------
z : torch.Tensor, shape (N, input_dim)
Latent embedding (z_env or [z_env, z_pfam]).
Returns
-------
y_pred : torch.Tensor, shape (N, bio_dim)
Predicted productivity (chl-a, POC, NFLH).
"""
return self.layers(z)
class WorldModel(nn.Module):
"""Joint Environment-Genome Embedding Model.
Wraps Encoder_E, Encoder_P, Predictor, and VICRegLoss into a single
module for training and inference.
Training flow:
env -> Encoder_E -> z_env --|
|--> VICReg(z_env, z_pfam)
pfam -> Encoder_P -> z_pfam--|
|--> Predictor(z_env) -> y_pred
MSE(y_pred, bio_targets)
Inference flow (env-only):
env -> Encoder_E -> z_env -> Predictor -> productivity
Parameters
----------
env_dim : int
Number of environment input features (default 24).
pfam_dim : int
Number of PFAM module input features (default 20).
bio_dim : int
Number of bio-response targets (default 3).
latent_dim : int
Latent embedding dimension (default 16).
dropout : float
Dropout probability (default 0.3).
lambda_inv : float
VICReg invariance weight (default 25.0).
lambda_var : float
VICReg variance weight (default 25.0).
lambda_cov : float
VICReg covariance weight (default 1.0).
pred_alpha : float
Weight for productivity prediction loss (default 1.0).
"""
def __init__(self, env_dim=24, pfam_dim=20, bio_dim=3, latent_dim=16,
dropout=0.3, lambda_inv=25.0, lambda_var=25.0,
lambda_cov=1.0, pred_alpha=1.0):
super().__init__()
self.env_dim = env_dim
self.pfam_dim = pfam_dim
self.bio_dim = bio_dim
self.latent_dim = latent_dim
self.pred_alpha = pred_alpha
# Sub-modules
self.encoder_e = EncoderE(env_dim, latent_dim, dropout)
self.encoder_p = EncoderP(pfam_dim, latent_dim, dropout)
self.predictor = Predictor(latent_dim, bio_dim)
self.vicreg = VICRegLoss(lambda_inv, lambda_var, lambda_cov)
# Store config for serialization
self.config = {
'env_dim': env_dim,
'pfam_dim': pfam_dim,
'bio_dim': bio_dim,
'latent_dim': latent_dim,
'dropout': dropout,
'lambda_inv': lambda_inv,
'lambda_var': lambda_var,
'lambda_cov': lambda_cov,
'pred_alpha': pred_alpha,
}
def forward(self, env, pfam, bio_targets=None, bio_valid=None):
"""Full training forward pass.
Parameters
----------
env : torch.Tensor, shape (N, env_dim)
Standardized environment features.
pfam : torch.Tensor, shape (N, pfam_dim)
Standardized PFAM module features.
bio_targets : torch.Tensor or None, shape (N, bio_dim)
Standardized bio-response targets. If None, skip pred loss.
bio_valid : torch.Tensor or None, shape (N,)
Boolean mask: True where all bio targets are valid.
If None and bio_targets given, assume all valid.
Returns
-------
result : dict
'z_env': (N, latent_dim) environment embedding
'z_pfam': (N, latent_dim) PFAM module embedding
'y_pred': (N, bio_dim) predicted productivity
'total_loss': scalar total loss
'vicreg_loss': scalar VICReg loss
'pred_loss': scalar prediction MSE loss (0 if no targets)
'vicreg_components': dict of individual VICReg terms
"""
# Encode both modalities
z_env = self.encoder_e(env)
z_pfam = self.encoder_p(pfam)
# Predict productivity from environment embedding
y_pred = self.predictor(z_env)
# Compute VICReg alignment loss
vicreg_loss, vicreg_components = self.vicreg(z_env, z_pfam)
# Compute prediction loss (only on bio_valid samples)
pred_loss = torch.tensor(0.0, device=env.device)
if bio_targets is not None:
if bio_valid is not None:
valid_mask = bio_valid.bool()
if valid_mask.sum() > 0:
pred_loss = nn.functional.mse_loss(
y_pred[valid_mask], bio_targets[valid_mask]
)
else:
pred_loss = nn.functional.mse_loss(y_pred, bio_targets)
# Total loss
total_loss = vicreg_loss + self.pred_alpha * pred_loss
return {
'z_env': z_env,
'z_pfam': z_pfam,
'y_pred': y_pred,
'total_loss': total_loss,
'vicreg_loss': vicreg_loss,
'pred_loss': pred_loss,
'vicreg_components': vicreg_components,
}
def encode_env(self, env):
"""Encode environment features to latent space.
Parameters
----------
env : torch.Tensor, shape (N, env_dim)
Returns
-------
z_env : torch.Tensor, shape (N, latent_dim)
"""
return self.encoder_e(env)
def encode_pfam(self, pfam):
"""Encode PFAM module features to latent space.
Parameters
----------
pfam : torch.Tensor, shape (N, pfam_dim)
Returns
-------
z_pfam : torch.Tensor, shape (N, latent_dim)
"""
return self.encoder_p(pfam)
def inference(self, env):
"""Environment-only inference path.
Parameters
----------
env : torch.Tensor, shape (N, env_dim)
Standardized environment features.
Returns
-------
y_pred : torch.Tensor, shape (N, bio_dim)
Predicted productivity.
"""
z_env = self.encoder_e(env)
return self.predictor(z_env)
def count_parameters(self):
"""Count total trainable parameters.
Returns
-------
int
Total number of trainable parameters.
"""
return sum(p.numel() for p in self.parameters() if p.requires_grad)
def count_parameters_by_component(self):
"""Count trainable parameters per sub-module.
Returns
-------
dict
{'encoder_e': int, 'encoder_p': int, 'predictor': int, 'total': int}
"""
counts = {}
for name, module in [('encoder_e', self.encoder_e),
('encoder_p', self.encoder_p),
('predictor', self.predictor)]:
counts[name] = sum(p.numel() for p in module.parameters()
if p.requires_grad)
counts['total'] = sum(counts.values())
return counts
def self_test():
"""Run comprehensive self-tests for WorldModel. Returns True if all pass."""
tests_passed = 0
tests_total = 0
def check(name, condition):
nonlocal tests_passed, tests_total
tests_total += 1
if condition:
tests_passed += 1
print(f" PASS: {name}")
else:
print(f" FAIL: {name}")
print("=" * 70)
print("WorldModel Self-Tests")
print("=" * 70)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
# ── Test 1: Instantiation with default parameters ──
print("\nTest 1: Instantiation with default parameters")
model = WorldModel(env_dim=24, pfam_dim=20, bio_dim=3,
latent_dim=16, dropout=0.3).to(device)
params = model.count_parameters()
param_detail = model.count_parameters_by_component()
print(f" Total parameters: {params:,}")
print(f" Encoder_E: {param_detail['encoder_e']:,}")
print(f" Encoder_P: {param_detail['encoder_p']:,}")
print(f" Predictor: {param_detail['predictor']:,}")
check("model instantiates", model is not None)
check("total params > 0", params > 0)
check("param counts sum correctly",
param_detail['total'] == params)
# ── Test 2: Forward pass shapes ──
print("\nTest 2: Forward pass shapes")
N = 64
env = torch.randn(N, 24, device=device)
pfam = torch.randn(N, 20, device=device)
bio = torch.randn(N, 3, device=device)
bio_valid = torch.ones(N, dtype=torch.bool, device=device)
model.train()
result = model(env, pfam, bio, bio_valid)
check("z_env shape", result['z_env'].shape == (N, 16))
check("z_pfam shape", result['z_pfam'].shape == (N, 16))
check("y_pred shape", result['y_pred'].shape == (N, 3))
check("total_loss is scalar", result['total_loss'].dim() == 0)
check("vicreg_loss is scalar", result['vicreg_loss'].dim() == 0)
check("pred_loss is scalar", result['pred_loss'].dim() == 0)
check("vicreg_components present",
all(k in result['vicreg_components']
for k in ['invariance', 'variance_a', 'variance_b',
'covariance_a', 'covariance_b', 'total']))
# ── Test 3: Forward without bio targets (VICReg-only mode) ──
print("\nTest 3: Forward without bio targets (VICReg-only)")
result_no_bio = model(env, pfam, bio_targets=None)
check("works without bio targets", result_no_bio['total_loss'].item() > 0)
check("pred_loss is zero", result_no_bio['pred_loss'].item() == 0.0)
# ── Test 4: Forward with partial bio_valid mask ──
print("\nTest 4: Forward with partial bio_valid mask")
partial_valid = torch.zeros(N, dtype=torch.bool, device=device)
partial_valid[:32] = True # Only half valid
result_partial = model(env, pfam, bio, partial_valid)
check("works with partial bio_valid", result_partial['total_loss'].item() > 0)
check("pred_loss computed on valid subset",
result_partial['pred_loss'].item() >= 0)
# Forward with all-invalid bio_valid mask
all_invalid = torch.zeros(N, dtype=torch.bool, device=device)
result_novalid = model(env, pfam, bio, all_invalid)
check("works with all-invalid mask",
result_novalid['pred_loss'].item() == 0.0)
# ── Test 5: Gradient flow ──
print("\nTest 5: Gradient flow")
model.zero_grad()
result = model(env, pfam, bio, bio_valid)
result['total_loss'].backward()
all_params = list(model.named_parameters())
grad_count = sum(1 for _, p in all_params
if p.grad is not None and p.grad.abs().sum() > 0)
check(f"all {len(all_params)} param tensors receive gradients",
grad_count == len(all_params))
no_nan = all(not torch.isnan(p.grad).any()
for _, p in all_params if p.grad is not None)
check("no NaN in any gradient", no_nan)
# ── Test 6: Inference mode (env-only) ──
print("\nTest 6: Inference mode (env-only)")
model.eval()
with torch.no_grad():
y_pred_inf = model.inference(env)
check("inference returns correct shape", y_pred_inf.shape == (N, 3))
check("no NaN in inference output", not torch.isnan(y_pred_inf).any())
# ── Test 7: Training convergence (50 steps) ──
print("\nTest 7: Training convergence (50 steps)")
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
losses = []
for step in range(50):
optimizer.zero_grad()
result = model(env, pfam, bio, bio_valid)
result['total_loss'].backward()
optimizer.step()
losses.append(result['total_loss'].item())
reduction = (losses[0] - losses[-1]) / losses[0] * 100
print(f" Loss: {losses[0]:.2f} -> {losses[-1]:.2f} ({reduction:.1f}% reduction)")
check("loss decreases over 50 steps", losses[-1] < losses[0])
check("no NaN in loss", all(not (l != l) for l in losses))
# ── Test 8: Different latent dimensions ──
print("\nTest 8: Different latent dimensions {16, 32, 64}")
for ld in [16, 32, 64]:
m = WorldModel(env_dim=24, pfam_dim=20, latent_dim=ld).to(device)
m.train()
r = m(env, pfam, bio, bio_valid)
check(f"latent_dim={ld}: z_env shape ({N},{ld})",
r['z_env'].shape == (N, ld))
check(f"latent_dim={ld}: valid loss",
r['total_loss'].item() > 0 and not torch.isnan(r['total_loss']))
# ── Test 9: Custom VICReg configs ──
print("\nTest 9: Custom VICReg configurations")
configs = {
'default': dict(lambda_inv=25.0, lambda_var=25.0, lambda_cov=1.0),
'high_variance': dict(lambda_inv=10.0, lambda_var=50.0, lambda_cov=1.0),
'high_covariance': dict(lambda_inv=25.0, lambda_var=25.0, lambda_cov=10.0),
}
for name, cfg in configs.items():
m = WorldModel(env_dim=24, pfam_dim=20, **cfg).to(device)
m.train()
r = m(env, pfam, bio, bio_valid)
check(f"{name}: valid loss",
r['total_loss'].item() > 0 and not torch.isnan(r['total_loss']))
# ── Test 10: Minimum batch size (N=2) ──
print("\nTest 10: Minimum batch size (N=2)")
env_small = torch.randn(2, 24, device=device)
pfam_small = torch.randn(2, 20, device=device)
bio_small = torch.randn(2, 3, device=device)
valid_small = torch.ones(2, dtype=torch.bool, device=device)
model.train()
r_small = model(env_small, pfam_small, bio_small, valid_small)
check("batch size 2 works", not torch.isnan(r_small['total_loss']))
# ── Test 11: Standalone encoder methods ──
print("\nTest 11: Standalone encoder methods")
model.eval()
with torch.no_grad():
ze = model.encode_env(env)
zp = model.encode_pfam(pfam)
check("encode_env shape", ze.shape == (N, 16))
check("encode_pfam shape", zp.shape == (N, 16))
# ── Test 12: GPU computation (if available) ──
print("\nTest 12: GPU computation")
if torch.cuda.is_available():
m_gpu = WorldModel(env_dim=24, pfam_dim=20).to('cuda')
m_gpu.train()
e_gpu = torch.randn(32, 24, device='cuda')
p_gpu = torch.randn(32, 20, device='cuda')
b_gpu = torch.randn(32, 3, device='cuda')
v_gpu = torch.ones(32, dtype=torch.bool, device='cuda')
r_gpu = m_gpu(e_gpu, p_gpu, b_gpu, v_gpu)
r_gpu['total_loss'].backward()
check("GPU forward + backward succeeded",
not torch.isnan(r_gpu['total_loss']))
else:
print(" SKIP: CUDA not available")
tests_total += 1
tests_passed += 1
# ── Test 13: Model serialization (save/load) ──
print("\nTest 13: Model serialization (save/load)")
import tempfile
model.eval()
with torch.no_grad():
y_before = model.inference(env)
checkpoint = {
'model_state_dict': model.state_dict(),
'config': model.config,
}
with tempfile.NamedTemporaryFile(suffix='.pt', delete=False) as f:
tmp_path = f.name
torch.save(checkpoint, f)
# Load into fresh model
loaded = torch.load(tmp_path, map_location=device, weights_only=False)
model2 = WorldModel(**loaded['config']).to(device)
model2.load_state_dict(loaded['model_state_dict'])
model2.eval()
with torch.no_grad():
y_after = model2.inference(env)
max_diff = (y_before - y_after).abs().max().item()
print(f" Max prediction diff after save/load: {max_diff:.2e}")
check("save/load produces identical predictions", max_diff < 1e-6)
os.unlink(tmp_path)
# ── Summary ──
print(f"\n{'=' * 70}")
print(f"Results: {tests_passed}/{tests_total} tests passed")
print(f"{'=' * 70}")
return tests_passed == tests_total
if __name__ == '__main__':
success = self_test()
sys.exit(0 if success else 1)