| # GeneSetCLIP: Contrastive Pretraining for Gene SetβText Alignment |
|
|
| A CLIP-style contrastive model that aligns **biological text descriptions** with **gene-set representations**, trained on MSigDB v2024.1 (human + mouse). |
|
|
| Given a text query like *"type I interferon signaling"*, the model retrieves the corresponding gene set β and vice versa. |
|
|
| ## Architecture |
|
|
| ``` |
| TEXT SIDE GENE SET SIDE |
| βββββββββββββββββββββ ββββββββββββββββββββββββββ |
| "Genes up-regulated in {STAT1, IRF7, ISG15, |
| response to IFN-Ξ±..." OAS1, MX1, IFIT1, ...} |
| β β |
| βΌ βΌ |
| BioLORD-2023 (frozen) GSFM (fine-tuned, lr/10) |
| [768-dim] [256-dim] |
| β β |
| βΌ βΌ |
| text_proj (trainable) gene_proj (trainable) |
| 768 β 512 β 256 256 β 256 β 256 |
| β β |
| βΌ βΌ |
| z_text [256] z_gene [256] |
| β β |
| βββββββ L2-normalize ββββββββββββββββββββ |
| β |
| βΌ |
| InfoNCE loss (Ο learnable) |
| ``` |
|
|
| ### Components |
|
|
| | Component | Model | Dim | Training | |
| |-----------|-------|-----|----------| |
| | **Gene encoder** | [GSFM](https://huggingface.co/maayanlab/gsfm-rummagene) (MLP autoencoder, Set model) | 256 | Fine-tuned at 1/10 LR | |
| | **Text encoder** | [BioLORD-2023](https://huggingface.co/FremyCompany/BioLORD-2023) (MPNet-base) | 768 | Frozen | |
| | **Gene projection** | MLP: 256 β 256 β 256 + LayerNorm | 256 | Trained | |
| | **Text projection** | MLP: 768 β 512 β 256 + LayerNorm | 256 | Trained | |
|
|
| ### Why these encoders? |
|
|
| - **GSFM**: Purpose-built gene-set encoder from Ma'ayan Lab. Takes variable-length gene sets as input (multi-hot encoding β MLP), producing permutation-invariant 256-dim embeddings. Pretrained on Rummagene (gene sets from PubMed tables). |
| - **BioLORD-2023**: Ontology-grounded biomedical sentence embeddings. Trained on UMLS concept name-synonym pairs + LLM-generated definitions β structurally identical to MSigDB gene set descriptions (name + definition anchored in GO/KEGG/Reactome). |
|
|
| ## Training Data |
|
|
| **MSigDB v2024.1** β 50,896 gene setβtext pairs from the Molecular Signatures Database. |
|
|
| | Split | Collections | Pairs | Purpose | |
| |-------|-------------|-------|---------| |
| | Train | C2, C5, C8, C1, M2, M5, M8, M1 | 38,622 | Curated, GO, cell type signatures | |
| | Val | C3, C4, M3 | 6,766 | Regulatory targets, computational | |
| | Test | H, C6, C7, MH | 5,508 | Hallmarks, oncogenic, immunologic | |
|
|
| Each pair consists of: |
| - **Text**: `[Collection: H] [Species: human]\nHALLMARK APOPTOSIS\nGenes mediating programmed cell death by activation of caspases.` |
| - **Genes**: `["CASP3", "CASP6", "TP53", "BAX", ...]` |
|
|
| Data augmentation: 20% gene dropout (randomly remove genes each epoch). |
|
|
| ## Training Recipe |
|
|
| Based on [ProtST](https://arxiv.org/abs/2301.12040) (ICML 2023) adapted for gene sets: |
|
|
| | Parameter | Value | |
| |-----------|-------| |
| | Loss | Symmetric InfoNCE (NT-Xent) | |
| | Temperature | 0.07 (learnable, clamped [0.01, 1.0]) | |
| | Batch size | 256 | |
| | LR (projections) | 1e-4 | |
| | LR (gene encoder) | 1e-5 (10x lower) | |
| | LR (text encoder) | 0 (frozen) | |
| | Optimizer | AdamW (weight_decay=0.01) | |
| | Schedule | 500-step warmup β cosine decay | |
| | Epochs | 50 (early stopping, patience=10) | |
| | Gene dropout | 20% | |
| | Max gene set size | 512 | |
| | Hardware | 1Γ T4 GPU (16GB) | |
| |
| ## Quick Start |
| |
| ### Installation |
| ```bash |
| pip install torch sentence-transformers huggingface_hub safetensors lightning |
| GIT_LFS_SKIP_SMUDGE=1 pip install "git+https://huggingface.co/maayanlab/gsfm" |
| ``` |
| |
| ### Inference |
| ```python |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from gsfm import GSFM, Vocab |
| from sentence_transformers import SentenceTransformer |
| from huggingface_hub import hf_hub_download |
| |
| # Load gene encoder + vocab |
| gene_encoder = GSFM.from_pretrained("maayanlab/gsfm-rummagene") |
| vocab = Vocab.from_pretrained("maayanlab/gsfm-rummagene") |
| gene_encoder.eval() |
| |
| # Load text encoder |
| text_encoder = SentenceTransformer("FremyCompany/BioLORD-2023") |
|
|
| # Load projection heads |
| clip_path = hf_hub_download("AliSaadatV/GeneSetCLIP", "clip_model.pt") |
|
|
| class ProjectionHead(nn.Module): |
| def __init__(self, d_in, d_h, d_out): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Linear(d_in, d_h), nn.GELU(), nn.Dropout(0.1), |
| nn.Linear(d_h, d_out), nn.LayerNorm(d_out)) |
| def forward(self, x): return self.net(x) |
| |
| class GeneSetCLIP(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.log_temperature = nn.Parameter(torch.zeros(1)) |
| self.text_proj = ProjectionHead(768, 512, 256) |
| self.gene_proj = ProjectionHead(256, 256, 256) |
| |
| clip_model = GeneSetCLIP() |
| clip_model.load_state_dict(torch.load(clip_path, map_location="cpu", weights_only=True)) |
| clip_model.eval() |
|
|
| # --- Encode a gene set --- |
| genes = ["STAT1", "IRF7", "ISG15", "OAS1", "MX1", "IFIT1"] |
| gene_ids = torch.tensor([vocab(genes)]) |
| with torch.no_grad(): |
| gene_emb = gene_encoder.encode(gene_ids) |
| z_gene = F.normalize(clip_model.gene_proj(gene_emb), dim=-1) |
| |
| # --- Encode text queries --- |
| queries = [ |
| "Interferon alpha response genes", |
| "Apoptosis signaling", |
| "Cell cycle regulation", |
| ] |
| text_embs = text_encoder.encode(queries, convert_to_tensor=True) |
| with torch.no_grad(): |
| z_text = F.normalize(clip_model.text_proj(text_embs), dim=-1) |
| |
| # --- Compute similarities --- |
| sims = (z_gene @ z_text.T).squeeze() |
| for q, s in zip(queries, sims): |
| print(f" {s.item():.3f} {q}") |
| # Expected: highest similarity for "Interferon alpha response genes" |
| ``` |
| |
| ## Training from Scratch |
|
|
| ### 1. Process MSigDB data |
| ```bash |
| python data_processing.py |
| ``` |
| This downloads all MSigDB GMT files and scrapes descriptions. |
|
|
| ### 2. Train |
| ```bash |
| # Self-contained (downloads data from Hub automatically) |
| python train_job.py |
| |
| # Or with local data |
| python train.py |
| ``` |
|
|
| ### 3. On HF Jobs (GPU) |
| ```python |
| from huggingface_hub import HfApi |
| # Submit as HF Job with GPU |
| # See train_job.py for the self-contained script |
| ``` |
|
|
| ## Downstream Applications |
|
|
| 1. **Zero-shot gene set annotation**: Embed a gene list from an experiment β find nearest text descriptions |
| 2. **Cross-modal search**: Text query β gene sets, or gene list β pathway descriptions |
| 3. **Gene set similarity**: Compare gene sets via embedding cosine similarity (captures functional similarity beyond gene overlap) |
| 4. **Cell type annotation**: Embed cell marker gene sets β match to cell type text descriptions |
| 5. **Biological RAG**: Use MSigDB embeddings as retrieval corpus for LLM-based reasoning |
|
|
| ## Key References |
|
|
| - [ProtST](https://arxiv.org/abs/2301.12040) (ICML 2023) β Protein-text contrastive alignment |
| - [MoleculeSTM](https://arxiv.org/abs/2212.10789) (Nature MI 2024) β Molecule-text alignment |
| - [LangCell](https://arxiv.org/abs/2405.06708) β Cell-text contrastive with MSigDB pathways |
| - [BioLORD-2023](https://arxiv.org/abs/2311.16075) (JAMIA 2024) β Biomedical sentence embeddings |
| - [Set Transformer](https://arxiv.org/abs/1810.00825) β Permutation-invariant set encoding |
|
|
| ## Files |
|
|
| | File | Description | |
| |------|-------------| |
| | `clip_model.pt` | Trained projection heads (text + gene) | |
| | `gene_encoder.pt` | Fine-tuned GSFM gene encoder | |
| | `config.json` | Training configuration | |
| | `vocab.json` | Gene symbol β token ID mapping | |
| | `test_results.json` | Evaluation metrics on test set | |
| | `train_job.py` | Self-contained training script (for HF Jobs) | |
| | `train.py` | Modular training script | |
| | `data_processing.py` | MSigDB data download + processing | |
|
|
| ## License |
|
|
| - Code: MIT |
| - GSFM model: BSD-3-Clause |
| - BioLORD-2023: Other (requires UMLS account) |
| - MSigDB data: [Creative Commons Attribution 4.0](https://www.gsea-msigdb.org/gsea/msigdb/licenses.jsp) |
|
|