GeneSetCLIP / README.md
AliSaadatV's picture
Add comprehensive project README
cdc3ab7 verified
# GeneSetCLIP: Contrastive Pretraining for Gene Set–Text Alignment
A CLIP-style contrastive model that aligns **biological text descriptions** with **gene-set representations**, trained on MSigDB v2024.1 (human + mouse).
Given a text query like *"type I interferon signaling"*, the model retrieves the corresponding gene set β€” and vice versa.
## Architecture
```
TEXT SIDE GENE SET SIDE
───────────────────── ──────────────────────────
"Genes up-regulated in {STAT1, IRF7, ISG15,
response to IFN-Ξ±..." OAS1, MX1, IFIT1, ...}
β”‚ β”‚
β–Ό β–Ό
BioLORD-2023 (frozen) GSFM (fine-tuned, lr/10)
[768-dim] [256-dim]
β”‚ β”‚
β–Ό β–Ό
text_proj (trainable) gene_proj (trainable)
768 β†’ 512 β†’ 256 256 β†’ 256 β†’ 256
β”‚ β”‚
β–Ό β–Ό
z_text [256] z_gene [256]
β”‚ β”‚
└────── L2-normalize β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”‚
β–Ό
InfoNCE loss (Ο„ learnable)
```
### Components
| Component | Model | Dim | Training |
|-----------|-------|-----|----------|
| **Gene encoder** | [GSFM](https://huggingface.co/maayanlab/gsfm-rummagene) (MLP autoencoder, Set model) | 256 | Fine-tuned at 1/10 LR |
| **Text encoder** | [BioLORD-2023](https://huggingface.co/FremyCompany/BioLORD-2023) (MPNet-base) | 768 | Frozen |
| **Gene projection** | MLP: 256 β†’ 256 β†’ 256 + LayerNorm | 256 | Trained |
| **Text projection** | MLP: 768 β†’ 512 β†’ 256 + LayerNorm | 256 | Trained |
### Why these encoders?
- **GSFM**: Purpose-built gene-set encoder from Ma'ayan Lab. Takes variable-length gene sets as input (multi-hot encoding β†’ MLP), producing permutation-invariant 256-dim embeddings. Pretrained on Rummagene (gene sets from PubMed tables).
- **BioLORD-2023**: Ontology-grounded biomedical sentence embeddings. Trained on UMLS concept name-synonym pairs + LLM-generated definitions β€” structurally identical to MSigDB gene set descriptions (name + definition anchored in GO/KEGG/Reactome).
## Training Data
**MSigDB v2024.1** β€” 50,896 gene set–text pairs from the Molecular Signatures Database.
| Split | Collections | Pairs | Purpose |
|-------|-------------|-------|---------|
| Train | C2, C5, C8, C1, M2, M5, M8, M1 | 38,622 | Curated, GO, cell type signatures |
| Val | C3, C4, M3 | 6,766 | Regulatory targets, computational |
| Test | H, C6, C7, MH | 5,508 | Hallmarks, oncogenic, immunologic |
Each pair consists of:
- **Text**: `[Collection: H] [Species: human]\nHALLMARK APOPTOSIS\nGenes mediating programmed cell death by activation of caspases.`
- **Genes**: `["CASP3", "CASP6", "TP53", "BAX", ...]`
Data augmentation: 20% gene dropout (randomly remove genes each epoch).
## Training Recipe
Based on [ProtST](https://arxiv.org/abs/2301.12040) (ICML 2023) adapted for gene sets:
| Parameter | Value |
|-----------|-------|
| Loss | Symmetric InfoNCE (NT-Xent) |
| Temperature | 0.07 (learnable, clamped [0.01, 1.0]) |
| Batch size | 256 |
| LR (projections) | 1e-4 |
| LR (gene encoder) | 1e-5 (10x lower) |
| LR (text encoder) | 0 (frozen) |
| Optimizer | AdamW (weight_decay=0.01) |
| Schedule | 500-step warmup β†’ cosine decay |
| Epochs | 50 (early stopping, patience=10) |
| Gene dropout | 20% |
| Max gene set size | 512 |
| Hardware | 1Γ— T4 GPU (16GB) |
## Quick Start
### Installation
```bash
pip install torch sentence-transformers huggingface_hub safetensors lightning
GIT_LFS_SKIP_SMUDGE=1 pip install "git+https://huggingface.co/maayanlab/gsfm"
```
### Inference
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
from gsfm import GSFM, Vocab
from sentence_transformers import SentenceTransformer
from huggingface_hub import hf_hub_download
# Load gene encoder + vocab
gene_encoder = GSFM.from_pretrained("maayanlab/gsfm-rummagene")
vocab = Vocab.from_pretrained("maayanlab/gsfm-rummagene")
gene_encoder.eval()
# Load text encoder
text_encoder = SentenceTransformer("FremyCompany/BioLORD-2023")
# Load projection heads
clip_path = hf_hub_download("AliSaadatV/GeneSetCLIP", "clip_model.pt")
class ProjectionHead(nn.Module):
def __init__(self, d_in, d_h, d_out):
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_in, d_h), nn.GELU(), nn.Dropout(0.1),
nn.Linear(d_h, d_out), nn.LayerNorm(d_out))
def forward(self, x): return self.net(x)
class GeneSetCLIP(nn.Module):
def __init__(self):
super().__init__()
self.log_temperature = nn.Parameter(torch.zeros(1))
self.text_proj = ProjectionHead(768, 512, 256)
self.gene_proj = ProjectionHead(256, 256, 256)
clip_model = GeneSetCLIP()
clip_model.load_state_dict(torch.load(clip_path, map_location="cpu", weights_only=True))
clip_model.eval()
# --- Encode a gene set ---
genes = ["STAT1", "IRF7", "ISG15", "OAS1", "MX1", "IFIT1"]
gene_ids = torch.tensor([vocab(genes)])
with torch.no_grad():
gene_emb = gene_encoder.encode(gene_ids)
z_gene = F.normalize(clip_model.gene_proj(gene_emb), dim=-1)
# --- Encode text queries ---
queries = [
"Interferon alpha response genes",
"Apoptosis signaling",
"Cell cycle regulation",
]
text_embs = text_encoder.encode(queries, convert_to_tensor=True)
with torch.no_grad():
z_text = F.normalize(clip_model.text_proj(text_embs), dim=-1)
# --- Compute similarities ---
sims = (z_gene @ z_text.T).squeeze()
for q, s in zip(queries, sims):
print(f" {s.item():.3f} {q}")
# Expected: highest similarity for "Interferon alpha response genes"
```
## Training from Scratch
### 1. Process MSigDB data
```bash
python data_processing.py
```
This downloads all MSigDB GMT files and scrapes descriptions.
### 2. Train
```bash
# Self-contained (downloads data from Hub automatically)
python train_job.py
# Or with local data
python train.py
```
### 3. On HF Jobs (GPU)
```python
from huggingface_hub import HfApi
# Submit as HF Job with GPU
# See train_job.py for the self-contained script
```
## Downstream Applications
1. **Zero-shot gene set annotation**: Embed a gene list from an experiment β†’ find nearest text descriptions
2. **Cross-modal search**: Text query β†’ gene sets, or gene list β†’ pathway descriptions
3. **Gene set similarity**: Compare gene sets via embedding cosine similarity (captures functional similarity beyond gene overlap)
4. **Cell type annotation**: Embed cell marker gene sets β†’ match to cell type text descriptions
5. **Biological RAG**: Use MSigDB embeddings as retrieval corpus for LLM-based reasoning
## Key References
- [ProtST](https://arxiv.org/abs/2301.12040) (ICML 2023) β€” Protein-text contrastive alignment
- [MoleculeSTM](https://arxiv.org/abs/2212.10789) (Nature MI 2024) β€” Molecule-text alignment
- [LangCell](https://arxiv.org/abs/2405.06708) β€” Cell-text contrastive with MSigDB pathways
- [BioLORD-2023](https://arxiv.org/abs/2311.16075) (JAMIA 2024) β€” Biomedical sentence embeddings
- [Set Transformer](https://arxiv.org/abs/1810.00825) β€” Permutation-invariant set encoding
## Files
| File | Description |
|------|-------------|
| `clip_model.pt` | Trained projection heads (text + gene) |
| `gene_encoder.pt` | Fine-tuned GSFM gene encoder |
| `config.json` | Training configuration |
| `vocab.json` | Gene symbol β†’ token ID mapping |
| `test_results.json` | Evaluation metrics on test set |
| `train_job.py` | Self-contained training script (for HF Jobs) |
| `train.py` | Modular training script |
| `data_processing.py` | MSigDB data download + processing |
## License
- Code: MIT
- GSFM model: BSD-3-Clause
- BioLORD-2023: Other (requires UMLS account)
- MSigDB data: [Creative Commons Attribution 4.0](https://www.gsea-msigdb.org/gsea/msigdb/licenses.jsp)