gthai's picture
docs: update readme
aa6c06e
---
license: mit
language:
- en
tags:
- genomics
- bioinformatics
- classification
- immunology
- cll
- ighv
- fttransformer
- tabular
library_name: pytorch
---
# IGH Classification β€” FT-Transformer
Pre-trained **Feature Tokenizer Transformer (FT-Transformer)** models for classifying whole-genome sequencing reads as IGH (immunoglobulin heavy chain) or non-IGH. The models are trained on a combination of real CLL (chronic lymphocytic leukemia) patient data and synthetic V(D)J recombination sequences.
- **GitHub repository:** [acri-nb/igh_classification](https://github.com/acri-nb/igh_classification)
- **Paper:** Darmendre J. *et al.* *Machine learning-based classification of IGHV mutation status in CLL from whole-genome sequencing data.* (submitted)
---
## Model description
Each checkpoint is a FT-Transformer (`FTTransformer` class in `models.py`) trained on 464 numerical descriptors extracted from 100 bp sequencing reads:
| Feature group | Description |
|---|---|
| NAC | Nucleotide amino acid composition |
| DNC | Dinucleotide composition |
| TNC | Trinucleotide composition |
| kGap (di/tri) | k-spaced k-mer frequencies |
| ORF | Open reading frame features |
| Fickett | Fickett score |
| Shannon entropy | 5-mer entropy |
| Fourier binary | Binary Fourier transform |
| Fourier complex | Complex Fourier transform |
| Tsallis entropy | Tsallis entropy (q = 2.3) |
**Binary classification:** TP (IGH read) vs. TN (non-IGH read)
---
## Available checkpoints
This repository contains **61 pre-trained checkpoints** organized under two experimental approaches:
### 1. `fixed_total_size/` β€” Fixed global training set size (N_global_fixe)
The total training set is fixed at **598,709 sequences**. The proportion of real versus synthetic data is varied in steps of 10%, from 0% real (fully synthetic) to 100% real (no synthetic).
```
fixed_total_size/
β”œβ”€β”€ transformer_real0/best_model.pt (0% real, 100% synthetic)
β”œβ”€β”€ transformer_real10/best_model.pt (10% real, 90% synthetic)
β”œβ”€β”€ transformer_real20/best_model.pt
β”œβ”€β”€ ...
└── transformer_real100/best_model.pt (100% real, 0% synthetic)
```
**Key finding:** Performance collapses when synthetic data exceeds 60% of the training set. A minimum of 50% real data is required for meaningful results.
### 2. `progressive_training/` β€” Fixed real data size with synthetic augmentation (N_real_fixe)
The real data size is held constant; synthetic data is added at increasing percentages (10%–100% of the real data size). This approach is systematically evaluated across 5 real data sizes.
```
progressive_training/
β”œβ”€β”€ real_050000/
β”‚ β”œβ”€β”€ synth_010pct_005000/best_model.pt (50K real + 5K synthetic)
β”‚ β”œβ”€β”€ synth_020pct_010000/best_model.pt
β”‚ β”œβ”€β”€ ...
β”‚ └── synth_100pct_050000/best_model.pt (50K real + 50K synthetic)
β”œβ”€β”€ real_100000/ (100K real, 10 synthetic proportions)
β”œβ”€β”€ real_150000/ (150K real, 10 synthetic proportions) ← best results
β”œβ”€β”€ real_200000/ (200K real, 10 synthetic proportions)
└── real_213100/ (213K real, 10 synthetic proportions)
```
**Key finding:** Synthetic augmentation monotonically improves performance. Best results plateau at β‰₯ 70% synthetic augmentation.
---
## Best model
The recommended checkpoint for production use is:
```
progressive_training/real_150000/synth_100pct_150000/best_model.pt
```
| Metric | Value |
|---|---|
| Balanced accuracy | 97.5% |
| F1-score | 95.6% |
| ROC-AUC | 99.7% |
| PR-AUC | 99.3% |
*Evaluated on a held-out patient test set of 173,100 reads (119,349 TN, 53,751 TP) from CLL patients and the ICGC-CLL Genome cohort.*
---
## Usage
### Installation
```bash
pip install torch scikit-learn pandas numpy
git clone https://github.com/acri-nb/igh_classification.git
```
### Loading a checkpoint
```python
import torch
import sys
sys.path.insert(0, "/path/to/igh_classification")
from models import FTTransformer
checkpoint = torch.load(
"progressive_training/real_150000/synth_100pct_150000/best_model.pt",
map_location="cpu",
weights_only=False,
)
model = FTTransformer(
input_dim=checkpoint["input_dim"],
hidden_dims=checkpoint["hidden_dims"],
dropout=checkpoint.get("dropout", 0.3),
)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
```
### Running inference
```python
import pandas as pd
from sklearn.preprocessing import RobustScaler
import torch
# Load your feature CSV (output of preprocessing pipeline)
df = pd.read_csv("features_extracted.csv")
X = df.values.astype("float32")
scaler = RobustScaler()
X_scaled = scaler.fit_transform(X) # use the scaler fitted on training data
X_tensor = torch.tensor(X_scaled, dtype=torch.float32)
with torch.no_grad():
logits = model(X_tensor)
probs = torch.sigmoid(logits).squeeze()
predictions = (probs >= 0.5).int()
```
> **Note:** The scaler must be fitted on the **training data** and saved alongside the model. The `DeepBioClassifier` class in the GitHub repository handles this automatically during training.
---
## Training details
| Parameter | Value |
|---|---|
| Architecture | FT-Transformer |
| Input features | 464 |
| Hidden dimensions | [512, 256, 128, 64] |
| Dropout | 0.3 |
| Optimizer | AdamW |
| Learning rate | 1e-3 |
| Scheduler | Cosine Annealing with Warm Restarts |
| Loss | Focal Loss |
| Epochs | 150 (with early stopping, patience = 50) |
| Batch size | 256 |
| Feature normalization | RobustScaler |
---
## Citation
If you use these weights or the associated code, please cite:
```bibtex
@article{gayap2025igh,
title = {Machine learning-based classification of IGHV mutation status
in CLL from whole-genome sequencing data},
author = {Gayap, Hadrien and others},
journal = {(submitted)},
year = {2025}
}
```
---
## License
MIT License. See [LICENSE](https://github.com/acri-nb/igh_classification/blob/main/LICENSE) for details.