File size: 4,286 Bytes
28dd986
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
---
language: en
license: apache-2.0
tags:
  - marine-biology
  - metagenomics
  - environmental-modeling
  - protein-domains
  - tara-oceans
  - vicreg
  - joint-embedding
  - self-supervised-learning
  - pytorch
library_name: pytorch
pipeline_tag: tabular-regression
---

# TARA-WorldModel-VICReg

Joint environment--genome embedding model for marine ecosystem productivity prediction, trained with Variance-Invariance-Covariance Regularization (VICReg) on TARA Oceans data.

## Model Description

The World Model learns a shared latent space that aligns environmental context (satellite-derived variables) with microalgal protein domain composition (Pfam module abundances), then predicts marine productivity (chlorophyll-a, POC, NFLH) from the joint embedding.

### Architecture

```
Training:
  env (24 dims) β†’ EncoderE(128 β†’ 32) β†’ z_env ─┐
                                                 β”œβ†’ VICReg loss
  pfam (20 dims) β†’ EncoderP(256 β†’ 128 β†’ 32) β†’ z_pfam
                                                 β””β†’ Predictor(64 β†’ 3) β†’ productivity

Inference (environment-only):
  env β†’ EncoderE β†’ z_env β†’ Predictor β†’ [chl-a, POC, NFLH]
```

- **EncoderE**: Linear(24β†’128) + BN + ReLU + Dropout(0.3) β†’ Linear(128β†’32) + BN + ReLU + Dropout(0.3)
- **EncoderP**: Linear(20β†’256) + BN + ReLU + Dropout(0.3) β†’ Linear(256β†’128) + BN + ReLU + Dropout(0.3) β†’ Linear(128β†’32) + BN + ReLU + Dropout(0.3)
- **Predictor**: Linear(32β†’64) + ReLU β†’ Linear(64β†’3)
- **Total parameters**: 53,187

### VICReg Loss

Non-contrastive self-supervised alignment (Bardes et al., ICLR 2022):
- **Invariance**: MSE between co-located env/pfam embeddings (Ξ»=25)
- **Variance**: Hinge loss preventing embedding collapse (Ξ»=25)
- **Covariance**: Off-diagonal penalty decorrelating dimensions (Ξ»=1)
- **Prediction**: MSE on productivity targets (Ξ±=1)

## Performance

Joint embedding improves POC prediction (RΒ² 0.422 β†’ 0.532, 26% relative improvement) over environment-only baseline. Chlorophyll-a and NFLH are better predicted by environment alone (directly satellite-measured).

## Files

### Fold Checkpoints (leave-one-basin-out spatial CV)

Two training runs are provided:
- `world_model_fold_*_20260127_110243.pt` β€” Initial configuration (latent_dim=16)
- `world_model_fold_*_20260127_111754.pt` β€” Best configuration from hyperparameter sweep (latent_dim=32)

Six folds per run: Arctic, Atlantic, Indian, Mediterranean, Pacific, Southern.

### Configuration

- `phase2_best_config.json` β€” Hyperparameter sweep results (54 configurations, 3 seeds each)

## Hyperparameters (Best Config)

| Parameter | Value |
|-----------|-------|
| latent_dim | 32 |
| dropout | 0.3 |
| Ξ»_invariance | 25.0 |
| Ξ»_variance | 25.0 |
| Ξ»_covariance | 1.0 |
| pred_alpha | 1.0 |
| learning_rate | 0.001 |
| weight_decay | 1e-4 |
| batch_size | 128 |
| max_epochs | 300 |
| patience | 30 |
| grad_clip | 1.0 |

## Usage

```python
import torch

# Load fold checkpoint
ckpt = torch.load("world_model_fold_Atlantic_20260127_111754.pt", map_location="cpu")

# ckpt contains model_state_dict for the full WorldModel
# Requires WorldModel class from the training codebase
```

## Dataset

- **1,810 ocean samples** with co-located environment and Pfam profiles
- **24 environmental variables** (GEE oceanographic/atmospheric)
- **20 Pfam module features** (aggregated from 9,466 domains via co-occurrence clustering)
- **3 productivity targets** (chlorophyll-a, POC, NFLH)
- **Spatial cross-validation**: Leave-one-basin-out (6 ocean basins)

## Related Models

- [GreenGenomicsLab/algaGPT](https://huggingface.co/GreenGenomicsLab/algaGPT) β€” AlgaGPT protein classification
- [GreenGenomicsLab/LA4SR-Pythia70m-b-ckpt-55000](https://huggingface.co/GreenGenomicsLab/LA4SR-Pythia70m-b-ckpt-55000) β€” LA4SR-Pythia classification
- [GreenGenomicsLab/TARA-ELF-NET](https://huggingface.co/GreenGenomicsLab/TARA-ELF-NET) β€” Deep bidirectional env↔pfam models
- [GreenGenomicsLab/TARA-XGBoost-Bidirectional](https://huggingface.co/GreenGenomicsLab/TARA-XGBoost-Bidirectional) β€” XGBoost bidirectional models

## Citation

LA4SR classification models:
> Nelson DR, Jaiswal AK, Ismail NS, Mystikou A, Salehi-Ashtiani K. *Patterns*. 2024;6(11).

## License

Apache 2.0