language: en
license: apache-2.0
tags:
- marine-biology
- metagenomics
- environmental-modeling
- protein-domains
- tara-oceans
- vicreg
- joint-embedding
- self-supervised-learning
- pytorch
library_name: pytorch
pipeline_tag: tabular-regression
TARA-WorldModel-VICReg
Joint environment--genome embedding model for marine ecosystem productivity prediction, trained with Variance-Invariance-Covariance Regularization (VICReg) on TARA Oceans data.
Model Description
The World Model learns a shared latent space that aligns environmental context (satellite-derived variables) with microalgal protein domain composition (Pfam module abundances), then predicts marine productivity (chlorophyll-a, POC, NFLH) from the joint embedding.
Architecture
Training:
env (24 dims) β EncoderE(128 β 32) β z_env ββ
ββ VICReg loss
pfam (20 dims) β EncoderP(256 β 128 β 32) β z_pfam
ββ Predictor(64 β 3) β productivity
Inference (environment-only):
env β EncoderE β z_env β Predictor β [chl-a, POC, NFLH]
- EncoderE: Linear(24β128) + BN + ReLU + Dropout(0.3) β Linear(128β32) + BN + ReLU + Dropout(0.3)
- EncoderP: Linear(20β256) + BN + ReLU + Dropout(0.3) β Linear(256β128) + BN + ReLU + Dropout(0.3) β Linear(128β32) + BN + ReLU + Dropout(0.3)
- Predictor: Linear(32β64) + ReLU β Linear(64β3)
- Total parameters: 53,187
VICReg Loss
Non-contrastive self-supervised alignment (Bardes et al., ICLR 2022):
- Invariance: MSE between co-located env/pfam embeddings (Ξ»=25)
- Variance: Hinge loss preventing embedding collapse (Ξ»=25)
- Covariance: Off-diagonal penalty decorrelating dimensions (Ξ»=1)
- Prediction: MSE on productivity targets (Ξ±=1)
Performance
Joint embedding improves POC prediction (RΒ² 0.422 β 0.532, 26% relative improvement) over environment-only baseline. Chlorophyll-a and NFLH are better predicted by environment alone (directly satellite-measured).
Files
Fold Checkpoints (leave-one-basin-out spatial CV)
Two training runs are provided:
world_model_fold_*_20260127_110243.ptβ Initial configuration (latent_dim=16)world_model_fold_*_20260127_111754.ptβ Best configuration from hyperparameter sweep (latent_dim=32)
Six folds per run: Arctic, Atlantic, Indian, Mediterranean, Pacific, Southern.
Configuration
phase2_best_config.jsonβ Hyperparameter sweep results (54 configurations, 3 seeds each)
Hyperparameters (Best Config)
| Parameter | Value |
|---|---|
| latent_dim | 32 |
| dropout | 0.3 |
| Ξ»_invariance | 25.0 |
| Ξ»_variance | 25.0 |
| Ξ»_covariance | 1.0 |
| pred_alpha | 1.0 |
| learning_rate | 0.001 |
| weight_decay | 1e-4 |
| batch_size | 128 |
| max_epochs | 300 |
| patience | 30 |
| grad_clip | 1.0 |
Usage
import torch
# Load fold checkpoint
ckpt = torch.load("world_model_fold_Atlantic_20260127_111754.pt", map_location="cpu")
# ckpt contains model_state_dict for the full WorldModel
# Requires WorldModel class from the training codebase
Dataset
- 1,810 ocean samples with co-located environment and Pfam profiles
- 24 environmental variables (GEE oceanographic/atmospheric)
- 20 Pfam module features (aggregated from 9,466 domains via co-occurrence clustering)
- 3 productivity targets (chlorophyll-a, POC, NFLH)
- Spatial cross-validation: Leave-one-basin-out (6 ocean basins)
Related Models
- GreenGenomicsLab/algaGPT β AlgaGPT protein classification
- GreenGenomicsLab/LA4SR-Pythia70m-b-ckpt-55000 β LA4SR-Pythia classification
- GreenGenomicsLab/TARA-ELF-NET β Deep bidirectional envβpfam models
- GreenGenomicsLab/TARA-XGBoost-Bidirectional β XGBoost bidirectional models
Citation
LA4SR classification models:
Nelson DR, Jaiswal AK, Ismail NS, Mystikou A, Salehi-Ashtiani K. Patterns. 2024;6(11).
License
Apache 2.0