| |
| """ |
| WorldModel: Joint Environment-Genome Embedding for Productivity Prediction. |
| |
| Architecture: |
| - Encoder_E: Environment MLP (env_dim -> 128 -> latent_dim) |
| - Encoder_P: PFAM Module MLP (pfam_dim -> 256 -> 128 -> latent_dim) |
| - Predictor: Productivity head (latent_dim -> 64 -> 3) |
| |
| Training: |
| Loss = VICReg(z_env, z_pfam) + alpha * MSE(Predictor(z_env), bio_targets) |
| |
| Inference (environment-only): |
| env -> Encoder_E -> z_env -> Predictor -> productivity (chl-a, POC, NFLH) |
| |
| Designed for: |
| - 1,810 ocean samples with 24 environmental variables, 20 PFAM modules, 3 bio targets |
| - Spatial block CV (leave-one-basin-out) |
| - VICReg non-contrastive alignment (Bardes et al., ICLR 2022) |
| |
| Author: World Model RALPH Loop |
| Date: 2026-01-27 |
| """ |
|
|
| import sys |
| import os |
| import torch |
| import torch.nn as nn |
|
|
| |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
| from vicreg_loss import VICRegLoss |
|
|
|
|
| class EncoderE(nn.Module): |
| """Environment encoder MLP. |
| |
| Architecture: env_dim -> 128 -> latent_dim |
| Each layer: Linear -> BatchNorm1d -> ReLU -> Dropout |
| |
| Parameters |
| ---------- |
| env_dim : int |
| Number of environment input features (default 24). |
| latent_dim : int |
| Latent embedding dimension (default 16). |
| dropout : float |
| Dropout probability (default 0.3). |
| """ |
|
|
| def __init__(self, env_dim=24, latent_dim=16, dropout=0.3): |
| super().__init__() |
| self.env_dim = env_dim |
| self.latent_dim = latent_dim |
|
|
| self.layers = nn.Sequential( |
| |
| nn.Linear(env_dim, 128), |
| nn.BatchNorm1d(128), |
| nn.ReLU(), |
| nn.Dropout(dropout), |
| |
| nn.Linear(128, latent_dim), |
| nn.BatchNorm1d(latent_dim), |
| nn.ReLU(), |
| nn.Dropout(dropout), |
| ) |
|
|
| def forward(self, x): |
| """ |
| Parameters |
| ---------- |
| x : torch.Tensor, shape (N, env_dim) |
| Standardized environment features. |
| |
| Returns |
| ------- |
| z_env : torch.Tensor, shape (N, latent_dim) |
| Environment embedding. |
| """ |
| return self.layers(x) |
|
|
|
|
| class EncoderP(nn.Module): |
| """PFAM module encoder MLP. |
| |
| Architecture: pfam_dim -> 256 -> 128 -> latent_dim |
| Each layer: Linear -> BatchNorm1d -> ReLU -> Dropout |
| Deeper than EncoderE because PFAM modules encode richer combinatorial |
| information. |
| |
| Parameters |
| ---------- |
| pfam_dim : int |
| Number of PFAM module input features (default 20). |
| latent_dim : int |
| Latent embedding dimension (default 16). |
| dropout : float |
| Dropout probability (default 0.3). |
| """ |
|
|
| def __init__(self, pfam_dim=20, latent_dim=16, dropout=0.3): |
| super().__init__() |
| self.pfam_dim = pfam_dim |
| self.latent_dim = latent_dim |
|
|
| self.layers = nn.Sequential( |
| |
| nn.Linear(pfam_dim, 256), |
| nn.BatchNorm1d(256), |
| nn.ReLU(), |
| nn.Dropout(dropout), |
| |
| nn.Linear(256, 128), |
| nn.BatchNorm1d(128), |
| nn.ReLU(), |
| nn.Dropout(dropout), |
| |
| nn.Linear(128, latent_dim), |
| nn.BatchNorm1d(latent_dim), |
| nn.ReLU(), |
| nn.Dropout(dropout), |
| ) |
|
|
| def forward(self, x): |
| """ |
| Parameters |
| ---------- |
| x : torch.Tensor, shape (N, pfam_dim) |
| Standardized PFAM module features. |
| |
| Returns |
| ------- |
| z_pfam : torch.Tensor, shape (N, latent_dim) |
| PFAM module embedding. |
| """ |
| return self.layers(x) |
|
|
|
|
| class Predictor(nn.Module): |
| """Productivity prediction head. |
| |
| Architecture: input_dim -> 64 -> bio_dim |
| Simple head: Linear -> ReLU -> Linear (no BatchNorm/Dropout). |
| |
| Parameters |
| ---------- |
| input_dim : int |
| Input dimension (latent_dim for env-only, 2*latent_dim for joint). |
| bio_dim : int |
| Number of bio-response targets (default 3: chl-a, POC, NFLH). |
| """ |
|
|
| def __init__(self, input_dim=16, bio_dim=3): |
| super().__init__() |
| self.input_dim = input_dim |
| self.bio_dim = bio_dim |
|
|
| self.layers = nn.Sequential( |
| nn.Linear(input_dim, 64), |
| nn.ReLU(), |
| nn.Linear(64, bio_dim), |
| ) |
|
|
| def forward(self, z): |
| """ |
| Parameters |
| ---------- |
| z : torch.Tensor, shape (N, input_dim) |
| Latent embedding (z_env or [z_env, z_pfam]). |
| |
| Returns |
| ------- |
| y_pred : torch.Tensor, shape (N, bio_dim) |
| Predicted productivity (chl-a, POC, NFLH). |
| """ |
| return self.layers(z) |
|
|
|
|
| class WorldModel(nn.Module): |
| """Joint Environment-Genome Embedding Model. |
| |
| Wraps Encoder_E, Encoder_P, Predictor, and VICRegLoss into a single |
| module for training and inference. |
| |
| Training flow: |
| env -> Encoder_E -> z_env --| |
| |--> VICReg(z_env, z_pfam) |
| pfam -> Encoder_P -> z_pfam--| |
| |--> Predictor(z_env) -> y_pred |
| MSE(y_pred, bio_targets) |
| |
| Inference flow (env-only): |
| env -> Encoder_E -> z_env -> Predictor -> productivity |
| |
| Parameters |
| ---------- |
| env_dim : int |
| Number of environment input features (default 24). |
| pfam_dim : int |
| Number of PFAM module input features (default 20). |
| bio_dim : int |
| Number of bio-response targets (default 3). |
| latent_dim : int |
| Latent embedding dimension (default 16). |
| dropout : float |
| Dropout probability (default 0.3). |
| lambda_inv : float |
| VICReg invariance weight (default 25.0). |
| lambda_var : float |
| VICReg variance weight (default 25.0). |
| lambda_cov : float |
| VICReg covariance weight (default 1.0). |
| pred_alpha : float |
| Weight for productivity prediction loss (default 1.0). |
| """ |
|
|
| def __init__(self, env_dim=24, pfam_dim=20, bio_dim=3, latent_dim=16, |
| dropout=0.3, lambda_inv=25.0, lambda_var=25.0, |
| lambda_cov=1.0, pred_alpha=1.0): |
| super().__init__() |
|
|
| self.env_dim = env_dim |
| self.pfam_dim = pfam_dim |
| self.bio_dim = bio_dim |
| self.latent_dim = latent_dim |
| self.pred_alpha = pred_alpha |
|
|
| |
| self.encoder_e = EncoderE(env_dim, latent_dim, dropout) |
| self.encoder_p = EncoderP(pfam_dim, latent_dim, dropout) |
| self.predictor = Predictor(latent_dim, bio_dim) |
| self.vicreg = VICRegLoss(lambda_inv, lambda_var, lambda_cov) |
|
|
| |
| self.config = { |
| 'env_dim': env_dim, |
| 'pfam_dim': pfam_dim, |
| 'bio_dim': bio_dim, |
| 'latent_dim': latent_dim, |
| 'dropout': dropout, |
| 'lambda_inv': lambda_inv, |
| 'lambda_var': lambda_var, |
| 'lambda_cov': lambda_cov, |
| 'pred_alpha': pred_alpha, |
| } |
|
|
| def forward(self, env, pfam, bio_targets=None, bio_valid=None): |
| """Full training forward pass. |
| |
| Parameters |
| ---------- |
| env : torch.Tensor, shape (N, env_dim) |
| Standardized environment features. |
| pfam : torch.Tensor, shape (N, pfam_dim) |
| Standardized PFAM module features. |
| bio_targets : torch.Tensor or None, shape (N, bio_dim) |
| Standardized bio-response targets. If None, skip pred loss. |
| bio_valid : torch.Tensor or None, shape (N,) |
| Boolean mask: True where all bio targets are valid. |
| If None and bio_targets given, assume all valid. |
| |
| Returns |
| ------- |
| result : dict |
| 'z_env': (N, latent_dim) environment embedding |
| 'z_pfam': (N, latent_dim) PFAM module embedding |
| 'y_pred': (N, bio_dim) predicted productivity |
| 'total_loss': scalar total loss |
| 'vicreg_loss': scalar VICReg loss |
| 'pred_loss': scalar prediction MSE loss (0 if no targets) |
| 'vicreg_components': dict of individual VICReg terms |
| """ |
| |
| z_env = self.encoder_e(env) |
| z_pfam = self.encoder_p(pfam) |
|
|
| |
| y_pred = self.predictor(z_env) |
|
|
| |
| vicreg_loss, vicreg_components = self.vicreg(z_env, z_pfam) |
|
|
| |
| pred_loss = torch.tensor(0.0, device=env.device) |
| if bio_targets is not None: |
| if bio_valid is not None: |
| valid_mask = bio_valid.bool() |
| if valid_mask.sum() > 0: |
| pred_loss = nn.functional.mse_loss( |
| y_pred[valid_mask], bio_targets[valid_mask] |
| ) |
| else: |
| pred_loss = nn.functional.mse_loss(y_pred, bio_targets) |
|
|
| |
| total_loss = vicreg_loss + self.pred_alpha * pred_loss |
|
|
| return { |
| 'z_env': z_env, |
| 'z_pfam': z_pfam, |
| 'y_pred': y_pred, |
| 'total_loss': total_loss, |
| 'vicreg_loss': vicreg_loss, |
| 'pred_loss': pred_loss, |
| 'vicreg_components': vicreg_components, |
| } |
|
|
| def encode_env(self, env): |
| """Encode environment features to latent space. |
| |
| Parameters |
| ---------- |
| env : torch.Tensor, shape (N, env_dim) |
| |
| Returns |
| ------- |
| z_env : torch.Tensor, shape (N, latent_dim) |
| """ |
| return self.encoder_e(env) |
|
|
| def encode_pfam(self, pfam): |
| """Encode PFAM module features to latent space. |
| |
| Parameters |
| ---------- |
| pfam : torch.Tensor, shape (N, pfam_dim) |
| |
| Returns |
| ------- |
| z_pfam : torch.Tensor, shape (N, latent_dim) |
| """ |
| return self.encoder_p(pfam) |
|
|
| def inference(self, env): |
| """Environment-only inference path. |
| |
| Parameters |
| ---------- |
| env : torch.Tensor, shape (N, env_dim) |
| Standardized environment features. |
| |
| Returns |
| ------- |
| y_pred : torch.Tensor, shape (N, bio_dim) |
| Predicted productivity. |
| """ |
| z_env = self.encoder_e(env) |
| return self.predictor(z_env) |
|
|
| def count_parameters(self): |
| """Count total trainable parameters. |
| |
| Returns |
| ------- |
| int |
| Total number of trainable parameters. |
| """ |
| return sum(p.numel() for p in self.parameters() if p.requires_grad) |
|
|
| def count_parameters_by_component(self): |
| """Count trainable parameters per sub-module. |
| |
| Returns |
| ------- |
| dict |
| {'encoder_e': int, 'encoder_p': int, 'predictor': int, 'total': int} |
| """ |
| counts = {} |
| for name, module in [('encoder_e', self.encoder_e), |
| ('encoder_p', self.encoder_p), |
| ('predictor', self.predictor)]: |
| counts[name] = sum(p.numel() for p in module.parameters() |
| if p.requires_grad) |
| counts['total'] = sum(counts.values()) |
| return counts |
|
|
|
|
| def self_test(): |
| """Run comprehensive self-tests for WorldModel. Returns True if all pass.""" |
| tests_passed = 0 |
| tests_total = 0 |
|
|
| def check(name, condition): |
| nonlocal tests_passed, tests_total |
| tests_total += 1 |
| if condition: |
| tests_passed += 1 |
| print(f" PASS: {name}") |
| else: |
| print(f" FAIL: {name}") |
|
|
| print("=" * 70) |
| print("WorldModel Self-Tests") |
| print("=" * 70) |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"Device: {device}") |
|
|
| |
| print("\nTest 1: Instantiation with default parameters") |
| model = WorldModel(env_dim=24, pfam_dim=20, bio_dim=3, |
| latent_dim=16, dropout=0.3).to(device) |
| params = model.count_parameters() |
| param_detail = model.count_parameters_by_component() |
| print(f" Total parameters: {params:,}") |
| print(f" Encoder_E: {param_detail['encoder_e']:,}") |
| print(f" Encoder_P: {param_detail['encoder_p']:,}") |
| print(f" Predictor: {param_detail['predictor']:,}") |
| check("model instantiates", model is not None) |
| check("total params > 0", params > 0) |
| check("param counts sum correctly", |
| param_detail['total'] == params) |
|
|
| |
| print("\nTest 2: Forward pass shapes") |
| N = 64 |
| env = torch.randn(N, 24, device=device) |
| pfam = torch.randn(N, 20, device=device) |
| bio = torch.randn(N, 3, device=device) |
| bio_valid = torch.ones(N, dtype=torch.bool, device=device) |
|
|
| model.train() |
| result = model(env, pfam, bio, bio_valid) |
| check("z_env shape", result['z_env'].shape == (N, 16)) |
| check("z_pfam shape", result['z_pfam'].shape == (N, 16)) |
| check("y_pred shape", result['y_pred'].shape == (N, 3)) |
| check("total_loss is scalar", result['total_loss'].dim() == 0) |
| check("vicreg_loss is scalar", result['vicreg_loss'].dim() == 0) |
| check("pred_loss is scalar", result['pred_loss'].dim() == 0) |
| check("vicreg_components present", |
| all(k in result['vicreg_components'] |
| for k in ['invariance', 'variance_a', 'variance_b', |
| 'covariance_a', 'covariance_b', 'total'])) |
|
|
| |
| print("\nTest 3: Forward without bio targets (VICReg-only)") |
| result_no_bio = model(env, pfam, bio_targets=None) |
| check("works without bio targets", result_no_bio['total_loss'].item() > 0) |
| check("pred_loss is zero", result_no_bio['pred_loss'].item() == 0.0) |
|
|
| |
| print("\nTest 4: Forward with partial bio_valid mask") |
| partial_valid = torch.zeros(N, dtype=torch.bool, device=device) |
| partial_valid[:32] = True |
| result_partial = model(env, pfam, bio, partial_valid) |
| check("works with partial bio_valid", result_partial['total_loss'].item() > 0) |
| check("pred_loss computed on valid subset", |
| result_partial['pred_loss'].item() >= 0) |
|
|
| |
| all_invalid = torch.zeros(N, dtype=torch.bool, device=device) |
| result_novalid = model(env, pfam, bio, all_invalid) |
| check("works with all-invalid mask", |
| result_novalid['pred_loss'].item() == 0.0) |
|
|
| |
| print("\nTest 5: Gradient flow") |
| model.zero_grad() |
| result = model(env, pfam, bio, bio_valid) |
| result['total_loss'].backward() |
| all_params = list(model.named_parameters()) |
| grad_count = sum(1 for _, p in all_params |
| if p.grad is not None and p.grad.abs().sum() > 0) |
| check(f"all {len(all_params)} param tensors receive gradients", |
| grad_count == len(all_params)) |
| no_nan = all(not torch.isnan(p.grad).any() |
| for _, p in all_params if p.grad is not None) |
| check("no NaN in any gradient", no_nan) |
|
|
| |
| print("\nTest 6: Inference mode (env-only)") |
| model.eval() |
| with torch.no_grad(): |
| y_pred_inf = model.inference(env) |
| check("inference returns correct shape", y_pred_inf.shape == (N, 3)) |
| check("no NaN in inference output", not torch.isnan(y_pred_inf).any()) |
|
|
| |
| print("\nTest 7: Training convergence (50 steps)") |
| model.train() |
| optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) |
| losses = [] |
| for step in range(50): |
| optimizer.zero_grad() |
| result = model(env, pfam, bio, bio_valid) |
| result['total_loss'].backward() |
| optimizer.step() |
| losses.append(result['total_loss'].item()) |
| reduction = (losses[0] - losses[-1]) / losses[0] * 100 |
| print(f" Loss: {losses[0]:.2f} -> {losses[-1]:.2f} ({reduction:.1f}% reduction)") |
| check("loss decreases over 50 steps", losses[-1] < losses[0]) |
| check("no NaN in loss", all(not (l != l) for l in losses)) |
|
|
| |
| print("\nTest 8: Different latent dimensions {16, 32, 64}") |
| for ld in [16, 32, 64]: |
| m = WorldModel(env_dim=24, pfam_dim=20, latent_dim=ld).to(device) |
| m.train() |
| r = m(env, pfam, bio, bio_valid) |
| check(f"latent_dim={ld}: z_env shape ({N},{ld})", |
| r['z_env'].shape == (N, ld)) |
| check(f"latent_dim={ld}: valid loss", |
| r['total_loss'].item() > 0 and not torch.isnan(r['total_loss'])) |
|
|
| |
| print("\nTest 9: Custom VICReg configurations") |
| configs = { |
| 'default': dict(lambda_inv=25.0, lambda_var=25.0, lambda_cov=1.0), |
| 'high_variance': dict(lambda_inv=10.0, lambda_var=50.0, lambda_cov=1.0), |
| 'high_covariance': dict(lambda_inv=25.0, lambda_var=25.0, lambda_cov=10.0), |
| } |
| for name, cfg in configs.items(): |
| m = WorldModel(env_dim=24, pfam_dim=20, **cfg).to(device) |
| m.train() |
| r = m(env, pfam, bio, bio_valid) |
| check(f"{name}: valid loss", |
| r['total_loss'].item() > 0 and not torch.isnan(r['total_loss'])) |
|
|
| |
| print("\nTest 10: Minimum batch size (N=2)") |
| env_small = torch.randn(2, 24, device=device) |
| pfam_small = torch.randn(2, 20, device=device) |
| bio_small = torch.randn(2, 3, device=device) |
| valid_small = torch.ones(2, dtype=torch.bool, device=device) |
| model.train() |
| r_small = model(env_small, pfam_small, bio_small, valid_small) |
| check("batch size 2 works", not torch.isnan(r_small['total_loss'])) |
|
|
| |
| print("\nTest 11: Standalone encoder methods") |
| model.eval() |
| with torch.no_grad(): |
| ze = model.encode_env(env) |
| zp = model.encode_pfam(pfam) |
| check("encode_env shape", ze.shape == (N, 16)) |
| check("encode_pfam shape", zp.shape == (N, 16)) |
|
|
| |
| print("\nTest 12: GPU computation") |
| if torch.cuda.is_available(): |
| m_gpu = WorldModel(env_dim=24, pfam_dim=20).to('cuda') |
| m_gpu.train() |
| e_gpu = torch.randn(32, 24, device='cuda') |
| p_gpu = torch.randn(32, 20, device='cuda') |
| b_gpu = torch.randn(32, 3, device='cuda') |
| v_gpu = torch.ones(32, dtype=torch.bool, device='cuda') |
| r_gpu = m_gpu(e_gpu, p_gpu, b_gpu, v_gpu) |
| r_gpu['total_loss'].backward() |
| check("GPU forward + backward succeeded", |
| not torch.isnan(r_gpu['total_loss'])) |
| else: |
| print(" SKIP: CUDA not available") |
| tests_total += 1 |
| tests_passed += 1 |
|
|
| |
| print("\nTest 13: Model serialization (save/load)") |
| import tempfile |
| model.eval() |
| with torch.no_grad(): |
| y_before = model.inference(env) |
|
|
| checkpoint = { |
| 'model_state_dict': model.state_dict(), |
| 'config': model.config, |
| } |
| with tempfile.NamedTemporaryFile(suffix='.pt', delete=False) as f: |
| tmp_path = f.name |
| torch.save(checkpoint, f) |
|
|
| |
| loaded = torch.load(tmp_path, map_location=device, weights_only=False) |
| model2 = WorldModel(**loaded['config']).to(device) |
| model2.load_state_dict(loaded['model_state_dict']) |
| model2.eval() |
| with torch.no_grad(): |
| y_after = model2.inference(env) |
|
|
| max_diff = (y_before - y_after).abs().max().item() |
| print(f" Max prediction diff after save/load: {max_diff:.2e}") |
| check("save/load produces identical predictions", max_diff < 1e-6) |
|
|
| os.unlink(tmp_path) |
|
|
| |
| print(f"\n{'=' * 70}") |
| print(f"Results: {tests_passed}/{tests_total} tests passed") |
| print(f"{'=' * 70}") |
|
|
| return tests_passed == tests_total |
|
|
|
|
| if __name__ == '__main__': |
| success = self_test() |
| sys.exit(0 if success else 1) |
|
|