File size: 8,320 Bytes
cdc3ab7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
# GeneSetCLIP: Contrastive Pretraining for Gene Set–Text Alignment

A CLIP-style contrastive model that aligns **biological text descriptions** with **gene-set representations**, trained on MSigDB v2024.1 (human + mouse).

Given a text query like *"type I interferon signaling"*, the model retrieves the corresponding gene set β€” and vice versa.

## Architecture

```
TEXT SIDE                               GENE SET SIDE
─────────────────────                   ──────────────────────────
"Genes up-regulated in                  {STAT1, IRF7, ISG15,
 response to IFN-Ξ±..."                  OAS1, MX1, IFIT1, ...}
        β”‚                                        β”‚
        β–Ό                                        β–Ό
 BioLORD-2023 (frozen)                  GSFM (fine-tuned, lr/10)
 [768-dim]                              [256-dim]
        β”‚                                        β”‚
        β–Ό                                        β–Ό
 text_proj (trainable)                  gene_proj (trainable)
 768 β†’ 512 β†’ 256                        256 β†’ 256 β†’ 256
        β”‚                                        β”‚
        β–Ό                                        β–Ό
   z_text [256]                         z_gene [256]
        β”‚                                        β”‚
        └────── L2-normalize β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
                      β”‚
                      β–Ό
              InfoNCE loss (Ο„ learnable)
```

### Components

| Component | Model | Dim | Training |
|-----------|-------|-----|----------|
| **Gene encoder** | [GSFM](https://huggingface.co/maayanlab/gsfm-rummagene) (MLP autoencoder, Set model) | 256 | Fine-tuned at 1/10 LR |
| **Text encoder** | [BioLORD-2023](https://huggingface.co/FremyCompany/BioLORD-2023) (MPNet-base) | 768 | Frozen |
| **Gene projection** | MLP: 256 β†’ 256 β†’ 256 + LayerNorm | 256 | Trained |
| **Text projection** | MLP: 768 β†’ 512 β†’ 256 + LayerNorm | 256 | Trained |

### Why these encoders?

- **GSFM**: Purpose-built gene-set encoder from Ma'ayan Lab. Takes variable-length gene sets as input (multi-hot encoding β†’ MLP), producing permutation-invariant 256-dim embeddings. Pretrained on Rummagene (gene sets from PubMed tables).
- **BioLORD-2023**: Ontology-grounded biomedical sentence embeddings. Trained on UMLS concept name-synonym pairs + LLM-generated definitions β€” structurally identical to MSigDB gene set descriptions (name + definition anchored in GO/KEGG/Reactome).

## Training Data

**MSigDB v2024.1** β€” 50,896 gene set–text pairs from the Molecular Signatures Database.

| Split | Collections | Pairs | Purpose |
|-------|-------------|-------|---------|
| Train | C2, C5, C8, C1, M2, M5, M8, M1 | 38,622 | Curated, GO, cell type signatures |
| Val | C3, C4, M3 | 6,766 | Regulatory targets, computational |
| Test | H, C6, C7, MH | 5,508 | Hallmarks, oncogenic, immunologic |

Each pair consists of:
- **Text**: `[Collection: H] [Species: human]\nHALLMARK APOPTOSIS\nGenes mediating programmed cell death by activation of caspases.`
- **Genes**: `["CASP3", "CASP6", "TP53", "BAX", ...]`

Data augmentation: 20% gene dropout (randomly remove genes each epoch).

## Training Recipe

Based on [ProtST](https://arxiv.org/abs/2301.12040) (ICML 2023) adapted for gene sets:

| Parameter | Value |
|-----------|-------|
| Loss | Symmetric InfoNCE (NT-Xent) |
| Temperature | 0.07 (learnable, clamped [0.01, 1.0]) |
| Batch size | 256 |
| LR (projections) | 1e-4 |
| LR (gene encoder) | 1e-5 (10x lower) |
| LR (text encoder) | 0 (frozen) |
| Optimizer | AdamW (weight_decay=0.01) |
| Schedule | 500-step warmup β†’ cosine decay |
| Epochs | 50 (early stopping, patience=10) |
| Gene dropout | 20% |
| Max gene set size | 512 |
| Hardware | 1Γ— T4 GPU (16GB) |

## Quick Start

### Installation
```bash
pip install torch sentence-transformers huggingface_hub safetensors lightning
GIT_LFS_SKIP_SMUDGE=1 pip install "git+https://huggingface.co/maayanlab/gsfm"
```

### Inference
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
from gsfm import GSFM, Vocab
from sentence_transformers import SentenceTransformer
from huggingface_hub import hf_hub_download

# Load gene encoder + vocab
gene_encoder = GSFM.from_pretrained("maayanlab/gsfm-rummagene")
vocab = Vocab.from_pretrained("maayanlab/gsfm-rummagene")
gene_encoder.eval()

# Load text encoder
text_encoder = SentenceTransformer("FremyCompany/BioLORD-2023")

# Load projection heads
clip_path = hf_hub_download("AliSaadatV/GeneSetCLIP", "clip_model.pt")

class ProjectionHead(nn.Module):
    def __init__(self, d_in, d_h, d_out):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_in, d_h), nn.GELU(), nn.Dropout(0.1),
            nn.Linear(d_h, d_out), nn.LayerNorm(d_out))
    def forward(self, x): return self.net(x)

class GeneSetCLIP(nn.Module):
    def __init__(self):
        super().__init__()
        self.log_temperature = nn.Parameter(torch.zeros(1))
        self.text_proj = ProjectionHead(768, 512, 256)
        self.gene_proj = ProjectionHead(256, 256, 256)

clip_model = GeneSetCLIP()
clip_model.load_state_dict(torch.load(clip_path, map_location="cpu", weights_only=True))
clip_model.eval()

# --- Encode a gene set ---
genes = ["STAT1", "IRF7", "ISG15", "OAS1", "MX1", "IFIT1"]
gene_ids = torch.tensor([vocab(genes)])
with torch.no_grad():
    gene_emb = gene_encoder.encode(gene_ids)
    z_gene = F.normalize(clip_model.gene_proj(gene_emb), dim=-1)

# --- Encode text queries ---
queries = [
    "Interferon alpha response genes",
    "Apoptosis signaling",
    "Cell cycle regulation",
]
text_embs = text_encoder.encode(queries, convert_to_tensor=True)
with torch.no_grad():
    z_text = F.normalize(clip_model.text_proj(text_embs), dim=-1)

# --- Compute similarities ---
sims = (z_gene @ z_text.T).squeeze()
for q, s in zip(queries, sims):
    print(f"  {s.item():.3f}  {q}")
# Expected: highest similarity for "Interferon alpha response genes"
```

## Training from Scratch

### 1. Process MSigDB data
```bash
python data_processing.py
```
This downloads all MSigDB GMT files and scrapes descriptions.

### 2. Train
```bash
# Self-contained (downloads data from Hub automatically)
python train_job.py

# Or with local data
python train.py
```

### 3. On HF Jobs (GPU)
```python
from huggingface_hub import HfApi
# Submit as HF Job with GPU
# See train_job.py for the self-contained script
```

## Downstream Applications

1. **Zero-shot gene set annotation**: Embed a gene list from an experiment β†’ find nearest text descriptions
2. **Cross-modal search**: Text query β†’ gene sets, or gene list β†’ pathway descriptions  
3. **Gene set similarity**: Compare gene sets via embedding cosine similarity (captures functional similarity beyond gene overlap)
4. **Cell type annotation**: Embed cell marker gene sets β†’ match to cell type text descriptions
5. **Biological RAG**: Use MSigDB embeddings as retrieval corpus for LLM-based reasoning

## Key References

- [ProtST](https://arxiv.org/abs/2301.12040) (ICML 2023) β€” Protein-text contrastive alignment
- [MoleculeSTM](https://arxiv.org/abs/2212.10789) (Nature MI 2024) β€” Molecule-text alignment  
- [LangCell](https://arxiv.org/abs/2405.06708) β€” Cell-text contrastive with MSigDB pathways
- [BioLORD-2023](https://arxiv.org/abs/2311.16075) (JAMIA 2024) β€” Biomedical sentence embeddings
- [Set Transformer](https://arxiv.org/abs/1810.00825) β€” Permutation-invariant set encoding

## Files

| File | Description |
|------|-------------|
| `clip_model.pt` | Trained projection heads (text + gene) |
| `gene_encoder.pt` | Fine-tuned GSFM gene encoder |
| `config.json` | Training configuration |
| `vocab.json` | Gene symbol β†’ token ID mapping |
| `test_results.json` | Evaluation metrics on test set |
| `train_job.py` | Self-contained training script (for HF Jobs) |
| `train.py` | Modular training script |
| `data_processing.py` | MSigDB data download + processing |

## License

- Code: MIT
- GSFM model: BSD-3-Clause
- BioLORD-2023: Other (requires UMLS account)
- MSigDB data: [Creative Commons Attribution 4.0](https://www.gsea-msigdb.org/gsea/msigdb/licenses.jsp)