""" GeneSetCLIP: Contrastive pretraining to align gene-set embeddings with text descriptions. Architecture: - Gene encoder: GSFM (MLP autoencoder, 256-dim) from maayanlab/gsfm-rummagene - Text encoder: BioLORD-2023 (768-dim, frozen) from FremyCompany/BioLORD-2023 - Projection heads: text 768->256, gene 256->256 - Loss: Symmetric InfoNCE with learnable temperature Training recipe based on ProtST (ICML 2023) adapted for gene sets: - Freeze text encoder, fine-tune gene encoder at 1/10 LR - Gene dropout augmentation (20%) - Large batch for InfoNCE (256-512) """ import os import json import math import random import time from pathlib import Path from collections import defaultdict from dataclasses import dataclass, field import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader import numpy as np from huggingface_hub import HfApi # ============================================================ # Configuration # ============================================================ @dataclass class Config: # Model gene_model_id: str = "maayanlab/gsfm-rummagene" text_model_id: str = "FremyCompany/BioLORD-2023" shared_dim: int = 256 gene_dim: int = 256 text_dim: int = 768 proj_hidden_dim: int = 512 proj_dropout: float = 0.1 # Training batch_size: int = 256 lr: float = 1e-4 gene_encoder_lr: float = 1e-5 # 10x lower for pretrained gene encoder weight_decay: float = 0.01 warmup_steps: int = 500 max_epochs: int = 50 patience: int = 10 # early stopping patience temperature_init: float = 0.07 learnable_temperature: bool = True gene_dropout_rate: float = 0.2 # augmentation: randomly drop genes max_gene_set_size: int = 512 # pad/truncate gene sets to this # Data data_dir: str = "/app/data/processed" output_dir: str = "/app/output" hub_model_id: str = "AliSaadatV/GeneSetCLIP" # Hardware device: str = "cuda" if torch.cuda.is_available() else "cpu" num_workers: int = 4 mixed_precision: bool = True # Logging log_every: int = 10 eval_every: int = 1 # epochs save_every: int = 5 # epochs # ============================================================ # Dataset # ============================================================ class GeneSetTextDataset(Dataset): """Dataset of (text, gene_set) pairs for contrastive learning.""" def __init__(self, jsonl_path: str, vocab: dict, max_genes: int = 512, gene_dropout: float = 0.0, pad_idx: int = 1): self.records = [] with open(jsonl_path) as f: for line in f: self.records.append(json.loads(line)) self.vocab = vocab self.max_genes = max_genes self.gene_dropout = gene_dropout self.pad_idx = pad_idx def __len__(self): return len(self.records) def __getitem__(self, idx): record = self.records[idx] text = record["text"] genes = record["genes"] # Tokenize genes token_ids = [self.vocab.get(g, 0) for g in genes] # 0 = UNK # Gene dropout augmentation if self.gene_dropout > 0 and self.training_mode: n_keep = max(3, int(len(token_ids) * (1 - self.gene_dropout))) if n_keep < len(token_ids): token_ids = random.sample(token_ids, n_keep) # Truncate if too long if len(token_ids) > self.max_genes: token_ids = random.sample(token_ids, self.max_genes) # Pad n_genes = len(token_ids) if n_genes < self.max_genes: token_ids = token_ids + [self.pad_idx] * (self.max_genes - n_genes) return { "text": text, "gene_ids": torch.tensor(token_ids, dtype=torch.long), "n_genes": n_genes, "id": record["id"], } @property def training_mode(self): return self.gene_dropout > 0 def collate_fn(batch): """Custom collate: stack gene tensors, keep texts as list.""" return { "text": [item["text"] for item in batch], "gene_ids": torch.stack([item["gene_ids"] for item in batch]), "n_genes": torch.tensor([item["n_genes"] for item in batch]), "ids": [item["id"] for item in batch], } # ============================================================ # Model # ============================================================ class ProjectionHead(nn.Module): """MLP projection head with LayerNorm.""" def __init__(self, input_dim, hidden_dim, output_dim, dropout=0.1): super().__init__() self.net = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, output_dim), nn.LayerNorm(output_dim), ) def forward(self, x): return self.net(x) class GeneSetCLIP(nn.Module): """ Contrastive model aligning gene-set embeddings with text embeddings. Components: - gene_encoder: GSFM (pretrained, fine-tuned at low LR) - text_encoder: BioLORD-2023 (frozen) - text_proj: 768 -> 256 - gene_proj: 256 -> 256 - log_temperature: learnable scalar """ def __init__(self, config: Config): super().__init__() self.config = config # Temperature parameter self.log_temperature = nn.Parameter( torch.log(torch.tensor(config.temperature_init)), requires_grad=config.learnable_temperature, ) # Projection heads self.text_proj = ProjectionHead( config.text_dim, config.proj_hidden_dim, config.shared_dim, config.proj_dropout ) self.gene_proj = ProjectionHead( config.gene_dim, config.shared_dim, config.shared_dim, config.proj_dropout ) @property def temperature(self): return torch.clamp(self.log_temperature.exp(), min=0.01, max=1.0) def forward(self, gene_emb, text_emb): """ Args: gene_emb: (B, 256) from GSFM.encode() text_emb: (B, 768) from BioLORD Returns: loss, metrics_dict """ # Project to shared space z_gene = F.normalize(self.gene_proj(gene_emb), dim=-1) z_text = F.normalize(self.text_proj(text_emb), dim=-1) # Compute similarities tau = self.temperature logits = z_gene @ z_text.T / tau # (B, B) # Symmetric InfoNCE B = logits.size(0) labels = torch.arange(B, device=logits.device) loss_g2t = F.cross_entropy(logits, labels) loss_t2g = F.cross_entropy(logits.T, labels) loss = (loss_g2t + loss_t2g) / 2 # Metrics with torch.no_grad(): # Accuracy: is the correct pair the top-1? g2t_acc = (logits.argmax(dim=1) == labels).float().mean() t2g_acc = (logits.T.argmax(dim=1) == labels).float().mean() avg_acc = (g2t_acc + t2g_acc) / 2 metrics = { "loss": loss.item(), "g2t_acc": g2t_acc.item(), "t2g_acc": t2g_acc.item(), "avg_acc": avg_acc.item(), "temperature": tau.item(), } return loss, z_gene, z_text, metrics def get_embeddings(self, gene_emb=None, text_emb=None): """Get normalized projected embeddings.""" z_gene = z_text = None if gene_emb is not None: z_gene = F.normalize(self.gene_proj(gene_emb), dim=-1) if text_emb is not None: z_text = F.normalize(self.text_proj(text_emb), dim=-1) return z_gene, z_text # ============================================================ # Evaluation # ============================================================ @torch.no_grad() def evaluate_retrieval(model, gene_encoder, text_encoder, dataloader, device, config): """ Evaluate text-to-gene and gene-to-text retrieval. Returns recall@k metrics and loss. """ model.eval() gene_encoder.eval() all_z_gene = [] all_z_text = [] all_ids = [] total_loss = 0 n_batches = 0 for batch in dataloader: gene_ids = batch["gene_ids"].to(device) texts = batch["text"] # Encode genes gene_emb = gene_encoder.encode(gene_ids) # Encode text text_emb = text_encoder.encode(texts, convert_to_tensor=True, show_progress_bar=False) if text_emb.device != device: text_emb = text_emb.to(device) text_emb = text_emb.clone() # Project loss, z_gene, z_text, metrics = model(gene_emb, text_emb) total_loss += loss.item() n_batches += 1 all_z_gene.append(z_gene.cpu()) all_z_text.append(z_text.cpu()) all_ids.extend(batch["ids"]) all_z_gene = torch.cat(all_z_gene, dim=0) all_z_text = torch.cat(all_z_text, dim=0) N = len(all_z_gene) # Compute full similarity matrix sim = all_z_gene @ all_z_text.T # (N, N) # Retrieval metrics labels = torch.arange(N) def recall_at_k(sim_matrix, labels, k): topk = sim_matrix.topk(k, dim=1).indices correct = (topk == labels.unsqueeze(1)).any(dim=1) return correct.float().mean().item() def mrr(sim_matrix, labels): ranks = (sim_matrix.argsort(dim=1, descending=True) == labels.unsqueeze(1)).nonzero()[:, 1] + 1 return (1.0 / ranks.float()).mean().item() results = { "loss": total_loss / max(n_batches, 1), "n_samples": N, # Gene-to-Text retrieval "g2t_R@1": recall_at_k(sim, labels, 1), "g2t_R@5": recall_at_k(sim, labels, 5), "g2t_R@10": recall_at_k(sim, labels, 10), "g2t_MRR": mrr(sim, labels), # Text-to-Gene retrieval "t2g_R@1": recall_at_k(sim.T, labels, 1), "t2g_R@5": recall_at_k(sim.T, labels, 5), "t2g_R@10": recall_at_k(sim.T, labels, 10), "t2g_MRR": mrr(sim.T, labels), } # Average metrics results["avg_R@1"] = (results["g2t_R@1"] + results["t2g_R@1"]) / 2 results["avg_R@5"] = (results["g2t_R@5"] + results["t2g_R@5"]) / 2 results["avg_R@10"] = (results["g2t_R@10"] + results["t2g_R@10"]) / 2 results["avg_MRR"] = (results["g2t_MRR"] + results["t2g_MRR"]) / 2 model.train() return results # ============================================================ # Training Loop # ============================================================ def get_optimizer(model, gene_encoder, config): """Set up optimizer with different LRs for different components.""" param_groups = [ # Projection heads + temperature {"params": list(model.text_proj.parameters()) + list(model.gene_proj.parameters()) + [model.log_temperature], "lr": config.lr, "weight_decay": config.weight_decay}, # Gene encoder (lower LR) {"params": gene_encoder.parameters(), "lr": config.gene_encoder_lr, "weight_decay": config.weight_decay}, ] return torch.optim.AdamW(param_groups) def get_scheduler(optimizer, config, total_steps): """Warmup + cosine decay scheduler.""" def lr_lambda(step): if step < config.warmup_steps: return step / max(config.warmup_steps, 1) progress = (step - config.warmup_steps) / max(total_steps - config.warmup_steps, 1) return 0.5 * (1 + math.cos(math.pi * progress)) return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) def train(config: Config): """Full training pipeline.""" import trackio print("=" * 70) print("GeneSetCLIP Training") print("=" * 70) print(f"Device: {config.device}") print(f"Batch size: {config.batch_size}") print(f"Max epochs: {config.max_epochs}") os.makedirs(config.output_dir, exist_ok=True) # ---- Load GSFM ---- print("\nLoading GSFM gene encoder...") from gsfm import GSFM, Vocab vocab_obj = Vocab.from_pretrained(config.gene_model_id) gene_encoder = GSFM.from_pretrained(config.gene_model_id) gene_encoder.to(config.device) gene_encoder.train() # Build vocab dict for dataset vocab_dict = {token: i for i, token in enumerate(vocab_obj.vocab)} print(f" GSFM vocab: {len(vocab_dict)} genes, d_model=256") # ---- Load BioLORD ---- print("Loading BioLORD text encoder (frozen)...") from sentence_transformers import SentenceTransformer text_encoder = SentenceTransformer(config.text_model_id, device=config.device) # Freeze all text encoder parameters for param in text_encoder.parameters(): param.requires_grad = False text_encoder.eval() print(f" BioLORD dim: {config.text_dim}") # ---- Build model ---- print("Building GeneSetCLIP model...") model = GeneSetCLIP(config).to(config.device) total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) gene_params = sum(p.numel() for p in gene_encoder.parameters()) print(f" Projection head params: {total_params:,}") print(f" Gene encoder params: {gene_params:,}") # ---- Load data ---- print("\nLoading datasets...") train_ds = GeneSetTextDataset( os.path.join(config.data_dir, "train.jsonl"), vocab_dict, max_genes=config.max_gene_set_size, gene_dropout=config.gene_dropout_rate, ) val_ds = GeneSetTextDataset( os.path.join(config.data_dir, "val.jsonl"), vocab_dict, max_genes=config.max_gene_set_size, gene_dropout=0.0, # no augmentation for val ) test_ds = GeneSetTextDataset( os.path.join(config.data_dir, "test.jsonl"), vocab_dict, max_genes=config.max_gene_set_size, gene_dropout=0.0, ) print(f" Train: {len(train_ds)}, Val: {len(val_ds)}, Test: {len(test_ds)}") train_loader = DataLoader(train_ds, batch_size=config.batch_size, shuffle=True, collate_fn=collate_fn, num_workers=config.num_workers, pin_memory=True, drop_last=True) val_loader = DataLoader(val_ds, batch_size=config.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=config.num_workers) test_loader = DataLoader(test_ds, batch_size=config.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=config.num_workers) steps_per_epoch = len(train_loader) total_steps = steps_per_epoch * config.max_epochs print(f" Steps/epoch: {steps_per_epoch}, Total steps: {total_steps}") # ---- Optimizer ---- optimizer = get_optimizer(model, gene_encoder, config) scheduler = get_scheduler(optimizer, config, total_steps) # Mixed precision scaler = torch.amp.GradScaler('cuda') if config.mixed_precision and config.device == "cuda" else None # ---- Tracking ---- trackio.init( project="GeneSetCLIP", name=f"bs{config.batch_size}_lr{config.lr}_temp{config.temperature_init}", ) # ---- Training ---- best_val_mrr = 0 patience_counter = 0 global_step = 0 for epoch in range(1, config.max_epochs + 1): model.train() gene_encoder.train() epoch_loss = 0 epoch_acc = 0 n_batches = 0 for batch_idx, batch in enumerate(train_loader): gene_ids = batch["gene_ids"].to(config.device) texts = batch["text"] # Encode genes (with gradient) gene_emb = gene_encoder.encode(gene_ids) # Encode text (no gradient, frozen) - clone to exit inference mode with torch.no_grad(): text_emb = text_encoder.encode(texts, convert_to_tensor=True, show_progress_bar=False) if text_emb.device != torch.device(config.device): text_emb = text_emb.to(config.device) text_emb = text_emb.clone() # exit inference mode for autograd # Forward + loss if scaler is not None: with torch.amp.autocast('cuda'): loss, _, _, metrics = model(gene_emb, text_emb) optimizer.zero_grad() scaler.scale(loss).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_( list(model.parameters()) + list(gene_encoder.parameters()), max_norm=1.0 ) scaler.step(optimizer) scaler.update() else: loss, _, _, metrics = model(gene_emb, text_emb) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_( list(model.parameters()) + list(gene_encoder.parameters()), max_norm=1.0 ) optimizer.step() scheduler.step() global_step += 1 epoch_loss += metrics["loss"] epoch_acc += metrics["avg_acc"] n_batches += 1 if global_step % config.log_every == 0: lr_proj = optimizer.param_groups[0]["lr"] lr_gene = optimizer.param_groups[1]["lr"] print(f" Step {global_step:5d} | " f"Loss: {metrics['loss']:.4f} | " f"Acc: {metrics['avg_acc']:.3f} | " f"τ: {metrics['temperature']:.4f} | " f"LR: {lr_proj:.2e}/{lr_gene:.2e}") trackio.log({ "train/loss": metrics["loss"], "train/g2t_acc": metrics["g2t_acc"], "train/t2g_acc": metrics["t2g_acc"], "train/avg_acc": metrics["avg_acc"], "train/temperature": metrics["temperature"], "train/lr_proj": lr_proj, "train/lr_gene": lr_gene, "step": global_step, }) avg_loss = epoch_loss / max(n_batches, 1) avg_acc = epoch_acc / max(n_batches, 1) print(f"\nEpoch {epoch}/{config.max_epochs} | " f"Train Loss: {avg_loss:.4f} | Train Acc: {avg_acc:.3f}") # ---- Validation ---- if epoch % config.eval_every == 0: print(" Evaluating on validation set...") val_results = evaluate_retrieval( model, gene_encoder, text_encoder, val_loader, config.device, config ) print(f" Val Loss: {val_results['loss']:.4f} | " f"Val R@1: {val_results['avg_R@1']:.3f} | " f"Val R@5: {val_results['avg_R@5']:.3f} | " f"Val R@10: {val_results['avg_R@10']:.3f} | " f"Val MRR: {val_results['avg_MRR']:.3f}") trackio.log({ "val/loss": val_results["loss"], "val/g2t_R@1": val_results["g2t_R@1"], "val/g2t_R@5": val_results["g2t_R@5"], "val/t2g_R@1": val_results["t2g_R@1"], "val/t2g_R@5": val_results["t2g_R@5"], "val/avg_R@1": val_results["avg_R@1"], "val/avg_R@5": val_results["avg_R@5"], "val/avg_R@10": val_results["avg_R@10"], "val/avg_MRR": val_results["avg_MRR"], "epoch": epoch, }) # Early stopping if val_results["avg_MRR"] > best_val_mrr: best_val_mrr = val_results["avg_MRR"] patience_counter = 0 # Save best model save_checkpoint(model, gene_encoder, optimizer, config, epoch, val_results, is_best=True) print(f" ✓ New best! MRR: {best_val_mrr:.4f}") else: patience_counter += 1 print(f" No improvement ({patience_counter}/{config.patience})") if patience_counter >= config.patience: print(f" Early stopping at epoch {epoch}") break # Periodic save if epoch % config.save_every == 0: save_checkpoint(model, gene_encoder, optimizer, config, epoch, {}) # ---- Final Test Evaluation ---- print("\n" + "=" * 70) print("Final evaluation on test set (H, C6, C7 collections)...") # Load best model best_path = os.path.join(config.output_dir, "best_model") if os.path.exists(best_path): model.load_state_dict(torch.load(os.path.join(best_path, "clip_model.pt"), map_location=config.device)) gene_encoder.load_state_dict(torch.load(os.path.join(best_path, "gene_encoder.pt"), map_location=config.device)) print("Loaded best model checkpoint") test_results = evaluate_retrieval( model, gene_encoder, text_encoder, test_loader, config.device, config ) print(f"\nTest Results:") print(f" Loss: {test_results['loss']:.4f}") print(f" G→T R@1: {test_results['g2t_R@1']:.3f} R@5: {test_results['g2t_R@5']:.3f} R@10: {test_results['g2t_R@10']:.3f}") print(f" T→G R@1: {test_results['t2g_R@1']:.3f} R@5: {test_results['t2g_R@5']:.3f} R@10: {test_results['t2g_R@10']:.3f}") print(f" Avg R@1: {test_results['avg_R@1']:.3f} R@5: {test_results['avg_R@5']:.3f} MRR: {test_results['avg_MRR']:.3f}") trackio.log({"test/" + k: v for k, v in test_results.items()}) # Save test results with open(os.path.join(config.output_dir, "test_results.json"), "w") as f: json.dump(test_results, f, indent=2) # ---- Push to Hub ---- print("\nPushing model to Hub...") push_to_hub(model, gene_encoder, vocab_dict, config, test_results) print("Done!") def save_checkpoint(model, gene_encoder, optimizer, config, epoch, metrics, is_best=False): """Save model checkpoint.""" save_dir = os.path.join(config.output_dir, "best_model" if is_best else f"checkpoint_epoch{epoch}") os.makedirs(save_dir, exist_ok=True) torch.save(model.state_dict(), os.path.join(save_dir, "clip_model.pt")) torch.save(gene_encoder.state_dict(), os.path.join(save_dir, "gene_encoder.pt")) torch.save(optimizer.state_dict(), os.path.join(save_dir, "optimizer.pt")) # Save config + metrics with open(os.path.join(save_dir, "config.json"), "w") as f: json.dump(vars(config), f, indent=2) with open(os.path.join(save_dir, "metrics.json"), "w") as f: json.dump(metrics, f, indent=2) def push_to_hub(model, gene_encoder, vocab_dict, config, test_results): """Push trained model to HuggingFace Hub.""" api = HfApi() # Create repo if needed try: api.create_repo(config.hub_model_id, exist_ok=True) except Exception as e: print(f" Warning creating repo: {e}") save_dir = os.path.join(config.output_dir, "hub_upload") os.makedirs(save_dir, exist_ok=True) # Save model files torch.save(model.state_dict(), os.path.join(save_dir, "clip_model.pt")) torch.save(gene_encoder.state_dict(), os.path.join(save_dir, "gene_encoder.pt")) # Save config with open(os.path.join(save_dir, "config.json"), "w") as f: json.dump(vars(config), f, indent=2) # Save vocab with open(os.path.join(save_dir, "vocab.json"), "w") as f: json.dump(vocab_dict, f) # Save test results with open(os.path.join(save_dir, "test_results.json"), "w") as f: json.dump(test_results, f, indent=2) # Create README readme = f"""# GeneSetCLIP Contrastive model aligning gene-set embeddings with biomedical text descriptions. ## Architecture - **Gene encoder**: GSFM (MaayanLab, MLP autoencoder, 256-dim) - **Text encoder**: BioLORD-2023 (768-dim, frozen during training) - **Projection heads**: Maps both modalities to shared 256-dim space - **Loss**: Symmetric InfoNCE with learnable temperature ## Training Data - **MSigDB v2024.1** (Human + Mouse): ~50,000 gene set-text pairs - Collections: H, C1-C8 (Human), MH, M1-M8 (Mouse) ## Test Results (H, C6, C7 collections) | Metric | Gene→Text | Text→Gene | Average | |--------|-----------|-----------|---------| | R@1 | {test_results.get('g2t_R@1', 0):.3f} | {test_results.get('t2g_R@1', 0):.3f} | {test_results.get('avg_R@1', 0):.3f} | | R@5 | {test_results.get('g2t_R@5', 0):.3f} | {test_results.get('t2g_R@5', 0):.3f} | {test_results.get('avg_R@5', 0):.3f} | | R@10 | {test_results.get('g2t_R@10', 0):.3f} | {test_results.get('t2g_R@10', 0):.3f} | {test_results.get('avg_R@10', 0):.3f} | | MRR | {test_results.get('g2t_MRR', 0):.3f} | {test_results.get('t2g_MRR', 0):.3f} | {test_results.get('avg_MRR', 0):.3f} | ## Usage ```python import torch from gsfm import GSFM, Vocab from sentence_transformers import SentenceTransformer # Load models gene_encoder = GSFM.from_pretrained("maayanlab/gsfm-rummagene") text_encoder = SentenceTransformer("FremyCompany/BioLORD-2023") vocab = Vocab.from_pretrained("maayanlab/gsfm-rummagene") # Load GeneSetCLIP projection heads # (download clip_model.pt from this repo) from model import GeneSetCLIP, Config clip_model = GeneSetCLIP(Config()) clip_model.load_state_dict(torch.load("clip_model.pt")) clip_model.eval() # Encode a gene set genes = ["TP53", "BRCA1", "EGFR", "MYC"] gene_ids = torch.tensor([vocab(genes)]) with torch.no_grad(): gene_emb = gene_encoder.encode(gene_ids) z_gene, _ = clip_model.get_embeddings(gene_emb=gene_emb) # Encode text text_emb = text_encoder.encode(["Tumor suppressor genes"], convert_to_tensor=True) with torch.no_grad(): _, z_text = clip_model.get_embeddings(text_emb=text_emb) # Compute similarity similarity = (z_gene @ z_text.T).item() print(f"Similarity: {{similarity:.3f}}") ``` ## Config ```json {json.dumps(vars(config), indent=2)} ``` """ with open(os.path.join(save_dir, "README.md"), "w") as f: f.write(readme) # Upload api.upload_folder( folder_path=save_dir, repo_id=config.hub_model_id, commit_message="Upload GeneSetCLIP model", ) print(f" Pushed to https://huggingface.co/{config.hub_model_id}") # ============================================================ # Main # ============================================================ if __name__ == "__main__": config = Config() train(config)