Multimodal T-Cell Functional State Classifier

A multimodal deep learning ensemble for predicting T-cell functional states from scRNA-seq data. Integrates gene expression (3,000 HVGs), TCR sequences (via TCR-BERT), and V/J gene usage through bidirectional cross-attention fusion.

89.6% accuracy | macro F1 0.88 | 7 functional states | top-5 ensemble

GitHub: polinavd/multimodal-tcell-classifier

Model Description

This repository contains the weights for a top-5 ensemble of FullGenesVJClassifier models. Each model takes three input modalities:

  • Gene expression: 3,000 highly variable genes (learned dimensionality reduction, no PCA)
  • TCR sequences: CDR3-alpha and CDR3-beta encoded via TCR-BERT (768-dim CLS embeddings)
  • V/J gene usage: one-hot encoded TRAV/TRAJ/TRBV/TRBJ segments (161-dim)

Cross-attention fusion allows each modality to attend to the others before classification into 7 functional states: Effector, Exhausted, Memory, Naive, Proliferating, Th_effector, Treg.

Architecture

GEX (3000) β†’ [Linear 512 β†’ GELU β†’ Linear hidden] + ResidualBlock β†’ (hidden,)
TCR-Ξ±/Ξ² (768) + VJ context (64) β†’ [Linear hidden] + ResidualBlock β†’ (hidden,)
VJ (161) β†’ [Linear hidden] β†’ (hidden,)

Cross-Attention Fusion:
  GEX (1 token) ↔ TCR-Ξ± + TCR-Ξ² + VJ (3 tokens)
  4 attention heads, LayerNorm, residual connections

β†’ Concat (4 Γ— hidden) β†’ ResidualBlock β†’ Linear β†’ 7 classes

Ensemble Diversity

Model Hidden dim Heads Dropout Acc
m7_lr3e4 512 4 0.30 88.8%
m2_h512_s2 512 4 0.30 88.3%
m6_highdrop 512 4 0.35 88.3%
m4_8heads 512 8 0.30 88.3%
m1_h512 512 4 0.30 88.1%

Ensemble averaging of these 5 models yields 89.6% accuracy (macro F1 0.88).

Intended Use

Classification of T-cell functional states from paired scRNA-seq + TCR-seq data. Designed for research use in immunology, immuno-oncology, and single-cell analysis pipelines.

Not intended for clinical decision-making or diagnostic use.

Training Data

136,667 T-cells (after QC filtering) from 4 public scRNA-seq datasets:

Dataset Platform Cells* Tissue
GSE144469 10x Genomics ~60,000 Colitis (colon)
GSE179994 10x Genomics ~77,000 PBMC (exhaustion study)
GSE181061 10x Genomics ~31,000 ccRCC (tumor-infiltrating)
GSE108989 Smart-seq2 ~12,000 CRC (tumor + blood)

*Cell counts are pre-QC; 136,667 cells remain after quality control filtering.

Preprocessing: QC β†’ normalization (scanpy) β†’ 3,000 HVGs β†’ Harmony batch correction β†’ CDR3/V/J extraction via scirpy.

Evaluation

Per-Class Performance (Test Set)

Class Precision Recall F1 Support
Effector 0.91 0.92 0.91 6,685
Exhausted 0.84 0.82 0.83 2,245
Memory 0.89 0.88 0.89 4,979
Naive 0.87 0.85 0.86 2,441
Proliferating 0.92 0.89 0.90 764
Th_effector 0.76 0.74 0.75 393
Treg 0.93 0.94 0.94 2,329

Ablation Study

Configuration Accuracy
TCR-only (BERT embeddings) 33.7%
GEX-only (PCA-50) 69.9%
Multimodal (PCA-50, concat) 79.3%
End-to-end BERT fine-tuning 77.4%
Hybrid + VJ + PCA-200 84.9%
Ensemble + VJ + 3000 genes 89.6%

How to Use

Quick Start

git clone https://github.com/polinavd/multimodal-tcell-classifier.git
cd multimodal-tcell-classifier
pip install -r requirements.txt
python predict_report.py --input your_data.h5ad --output ./results

Model weights (~300 MB) are downloaded automatically from this HuggingFace repo on first run.

Output: interactive HTML report, predictions.csv, annotated .h5ad.

Manual Weight Download

from huggingface_hub import snapshot_download
snapshot_download("VirialyD/tcell-classifier", local_dir="./weights")

Files

File Description
m1_h512.pt Model 1: hidden=512, heads=4, dropout=0.30
m2_h512_s2.pt Model 2: hidden=512, heads=4, dropout=0.30, seed=2
m4_8heads.pt Model 4: hidden=512, heads=8, dropout=0.30
m6_highdrop.pt Model 6: hidden=512, heads=4, dropout=0.35
m7_lr3e4.pt Model 7: hidden=512, heads=4, dropout=0.30, lr=3e-4
results.json Individual model metrics and ensemble config
label_encoder.pkl sklearn LabelEncoder for 7 functional states
vj_encoder.pkl V/J gene one-hot encoder (161-dim)

Technical Details

  • Optimizer: AdamW (lr=2e-4, weight_decay=0.02)
  • Schedule: Cosine annealing with 5% linear warmup
  • Loss: CrossEntropyLoss with balanced class weights + label smoothing (0.03)
  • Mixed precision: FP16 with gradient clipping (max_norm=1.0)
  • Early stopping: on validation macro F1 (patience=12)
  • Hardware: NVIDIA RTX 5070 (8 GB VRAM)

Limitations

  • Trained on human T-cells only; not validated on other species or non-T immune cells.
  • Requires paired scRNA-seq + TCR-seq data (CDR3 alpha/beta + V/J genes).
  • Gene expression input must be from the same 3,000 HVG feature space. The preprocessing pipeline handles this, but heavily divergent protocols may reduce accuracy.
  • Th_effector class has the lowest performance (F1 0.75), likely due to small training sample (393 cells).

Citation

@software{shirokikh2025multimodal,
  author = {Shirokikh, Polina},
  title = {Multimodal T-Cell Functional State Classifier},
  year = {2025},
  url = {https://github.com/polinavd/multimodal-tcell-classifier}
}

License

MIT License β€” see LICENSE for details.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support