File size: 6,125 Bytes
b262442 7322445 b262442 7322445 b262442 7322445 b262442 7322445 b262442 7322445 b262442 7322445 b262442 7322445 b262442 7322445 b262442 7322445 b262442 cb3b9a3 b262442 cb3b9a3 b262442 7322445 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 | ---
license: mit
library_name: pytorch
tags:
- biology
- single-cell
- T-cell
- TCR
- immunology
- scRNA-seq
- multimodal
- cross-attention
- ensemble
datasets:
- custom
metrics:
- accuracy
- f1
pipeline_tag: tabular-classification
language:
- en
---
# Multimodal T-Cell Functional State Classifier
A multimodal deep learning ensemble for predicting T-cell functional states from scRNA-seq data. Integrates gene expression (3,000 HVGs), TCR sequences (via [TCR-BERT](https://github.com/wukevin/tcr-bert)), and V/J gene usage through bidirectional cross-attention fusion.
**89.6% accuracy** | **macro F1 0.88** | **7 functional states** | **top-5 ensemble**
**GitHub**: [polinavd/multimodal-tcell-classifier](https://github.com/polinavd/multimodal-tcell-classifier)
## Model Description
This repository contains the weights for a top-5 ensemble of `FullGenesVJClassifier` models. Each model takes three input modalities:
- **Gene expression**: 3,000 highly variable genes (learned dimensionality reduction, no PCA)
- **TCR sequences**: CDR3-alpha and CDR3-beta encoded via TCR-BERT (768-dim CLS embeddings)
- **V/J gene usage**: one-hot encoded TRAV/TRAJ/TRBV/TRBJ segments (161-dim)
Cross-attention fusion allows each modality to attend to the others before classification into 7 functional states: Effector, Exhausted, Memory, Naive, Proliferating, Th_effector, Treg.
### Architecture
```
GEX (3000) β [Linear 512 β GELU β Linear hidden] + ResidualBlock β (hidden,)
TCR-Ξ±/Ξ² (768) + VJ context (64) β [Linear hidden] + ResidualBlock β (hidden,)
VJ (161) β [Linear hidden] β (hidden,)
Cross-Attention Fusion:
GEX (1 token) β TCR-Ξ± + TCR-Ξ² + VJ (3 tokens)
4 attention heads, LayerNorm, residual connections
β Concat (4 Γ hidden) β ResidualBlock β Linear β 7 classes
```
### Ensemble Diversity
| Model | Hidden dim | Heads | Dropout | Acc |
|---|---|---|---|---|
| m7_lr3e4 | 512 | 4 | 0.30 | 88.8% |
| m2_h512_s2 | 512 | 4 | 0.30 | 88.3% |
| m6_highdrop | 512 | 4 | 0.35 | 88.3% |
| m4_8heads | 512 | 8 | 0.30 | 88.3% |
| m1_h512 | 512 | 4 | 0.30 | 88.1% |
Ensemble averaging of these 5 models yields **89.6% accuracy** (macro F1 0.88).
## Intended Use
Classification of T-cell functional states from paired scRNA-seq + TCR-seq data. Designed for research use in immunology, immuno-oncology, and single-cell analysis pipelines.
**Not intended** for clinical decision-making or diagnostic use.
## Training Data
**136,667 T-cells** (after QC filtering) from 4 public scRNA-seq datasets:
| Dataset | Platform | Cells* | Tissue |
|---|---|---|---|
| GSE144469 | 10x Genomics | ~60,000 | Colitis (colon) |
| GSE179994 | 10x Genomics | ~77,000 | PBMC (exhaustion study) |
| GSE181061 | 10x Genomics | ~31,000 | ccRCC (tumor-infiltrating) |
| GSE108989 | Smart-seq2 | ~12,000 | CRC (tumor + blood) |
*Cell counts are pre-QC; 136,667 cells remain after quality control filtering.
Preprocessing: QC β normalization (scanpy) β 3,000 HVGs β Harmony batch correction β CDR3/V/J extraction via scirpy.
## Evaluation
### Per-Class Performance (Test Set)
| Class | Precision | Recall | F1 | Support |
|---|---|---|---|---|
| Effector | 0.91 | 0.92 | 0.91 | 6,685 |
| Exhausted | 0.84 | 0.82 | 0.83 | 2,245 |
| Memory | 0.89 | 0.88 | 0.89 | 4,979 |
| Naive | 0.87 | 0.85 | 0.86 | 2,441 |
| Proliferating | 0.92 | 0.89 | 0.90 | 764 |
| Th_effector | 0.76 | 0.74 | 0.75 | 393 |
| Treg | 0.93 | 0.94 | 0.94 | 2,329 |
### Ablation Study
| Configuration | Accuracy |
|---|---|
| TCR-only (BERT embeddings) | 33.7% |
| GEX-only (PCA-50) | 69.9% |
| Multimodal (PCA-50, concat) | 79.3% |
| End-to-end BERT fine-tuning | 77.4% |
| Hybrid + VJ + PCA-200 | 84.9% |
| **Ensemble + VJ + 3000 genes** | **89.6%** |
## How to Use
### Quick Start
```bash
git clone https://github.com/polinavd/multimodal-tcell-classifier.git
cd multimodal-tcell-classifier
pip install -r requirements.txt
python predict_report.py --input your_data.h5ad --output ./results
```
Model weights (~300 MB) are downloaded automatically from this HuggingFace repo on first run.
Output: interactive HTML report, predictions.csv, annotated .h5ad.
### Manual Weight Download
```python
from huggingface_hub import snapshot_download
snapshot_download("VirialyD/tcell-classifier", local_dir="./weights")
```
## Files
| File | Description |
|---|---|
| `m1_h512.pt` | Model 1: hidden=512, heads=4, dropout=0.30 |
| `m2_h512_s2.pt` | Model 2: hidden=512, heads=4, dropout=0.30, seed=2 |
| `m4_8heads.pt` | Model 4: hidden=512, heads=8, dropout=0.30 |
| `m6_highdrop.pt` | Model 6: hidden=512, heads=4, dropout=0.35 |
| `m7_lr3e4.pt` | Model 7: hidden=512, heads=4, dropout=0.30, lr=3e-4 |
| `results.json` | Individual model metrics and ensemble config |
| `label_encoder.pkl` | sklearn LabelEncoder for 7 functional states |
| `vj_encoder.pkl` | V/J gene one-hot encoder (161-dim) |
## Technical Details
- **Optimizer**: AdamW (lr=2e-4, weight_decay=0.02)
- **Schedule**: Cosine annealing with 5% linear warmup
- **Loss**: CrossEntropyLoss with balanced class weights + label smoothing (0.03)
- **Mixed precision**: FP16 with gradient clipping (max_norm=1.0)
- **Early stopping**: on validation macro F1 (patience=12)
- **Hardware**: NVIDIA RTX 5070 (8 GB VRAM)
## Limitations
- Trained on human T-cells only; not validated on other species or non-T immune cells.
- Requires paired scRNA-seq + TCR-seq data (CDR3 alpha/beta + V/J genes).
- Gene expression input must be from the same 3,000 HVG feature space. The preprocessing pipeline handles this, but heavily divergent protocols may reduce accuracy.
- Th_effector class has the lowest performance (F1 0.75), likely due to small training sample (393 cells).
## Citation
```bibtex
@software{shirokikh2026multimodal,
author = {Shirokikh, Polina},
title = {Multimodal T-Cell Functional State Classifier},
year = {2026},
url = {https://github.com/polinavd/multimodal-tcell-classifier}
}
```
## License
MIT License β see [LICENSE](https://github.com/polinavd/multimodal-tcell-classifier/blob/main/LICENSE) for details.
|