| --- |
| license: mit |
| library_name: pytorch |
| tags: |
| - biology |
| - single-cell |
| - T-cell |
| - TCR |
| - immunology |
| - scRNA-seq |
| - multimodal |
| - cross-attention |
| - ensemble |
| datasets: |
| - custom |
| metrics: |
| - accuracy |
| - f1 |
| pipeline_tag: tabular-classification |
| language: |
| - en |
| --- |
| |
| # Multimodal T-Cell Functional State Classifier |
|
|
| A multimodal deep learning ensemble for predicting T-cell functional states from scRNA-seq data. Integrates gene expression (3,000 HVGs), TCR sequences (via [TCR-BERT](https://github.com/wukevin/tcr-bert)), and V/J gene usage through bidirectional cross-attention fusion. |
|
|
| **89.6% accuracy** | **macro F1 0.88** | **7 functional states** | **top-5 ensemble** |
|
|
| **GitHub**: [polinavd/multimodal-tcell-classifier](https://github.com/polinavd/multimodal-tcell-classifier) |
|
|
| ## Model Description |
|
|
| This repository contains the weights for a top-5 ensemble of `FullGenesVJClassifier` models. Each model takes three input modalities: |
|
|
| - **Gene expression**: 3,000 highly variable genes (learned dimensionality reduction, no PCA) |
| - **TCR sequences**: CDR3-alpha and CDR3-beta encoded via TCR-BERT (768-dim CLS embeddings) |
| - **V/J gene usage**: one-hot encoded TRAV/TRAJ/TRBV/TRBJ segments (161-dim) |
|
|
| Cross-attention fusion allows each modality to attend to the others before classification into 7 functional states: Effector, Exhausted, Memory, Naive, Proliferating, Th_effector, Treg. |
| |
| ### Architecture |
| |
| ``` |
| GEX (3000) β [Linear 512 β GELU β Linear hidden] + ResidualBlock β (hidden,) |
| TCR-Ξ±/Ξ² (768) + VJ context (64) β [Linear hidden] + ResidualBlock β (hidden,) |
| VJ (161) β [Linear hidden] β (hidden,) |
| |
| Cross-Attention Fusion: |
| GEX (1 token) β TCR-Ξ± + TCR-Ξ² + VJ (3 tokens) |
| 4 attention heads, LayerNorm, residual connections |
| |
| β Concat (4 Γ hidden) β ResidualBlock β Linear β 7 classes |
| ``` |
| |
| ### Ensemble Diversity |
| |
| | Model | Hidden dim | Heads | Dropout | Acc | |
| |---|---|---|---|---| |
| | m7_lr3e4 | 512 | 4 | 0.30 | 88.8% | |
| | m2_h512_s2 | 512 | 4 | 0.30 | 88.3% | |
| | m6_highdrop | 512 | 4 | 0.35 | 88.3% | |
| | m4_8heads | 512 | 8 | 0.30 | 88.3% | |
| | m1_h512 | 512 | 4 | 0.30 | 88.1% | |
| |
| Ensemble averaging of these 5 models yields **89.6% accuracy** (macro F1 0.88). |
| |
| ## Intended Use |
| |
| Classification of T-cell functional states from paired scRNA-seq + TCR-seq data. Designed for research use in immunology, immuno-oncology, and single-cell analysis pipelines. |
| |
| **Not intended** for clinical decision-making or diagnostic use. |
| |
| ## Training Data |
| |
| **136,667 T-cells** (after QC filtering) from 4 public scRNA-seq datasets: |
| |
| | Dataset | Platform | Cells* | Tissue | |
| |---|---|---|---| |
| | GSE144469 | 10x Genomics | ~60,000 | Colitis (colon) | |
| | GSE179994 | 10x Genomics | ~77,000 | PBMC (exhaustion study) | |
| | GSE181061 | 10x Genomics | ~31,000 | ccRCC (tumor-infiltrating) | |
| | GSE108989 | Smart-seq2 | ~12,000 | CRC (tumor + blood) | |
| |
| *Cell counts are pre-QC; 136,667 cells remain after quality control filtering. |
| |
| Preprocessing: QC β normalization (scanpy) β 3,000 HVGs β Harmony batch correction β CDR3/V/J extraction via scirpy. |
| |
| ## Evaluation |
| |
| ### Per-Class Performance (Test Set) |
| |
| | Class | Precision | Recall | F1 | Support | |
| |---|---|---|---|---| |
| | Effector | 0.91 | 0.92 | 0.91 | 6,685 | |
| | Exhausted | 0.84 | 0.82 | 0.83 | 2,245 | |
| | Memory | 0.89 | 0.88 | 0.89 | 4,979 | |
| | Naive | 0.87 | 0.85 | 0.86 | 2,441 | |
| | Proliferating | 0.92 | 0.89 | 0.90 | 764 | |
| | Th_effector | 0.76 | 0.74 | 0.75 | 393 | |
| | Treg | 0.93 | 0.94 | 0.94 | 2,329 | |
|
|
| ### Ablation Study |
|
|
| | Configuration | Accuracy | |
| |---|---| |
| | TCR-only (BERT embeddings) | 33.7% | |
| | GEX-only (PCA-50) | 69.9% | |
| | Multimodal (PCA-50, concat) | 79.3% | |
| | End-to-end BERT fine-tuning | 77.4% | |
| | Hybrid + VJ + PCA-200 | 84.9% | |
| | **Ensemble + VJ + 3000 genes** | **89.6%** | |
|
|
| ## How to Use |
|
|
| ### Quick Start |
|
|
| ```bash |
| git clone https://github.com/polinavd/multimodal-tcell-classifier.git |
| cd multimodal-tcell-classifier |
| pip install -r requirements.txt |
| python predict_report.py --input your_data.h5ad --output ./results |
| ``` |
|
|
| Model weights (~300 MB) are downloaded automatically from this HuggingFace repo on first run. |
|
|
| Output: interactive HTML report, predictions.csv, annotated .h5ad. |
|
|
| ### Manual Weight Download |
|
|
| ```python |
| from huggingface_hub import snapshot_download |
| snapshot_download("VirialyD/tcell-classifier", local_dir="./weights") |
| ``` |
|
|
| ## Files |
|
|
| | File | Description | |
| |---|---| |
| | `m1_h512.pt` | Model 1: hidden=512, heads=4, dropout=0.30 | |
| | `m2_h512_s2.pt` | Model 2: hidden=512, heads=4, dropout=0.30, seed=2 | |
| | `m4_8heads.pt` | Model 4: hidden=512, heads=8, dropout=0.30 | |
| | `m6_highdrop.pt` | Model 6: hidden=512, heads=4, dropout=0.35 | |
| | `m7_lr3e4.pt` | Model 7: hidden=512, heads=4, dropout=0.30, lr=3e-4 | |
| | `results.json` | Individual model metrics and ensemble config | |
| | `label_encoder.pkl` | sklearn LabelEncoder for 7 functional states | |
| | `vj_encoder.pkl` | V/J gene one-hot encoder (161-dim) | |
|
|
| ## Technical Details |
|
|
| - **Optimizer**: AdamW (lr=2e-4, weight_decay=0.02) |
| - **Schedule**: Cosine annealing with 5% linear warmup |
| - **Loss**: CrossEntropyLoss with balanced class weights + label smoothing (0.03) |
| - **Mixed precision**: FP16 with gradient clipping (max_norm=1.0) |
| - **Early stopping**: on validation macro F1 (patience=12) |
| - **Hardware**: NVIDIA RTX 5070 (8 GB VRAM) |
|
|
| ## Limitations |
|
|
| - Trained on human T-cells only; not validated on other species or non-T immune cells. |
| - Requires paired scRNA-seq + TCR-seq data (CDR3 alpha/beta + V/J genes). |
| - Gene expression input must be from the same 3,000 HVG feature space. The preprocessing pipeline handles this, but heavily divergent protocols may reduce accuracy. |
| - Th_effector class has the lowest performance (F1 0.75), likely due to small training sample (393 cells). |
| |
| ## Citation |
| |
| ```bibtex |
| @software{shirokikh2026multimodal, |
| author = {Shirokikh, Polina}, |
| title = {Multimodal T-Cell Functional State Classifier}, |
| year = {2026}, |
| url = {https://github.com/polinavd/multimodal-tcell-classifier} |
| } |
| ``` |
| |
| ## License |
| |
| MIT License β see [LICENSE](https://github.com/polinavd/multimodal-tcell-classifier/blob/main/LICENSE) for details. |
| |