File size: 5,995 Bytes
6834b13 aa6c06e 6834b13 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 | ---
license: mit
language:
- en
tags:
- genomics
- bioinformatics
- classification
- immunology
- cll
- ighv
- fttransformer
- tabular
library_name: pytorch
---
# IGH Classification β FT-Transformer
Pre-trained **Feature Tokenizer Transformer (FT-Transformer)** models for classifying whole-genome sequencing reads as IGH (immunoglobulin heavy chain) or non-IGH. The models are trained on a combination of real CLL (chronic lymphocytic leukemia) patient data and synthetic V(D)J recombination sequences.
- **GitHub repository:** [acri-nb/igh_classification](https://github.com/acri-nb/igh_classification)
- **Paper:** Darmendre J. *et al.* *Machine learning-based classification of IGHV mutation status in CLL from whole-genome sequencing data.* (submitted)
---
## Model description
Each checkpoint is a FT-Transformer (`FTTransformer` class in `models.py`) trained on 464 numerical descriptors extracted from 100 bp sequencing reads:
| Feature group | Description |
|---|---|
| NAC | Nucleotide amino acid composition |
| DNC | Dinucleotide composition |
| TNC | Trinucleotide composition |
| kGap (di/tri) | k-spaced k-mer frequencies |
| ORF | Open reading frame features |
| Fickett | Fickett score |
| Shannon entropy | 5-mer entropy |
| Fourier binary | Binary Fourier transform |
| Fourier complex | Complex Fourier transform |
| Tsallis entropy | Tsallis entropy (q = 2.3) |
**Binary classification:** TP (IGH read) vs. TN (non-IGH read)
---
## Available checkpoints
This repository contains **61 pre-trained checkpoints** organized under two experimental approaches:
### 1. `fixed_total_size/` β Fixed global training set size (N_global_fixe)
The total training set is fixed at **598,709 sequences**. The proportion of real versus synthetic data is varied in steps of 10%, from 0% real (fully synthetic) to 100% real (no synthetic).
```
fixed_total_size/
βββ transformer_real0/best_model.pt (0% real, 100% synthetic)
βββ transformer_real10/best_model.pt (10% real, 90% synthetic)
βββ transformer_real20/best_model.pt
βββ ...
βββ transformer_real100/best_model.pt (100% real, 0% synthetic)
```
**Key finding:** Performance collapses when synthetic data exceeds 60% of the training set. A minimum of 50% real data is required for meaningful results.
### 2. `progressive_training/` β Fixed real data size with synthetic augmentation (N_real_fixe)
The real data size is held constant; synthetic data is added at increasing percentages (10%β100% of the real data size). This approach is systematically evaluated across 5 real data sizes.
```
progressive_training/
βββ real_050000/
β βββ synth_010pct_005000/best_model.pt (50K real + 5K synthetic)
β βββ synth_020pct_010000/best_model.pt
β βββ ...
β βββ synth_100pct_050000/best_model.pt (50K real + 50K synthetic)
βββ real_100000/ (100K real, 10 synthetic proportions)
βββ real_150000/ (150K real, 10 synthetic proportions) β best results
βββ real_200000/ (200K real, 10 synthetic proportions)
βββ real_213100/ (213K real, 10 synthetic proportions)
```
**Key finding:** Synthetic augmentation monotonically improves performance. Best results plateau at β₯ 70% synthetic augmentation.
---
## Best model
The recommended checkpoint for production use is:
```
progressive_training/real_150000/synth_100pct_150000/best_model.pt
```
| Metric | Value |
|---|---|
| Balanced accuracy | 97.5% |
| F1-score | 95.6% |
| ROC-AUC | 99.7% |
| PR-AUC | 99.3% |
*Evaluated on a held-out patient test set of 173,100 reads (119,349 TN, 53,751 TP) from CLL patients and the ICGC-CLL Genome cohort.*
---
## Usage
### Installation
```bash
pip install torch scikit-learn pandas numpy
git clone https://github.com/acri-nb/igh_classification.git
```
### Loading a checkpoint
```python
import torch
import sys
sys.path.insert(0, "/path/to/igh_classification")
from models import FTTransformer
checkpoint = torch.load(
"progressive_training/real_150000/synth_100pct_150000/best_model.pt",
map_location="cpu",
weights_only=False,
)
model = FTTransformer(
input_dim=checkpoint["input_dim"],
hidden_dims=checkpoint["hidden_dims"],
dropout=checkpoint.get("dropout", 0.3),
)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
```
### Running inference
```python
import pandas as pd
from sklearn.preprocessing import RobustScaler
import torch
# Load your feature CSV (output of preprocessing pipeline)
df = pd.read_csv("features_extracted.csv")
X = df.values.astype("float32")
scaler = RobustScaler()
X_scaled = scaler.fit_transform(X) # use the scaler fitted on training data
X_tensor = torch.tensor(X_scaled, dtype=torch.float32)
with torch.no_grad():
logits = model(X_tensor)
probs = torch.sigmoid(logits).squeeze()
predictions = (probs >= 0.5).int()
```
> **Note:** The scaler must be fitted on the **training data** and saved alongside the model. The `DeepBioClassifier` class in the GitHub repository handles this automatically during training.
---
## Training details
| Parameter | Value |
|---|---|
| Architecture | FT-Transformer |
| Input features | 464 |
| Hidden dimensions | [512, 256, 128, 64] |
| Dropout | 0.3 |
| Optimizer | AdamW |
| Learning rate | 1e-3 |
| Scheduler | Cosine Annealing with Warm Restarts |
| Loss | Focal Loss |
| Epochs | 150 (with early stopping, patience = 50) |
| Batch size | 256 |
| Feature normalization | RobustScaler |
---
## Citation
If you use these weights or the associated code, please cite:
```bibtex
@article{gayap2025igh,
title = {Machine learning-based classification of IGHV mutation status
in CLL from whole-genome sequencing data},
author = {Gayap, Hadrien and others},
journal = {(submitted)},
year = {2025}
}
```
---
## License
MIT License. See [LICENSE](https://github.com/acri-nb/igh_classification/blob/main/LICENSE) for details.
|