File size: 8,320 Bytes
cdc3ab7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 | # GeneSetCLIP: Contrastive Pretraining for Gene SetβText Alignment
A CLIP-style contrastive model that aligns **biological text descriptions** with **gene-set representations**, trained on MSigDB v2024.1 (human + mouse).
Given a text query like *"type I interferon signaling"*, the model retrieves the corresponding gene set β and vice versa.
## Architecture
```
TEXT SIDE GENE SET SIDE
βββββββββββββββββββββ ββββββββββββββββββββββββββ
"Genes up-regulated in {STAT1, IRF7, ISG15,
response to IFN-Ξ±..." OAS1, MX1, IFIT1, ...}
β β
βΌ βΌ
BioLORD-2023 (frozen) GSFM (fine-tuned, lr/10)
[768-dim] [256-dim]
β β
βΌ βΌ
text_proj (trainable) gene_proj (trainable)
768 β 512 β 256 256 β 256 β 256
β β
βΌ βΌ
z_text [256] z_gene [256]
β β
βββββββ L2-normalize ββββββββββββββββββββ
β
βΌ
InfoNCE loss (Ο learnable)
```
### Components
| Component | Model | Dim | Training |
|-----------|-------|-----|----------|
| **Gene encoder** | [GSFM](https://huggingface.co/maayanlab/gsfm-rummagene) (MLP autoencoder, Set model) | 256 | Fine-tuned at 1/10 LR |
| **Text encoder** | [BioLORD-2023](https://huggingface.co/FremyCompany/BioLORD-2023) (MPNet-base) | 768 | Frozen |
| **Gene projection** | MLP: 256 β 256 β 256 + LayerNorm | 256 | Trained |
| **Text projection** | MLP: 768 β 512 β 256 + LayerNorm | 256 | Trained |
### Why these encoders?
- **GSFM**: Purpose-built gene-set encoder from Ma'ayan Lab. Takes variable-length gene sets as input (multi-hot encoding β MLP), producing permutation-invariant 256-dim embeddings. Pretrained on Rummagene (gene sets from PubMed tables).
- **BioLORD-2023**: Ontology-grounded biomedical sentence embeddings. Trained on UMLS concept name-synonym pairs + LLM-generated definitions β structurally identical to MSigDB gene set descriptions (name + definition anchored in GO/KEGG/Reactome).
## Training Data
**MSigDB v2024.1** β 50,896 gene setβtext pairs from the Molecular Signatures Database.
| Split | Collections | Pairs | Purpose |
|-------|-------------|-------|---------|
| Train | C2, C5, C8, C1, M2, M5, M8, M1 | 38,622 | Curated, GO, cell type signatures |
| Val | C3, C4, M3 | 6,766 | Regulatory targets, computational |
| Test | H, C6, C7, MH | 5,508 | Hallmarks, oncogenic, immunologic |
Each pair consists of:
- **Text**: `[Collection: H] [Species: human]\nHALLMARK APOPTOSIS\nGenes mediating programmed cell death by activation of caspases.`
- **Genes**: `["CASP3", "CASP6", "TP53", "BAX", ...]`
Data augmentation: 20% gene dropout (randomly remove genes each epoch).
## Training Recipe
Based on [ProtST](https://arxiv.org/abs/2301.12040) (ICML 2023) adapted for gene sets:
| Parameter | Value |
|-----------|-------|
| Loss | Symmetric InfoNCE (NT-Xent) |
| Temperature | 0.07 (learnable, clamped [0.01, 1.0]) |
| Batch size | 256 |
| LR (projections) | 1e-4 |
| LR (gene encoder) | 1e-5 (10x lower) |
| LR (text encoder) | 0 (frozen) |
| Optimizer | AdamW (weight_decay=0.01) |
| Schedule | 500-step warmup β cosine decay |
| Epochs | 50 (early stopping, patience=10) |
| Gene dropout | 20% |
| Max gene set size | 512 |
| Hardware | 1Γ T4 GPU (16GB) |
## Quick Start
### Installation
```bash
pip install torch sentence-transformers huggingface_hub safetensors lightning
GIT_LFS_SKIP_SMUDGE=1 pip install "git+https://huggingface.co/maayanlab/gsfm"
```
### Inference
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
from gsfm import GSFM, Vocab
from sentence_transformers import SentenceTransformer
from huggingface_hub import hf_hub_download
# Load gene encoder + vocab
gene_encoder = GSFM.from_pretrained("maayanlab/gsfm-rummagene")
vocab = Vocab.from_pretrained("maayanlab/gsfm-rummagene")
gene_encoder.eval()
# Load text encoder
text_encoder = SentenceTransformer("FremyCompany/BioLORD-2023")
# Load projection heads
clip_path = hf_hub_download("AliSaadatV/GeneSetCLIP", "clip_model.pt")
class ProjectionHead(nn.Module):
def __init__(self, d_in, d_h, d_out):
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_in, d_h), nn.GELU(), nn.Dropout(0.1),
nn.Linear(d_h, d_out), nn.LayerNorm(d_out))
def forward(self, x): return self.net(x)
class GeneSetCLIP(nn.Module):
def __init__(self):
super().__init__()
self.log_temperature = nn.Parameter(torch.zeros(1))
self.text_proj = ProjectionHead(768, 512, 256)
self.gene_proj = ProjectionHead(256, 256, 256)
clip_model = GeneSetCLIP()
clip_model.load_state_dict(torch.load(clip_path, map_location="cpu", weights_only=True))
clip_model.eval()
# --- Encode a gene set ---
genes = ["STAT1", "IRF7", "ISG15", "OAS1", "MX1", "IFIT1"]
gene_ids = torch.tensor([vocab(genes)])
with torch.no_grad():
gene_emb = gene_encoder.encode(gene_ids)
z_gene = F.normalize(clip_model.gene_proj(gene_emb), dim=-1)
# --- Encode text queries ---
queries = [
"Interferon alpha response genes",
"Apoptosis signaling",
"Cell cycle regulation",
]
text_embs = text_encoder.encode(queries, convert_to_tensor=True)
with torch.no_grad():
z_text = F.normalize(clip_model.text_proj(text_embs), dim=-1)
# --- Compute similarities ---
sims = (z_gene @ z_text.T).squeeze()
for q, s in zip(queries, sims):
print(f" {s.item():.3f} {q}")
# Expected: highest similarity for "Interferon alpha response genes"
```
## Training from Scratch
### 1. Process MSigDB data
```bash
python data_processing.py
```
This downloads all MSigDB GMT files and scrapes descriptions.
### 2. Train
```bash
# Self-contained (downloads data from Hub automatically)
python train_job.py
# Or with local data
python train.py
```
### 3. On HF Jobs (GPU)
```python
from huggingface_hub import HfApi
# Submit as HF Job with GPU
# See train_job.py for the self-contained script
```
## Downstream Applications
1. **Zero-shot gene set annotation**: Embed a gene list from an experiment β find nearest text descriptions
2. **Cross-modal search**: Text query β gene sets, or gene list β pathway descriptions
3. **Gene set similarity**: Compare gene sets via embedding cosine similarity (captures functional similarity beyond gene overlap)
4. **Cell type annotation**: Embed cell marker gene sets β match to cell type text descriptions
5. **Biological RAG**: Use MSigDB embeddings as retrieval corpus for LLM-based reasoning
## Key References
- [ProtST](https://arxiv.org/abs/2301.12040) (ICML 2023) β Protein-text contrastive alignment
- [MoleculeSTM](https://arxiv.org/abs/2212.10789) (Nature MI 2024) β Molecule-text alignment
- [LangCell](https://arxiv.org/abs/2405.06708) β Cell-text contrastive with MSigDB pathways
- [BioLORD-2023](https://arxiv.org/abs/2311.16075) (JAMIA 2024) β Biomedical sentence embeddings
- [Set Transformer](https://arxiv.org/abs/1810.00825) β Permutation-invariant set encoding
## Files
| File | Description |
|------|-------------|
| `clip_model.pt` | Trained projection heads (text + gene) |
| `gene_encoder.pt` | Fine-tuned GSFM gene encoder |
| `config.json` | Training configuration |
| `vocab.json` | Gene symbol β token ID mapping |
| `test_results.json` | Evaluation metrics on test set |
| `train_job.py` | Self-contained training script (for HF Jobs) |
| `train.py` | Modular training script |
| `data_processing.py` | MSigDB data download + processing |
## License
- Code: MIT
- GSFM model: BSD-3-Clause
- BioLORD-2023: Other (requires UMLS account)
- MSigDB data: [Creative Commons Attribution 4.0](https://www.gsea-msigdb.org/gsea/msigdb/licenses.jsp)
|