--- language: en license: apache-2.0 tags: - marine-biology - metagenomics - environmental-modeling - protein-domains - tara-oceans - vicreg - joint-embedding - self-supervised-learning - pytorch library_name: pytorch pipeline_tag: tabular-regression --- # TARA-WorldModel-VICReg Joint environment--genome embedding model for marine ecosystem productivity prediction, trained with Variance-Invariance-Covariance Regularization (VICReg) on TARA Oceans data. ## Model Description The World Model learns a shared latent space that aligns environmental context (satellite-derived variables) with microalgal protein domain composition (Pfam module abundances), then predicts marine productivity (chlorophyll-a, POC, NFLH) from the joint embedding. ### Architecture ``` Training: env (24 dims) → EncoderE(128 → 32) → z_env ─┐ ├→ VICReg loss pfam (20 dims) → EncoderP(256 → 128 → 32) → z_pfam └→ Predictor(64 → 3) → productivity Inference (environment-only): env → EncoderE → z_env → Predictor → [chl-a, POC, NFLH] ``` - **EncoderE**: Linear(24→128) + BN + ReLU + Dropout(0.3) → Linear(128→32) + BN + ReLU + Dropout(0.3) - **EncoderP**: Linear(20→256) + BN + ReLU + Dropout(0.3) → Linear(256→128) + BN + ReLU + Dropout(0.3) → Linear(128→32) + BN + ReLU + Dropout(0.3) - **Predictor**: Linear(32→64) + ReLU → Linear(64→3) - **Total parameters**: 53,187 ### VICReg Loss Non-contrastive self-supervised alignment (Bardes et al., ICLR 2022): - **Invariance**: MSE between co-located env/pfam embeddings (λ=25) - **Variance**: Hinge loss preventing embedding collapse (λ=25) - **Covariance**: Off-diagonal penalty decorrelating dimensions (λ=1) - **Prediction**: MSE on productivity targets (α=1) ## Performance Joint embedding improves POC prediction (R² 0.422 → 0.532, 26% relative improvement) over environment-only baseline. Chlorophyll-a and NFLH are better predicted by environment alone (directly satellite-measured). ## Files ### Fold Checkpoints (leave-one-basin-out spatial CV) Two training runs are provided: - `world_model_fold_*_20260127_110243.pt` — Initial configuration (latent_dim=16) - `world_model_fold_*_20260127_111754.pt` — Best configuration from hyperparameter sweep (latent_dim=32) Six folds per run: Arctic, Atlantic, Indian, Mediterranean, Pacific, Southern. ### Configuration - `phase2_best_config.json` — Hyperparameter sweep results (54 configurations, 3 seeds each) ## Hyperparameters (Best Config) | Parameter | Value | |-----------|-------| | latent_dim | 32 | | dropout | 0.3 | | λ_invariance | 25.0 | | λ_variance | 25.0 | | λ_covariance | 1.0 | | pred_alpha | 1.0 | | learning_rate | 0.001 | | weight_decay | 1e-4 | | batch_size | 128 | | max_epochs | 300 | | patience | 30 | | grad_clip | 1.0 | ## Usage ```python import torch # Load fold checkpoint ckpt = torch.load("world_model_fold_Atlantic_20260127_111754.pt", map_location="cpu") # ckpt contains model_state_dict for the full WorldModel # Requires WorldModel class from the training codebase ``` ## Dataset - **1,810 ocean samples** with co-located environment and Pfam profiles - **24 environmental variables** (GEE oceanographic/atmospheric) - **20 Pfam module features** (aggregated from 9,466 domains via co-occurrence clustering) - **3 productivity targets** (chlorophyll-a, POC, NFLH) - **Spatial cross-validation**: Leave-one-basin-out (6 ocean basins) ## Related Models - [GreenGenomicsLab/algaGPT](https://huggingface.co/GreenGenomicsLab/algaGPT) — AlgaGPT protein classification - [GreenGenomicsLab/LA4SR-Pythia70m-b-ckpt-55000](https://huggingface.co/GreenGenomicsLab/LA4SR-Pythia70m-b-ckpt-55000) — LA4SR-Pythia classification - [GreenGenomicsLab/TARA-ELF-NET](https://huggingface.co/GreenGenomicsLab/TARA-ELF-NET) — Deep bidirectional env↔pfam models - [GreenGenomicsLab/TARA-XGBoost-Bidirectional](https://huggingface.co/GreenGenomicsLab/TARA-XGBoost-Bidirectional) — XGBoost bidirectional models ## Citation LA4SR classification models: > Nelson DR, Jaiswal AK, Ismail NS, Mystikou A, Salehi-Ashtiani K. *Patterns*. 2024;6(11). ## License Apache 2.0