# GeneSetCLIP: Contrastive Pretraining for Gene Set–Text Alignment A CLIP-style contrastive model that aligns **biological text descriptions** with **gene-set representations**, trained on MSigDB v2024.1 (human + mouse). Given a text query like *"type I interferon signaling"*, the model retrieves the corresponding gene set — and vice versa. ## Architecture ``` TEXT SIDE GENE SET SIDE ───────────────────── ────────────────────────── "Genes up-regulated in {STAT1, IRF7, ISG15, response to IFN-α..." OAS1, MX1, IFIT1, ...} │ │ ▼ ▼ BioLORD-2023 (frozen) GSFM (fine-tuned, lr/10) [768-dim] [256-dim] │ │ ▼ ▼ text_proj (trainable) gene_proj (trainable) 768 → 512 → 256 256 → 256 → 256 │ │ ▼ ▼ z_text [256] z_gene [256] │ │ └────── L2-normalize ───────────────────┘ │ ▼ InfoNCE loss (τ learnable) ``` ### Components | Component | Model | Dim | Training | |-----------|-------|-----|----------| | **Gene encoder** | [GSFM](https://huggingface.co/maayanlab/gsfm-rummagene) (MLP autoencoder, Set model) | 256 | Fine-tuned at 1/10 LR | | **Text encoder** | [BioLORD-2023](https://huggingface.co/FremyCompany/BioLORD-2023) (MPNet-base) | 768 | Frozen | | **Gene projection** | MLP: 256 → 256 → 256 + LayerNorm | 256 | Trained | | **Text projection** | MLP: 768 → 512 → 256 + LayerNorm | 256 | Trained | ### Why these encoders? - **GSFM**: Purpose-built gene-set encoder from Ma'ayan Lab. Takes variable-length gene sets as input (multi-hot encoding → MLP), producing permutation-invariant 256-dim embeddings. Pretrained on Rummagene (gene sets from PubMed tables). - **BioLORD-2023**: Ontology-grounded biomedical sentence embeddings. Trained on UMLS concept name-synonym pairs + LLM-generated definitions — structurally identical to MSigDB gene set descriptions (name + definition anchored in GO/KEGG/Reactome). ## Training Data **MSigDB v2024.1** — 50,896 gene set–text pairs from the Molecular Signatures Database. | Split | Collections | Pairs | Purpose | |-------|-------------|-------|---------| | Train | C2, C5, C8, C1, M2, M5, M8, M1 | 38,622 | Curated, GO, cell type signatures | | Val | C3, C4, M3 | 6,766 | Regulatory targets, computational | | Test | H, C6, C7, MH | 5,508 | Hallmarks, oncogenic, immunologic | Each pair consists of: - **Text**: `[Collection: H] [Species: human]\nHALLMARK APOPTOSIS\nGenes mediating programmed cell death by activation of caspases.` - **Genes**: `["CASP3", "CASP6", "TP53", "BAX", ...]` Data augmentation: 20% gene dropout (randomly remove genes each epoch). ## Training Recipe Based on [ProtST](https://arxiv.org/abs/2301.12040) (ICML 2023) adapted for gene sets: | Parameter | Value | |-----------|-------| | Loss | Symmetric InfoNCE (NT-Xent) | | Temperature | 0.07 (learnable, clamped [0.01, 1.0]) | | Batch size | 256 | | LR (projections) | 1e-4 | | LR (gene encoder) | 1e-5 (10x lower) | | LR (text encoder) | 0 (frozen) | | Optimizer | AdamW (weight_decay=0.01) | | Schedule | 500-step warmup → cosine decay | | Epochs | 50 (early stopping, patience=10) | | Gene dropout | 20% | | Max gene set size | 512 | | Hardware | 1× T4 GPU (16GB) | ## Quick Start ### Installation ```bash pip install torch sentence-transformers huggingface_hub safetensors lightning GIT_LFS_SKIP_SMUDGE=1 pip install "git+https://huggingface.co/maayanlab/gsfm" ``` ### Inference ```python import torch import torch.nn as nn import torch.nn.functional as F from gsfm import GSFM, Vocab from sentence_transformers import SentenceTransformer from huggingface_hub import hf_hub_download # Load gene encoder + vocab gene_encoder = GSFM.from_pretrained("maayanlab/gsfm-rummagene") vocab = Vocab.from_pretrained("maayanlab/gsfm-rummagene") gene_encoder.eval() # Load text encoder text_encoder = SentenceTransformer("FremyCompany/BioLORD-2023") # Load projection heads clip_path = hf_hub_download("AliSaadatV/GeneSetCLIP", "clip_model.pt") class ProjectionHead(nn.Module): def __init__(self, d_in, d_h, d_out): super().__init__() self.net = nn.Sequential( nn.Linear(d_in, d_h), nn.GELU(), nn.Dropout(0.1), nn.Linear(d_h, d_out), nn.LayerNorm(d_out)) def forward(self, x): return self.net(x) class GeneSetCLIP(nn.Module): def __init__(self): super().__init__() self.log_temperature = nn.Parameter(torch.zeros(1)) self.text_proj = ProjectionHead(768, 512, 256) self.gene_proj = ProjectionHead(256, 256, 256) clip_model = GeneSetCLIP() clip_model.load_state_dict(torch.load(clip_path, map_location="cpu", weights_only=True)) clip_model.eval() # --- Encode a gene set --- genes = ["STAT1", "IRF7", "ISG15", "OAS1", "MX1", "IFIT1"] gene_ids = torch.tensor([vocab(genes)]) with torch.no_grad(): gene_emb = gene_encoder.encode(gene_ids) z_gene = F.normalize(clip_model.gene_proj(gene_emb), dim=-1) # --- Encode text queries --- queries = [ "Interferon alpha response genes", "Apoptosis signaling", "Cell cycle regulation", ] text_embs = text_encoder.encode(queries, convert_to_tensor=True) with torch.no_grad(): z_text = F.normalize(clip_model.text_proj(text_embs), dim=-1) # --- Compute similarities --- sims = (z_gene @ z_text.T).squeeze() for q, s in zip(queries, sims): print(f" {s.item():.3f} {q}") # Expected: highest similarity for "Interferon alpha response genes" ``` ## Training from Scratch ### 1. Process MSigDB data ```bash python data_processing.py ``` This downloads all MSigDB GMT files and scrapes descriptions. ### 2. Train ```bash # Self-contained (downloads data from Hub automatically) python train_job.py # Or with local data python train.py ``` ### 3. On HF Jobs (GPU) ```python from huggingface_hub import HfApi # Submit as HF Job with GPU # See train_job.py for the self-contained script ``` ## Downstream Applications 1. **Zero-shot gene set annotation**: Embed a gene list from an experiment → find nearest text descriptions 2. **Cross-modal search**: Text query → gene sets, or gene list → pathway descriptions 3. **Gene set similarity**: Compare gene sets via embedding cosine similarity (captures functional similarity beyond gene overlap) 4. **Cell type annotation**: Embed cell marker gene sets → match to cell type text descriptions 5. **Biological RAG**: Use MSigDB embeddings as retrieval corpus for LLM-based reasoning ## Key References - [ProtST](https://arxiv.org/abs/2301.12040) (ICML 2023) — Protein-text contrastive alignment - [MoleculeSTM](https://arxiv.org/abs/2212.10789) (Nature MI 2024) — Molecule-text alignment - [LangCell](https://arxiv.org/abs/2405.06708) — Cell-text contrastive with MSigDB pathways - [BioLORD-2023](https://arxiv.org/abs/2311.16075) (JAMIA 2024) — Biomedical sentence embeddings - [Set Transformer](https://arxiv.org/abs/1810.00825) — Permutation-invariant set encoding ## Files | File | Description | |------|-------------| | `clip_model.pt` | Trained projection heads (text + gene) | | `gene_encoder.pt` | Fine-tuned GSFM gene encoder | | `config.json` | Training configuration | | `vocab.json` | Gene symbol → token ID mapping | | `test_results.json` | Evaluation metrics on test set | | `train_job.py` | Self-contained training script (for HF Jobs) | | `train.py` | Modular training script | | `data_processing.py` | MSigDB data download + processing | ## License - Code: MIT - GSFM model: BSD-3-Clause - BioLORD-2023: Other (requires UMLS account) - MSigDB data: [Creative Commons Attribution 4.0](https://www.gsea-msigdb.org/gsea/msigdb/licenses.jsp)