Multimodal T-Cell Functional State Classifier
A multimodal deep learning ensemble for predicting T-cell functional states from scRNA-seq data. Integrates gene expression (3,000 HVGs), TCR sequences (via TCR-BERT), and V/J gene usage through bidirectional cross-attention fusion.
89.6% accuracy | macro F1 0.88 | 7 functional states | top-5 ensemble
GitHub: polinavd/multimodal-tcell-classifier
Model Description
This repository contains the weights for a top-5 ensemble of FullGenesVJClassifier models. Each model takes three input modalities:
- Gene expression: 3,000 highly variable genes (learned dimensionality reduction, no PCA)
- TCR sequences: CDR3-alpha and CDR3-beta encoded via TCR-BERT (768-dim CLS embeddings)
- V/J gene usage: one-hot encoded TRAV/TRAJ/TRBV/TRBJ segments (161-dim)
Cross-attention fusion allows each modality to attend to the others before classification into 7 functional states: Effector, Exhausted, Memory, Naive, Proliferating, Th_effector, Treg.
Architecture
GEX (3000) β [Linear 512 β GELU β Linear hidden] + ResidualBlock β (hidden,)
TCR-Ξ±/Ξ² (768) + VJ context (64) β [Linear hidden] + ResidualBlock β (hidden,)
VJ (161) β [Linear hidden] β (hidden,)
Cross-Attention Fusion:
GEX (1 token) β TCR-Ξ± + TCR-Ξ² + VJ (3 tokens)
4 attention heads, LayerNorm, residual connections
β Concat (4 Γ hidden) β ResidualBlock β Linear β 7 classes
Ensemble Diversity
| Model | Hidden dim | Heads | Dropout | Acc |
|---|---|---|---|---|
| m7_lr3e4 | 512 | 4 | 0.30 | 88.8% |
| m2_h512_s2 | 512 | 4 | 0.30 | 88.3% |
| m6_highdrop | 512 | 4 | 0.35 | 88.3% |
| m4_8heads | 512 | 8 | 0.30 | 88.3% |
| m1_h512 | 512 | 4 | 0.30 | 88.1% |
Ensemble averaging of these 5 models yields 89.6% accuracy (macro F1 0.88).
Intended Use
Classification of T-cell functional states from paired scRNA-seq + TCR-seq data. Designed for research use in immunology, immuno-oncology, and single-cell analysis pipelines.
Not intended for clinical decision-making or diagnostic use.
Training Data
136,667 T-cells (after QC filtering) from 4 public scRNA-seq datasets:
| Dataset | Platform | Cells* | Tissue |
|---|---|---|---|
| GSE144469 | 10x Genomics | ~60,000 | Colitis (colon) |
| GSE179994 | 10x Genomics | ~77,000 | PBMC (exhaustion study) |
| GSE181061 | 10x Genomics | ~31,000 | ccRCC (tumor-infiltrating) |
| GSE108989 | Smart-seq2 | ~12,000 | CRC (tumor + blood) |
*Cell counts are pre-QC; 136,667 cells remain after quality control filtering.
Preprocessing: QC β normalization (scanpy) β 3,000 HVGs β Harmony batch correction β CDR3/V/J extraction via scirpy.
Evaluation
Per-Class Performance (Test Set)
| Class | Precision | Recall | F1 | Support |
|---|---|---|---|---|
| Effector | 0.91 | 0.92 | 0.91 | 6,685 |
| Exhausted | 0.84 | 0.82 | 0.83 | 2,245 |
| Memory | 0.89 | 0.88 | 0.89 | 4,979 |
| Naive | 0.87 | 0.85 | 0.86 | 2,441 |
| Proliferating | 0.92 | 0.89 | 0.90 | 764 |
| Th_effector | 0.76 | 0.74 | 0.75 | 393 |
| Treg | 0.93 | 0.94 | 0.94 | 2,329 |
Ablation Study
| Configuration | Accuracy |
|---|---|
| TCR-only (BERT embeddings) | 33.7% |
| GEX-only (PCA-50) | 69.9% |
| Multimodal (PCA-50, concat) | 79.3% |
| End-to-end BERT fine-tuning | 77.4% |
| Hybrid + VJ + PCA-200 | 84.9% |
| Ensemble + VJ + 3000 genes | 89.6% |
How to Use
Quick Start
git clone https://github.com/polinavd/multimodal-tcell-classifier.git
cd multimodal-tcell-classifier
pip install -r requirements.txt
python predict_report.py --input your_data.h5ad --output ./results
Model weights (~300 MB) are downloaded automatically from this HuggingFace repo on first run.
Output: interactive HTML report, predictions.csv, annotated .h5ad.
Manual Weight Download
from huggingface_hub import snapshot_download
snapshot_download("VirialyD/tcell-classifier", local_dir="./weights")
Files
| File | Description |
|---|---|
m1_h512.pt |
Model 1: hidden=512, heads=4, dropout=0.30 |
m2_h512_s2.pt |
Model 2: hidden=512, heads=4, dropout=0.30, seed=2 |
m4_8heads.pt |
Model 4: hidden=512, heads=8, dropout=0.30 |
m6_highdrop.pt |
Model 6: hidden=512, heads=4, dropout=0.35 |
m7_lr3e4.pt |
Model 7: hidden=512, heads=4, dropout=0.30, lr=3e-4 |
results.json |
Individual model metrics and ensemble config |
label_encoder.pkl |
sklearn LabelEncoder for 7 functional states |
vj_encoder.pkl |
V/J gene one-hot encoder (161-dim) |
Technical Details
- Optimizer: AdamW (lr=2e-4, weight_decay=0.02)
- Schedule: Cosine annealing with 5% linear warmup
- Loss: CrossEntropyLoss with balanced class weights + label smoothing (0.03)
- Mixed precision: FP16 with gradient clipping (max_norm=1.0)
- Early stopping: on validation macro F1 (patience=12)
- Hardware: NVIDIA RTX 5070 (8 GB VRAM)
Limitations
- Trained on human T-cells only; not validated on other species or non-T immune cells.
- Requires paired scRNA-seq + TCR-seq data (CDR3 alpha/beta + V/J genes).
- Gene expression input must be from the same 3,000 HVG feature space. The preprocessing pipeline handles this, but heavily divergent protocols may reduce accuracy.
- Th_effector class has the lowest performance (F1 0.75), likely due to small training sample (393 cells).
Citation
@software{shirokikh2025multimodal,
author = {Shirokikh, Polina},
title = {Multimodal T-Cell Functional State Classifier},
year = {2025},
url = {https://github.com/polinavd/multimodal-tcell-classifier}
}
License
MIT License β see LICENSE for details.