IGH Classification β FT-Transformer
Pre-trained Feature Tokenizer Transformer (FT-Transformer) models for classifying whole-genome sequencing reads as IGH (immunoglobulin heavy chain) or non-IGH. The models are trained on a combination of real CLL (chronic lymphocytic leukemia) patient data and synthetic V(D)J recombination sequences.
- GitHub repository: acri-nb/igh_classification
- Paper: Darmendre J. et al. Machine learning-based classification of IGHV mutation status in CLL from whole-genome sequencing data. (submitted)
Model description
Each checkpoint is a FT-Transformer (FTTransformer class in models.py) trained on 464 numerical descriptors extracted from 100 bp sequencing reads:
| Feature group | Description |
|---|---|
| NAC | Nucleotide amino acid composition |
| DNC | Dinucleotide composition |
| TNC | Trinucleotide composition |
| kGap (di/tri) | k-spaced k-mer frequencies |
| ORF | Open reading frame features |
| Fickett | Fickett score |
| Shannon entropy | 5-mer entropy |
| Fourier binary | Binary Fourier transform |
| Fourier complex | Complex Fourier transform |
| Tsallis entropy | Tsallis entropy (q = 2.3) |
Binary classification: TP (IGH read) vs. TN (non-IGH read)
Available checkpoints
This repository contains 61 pre-trained checkpoints organized under two experimental approaches:
1. fixed_total_size/ β Fixed global training set size (N_global_fixe)
The total training set is fixed at 598,709 sequences. The proportion of real versus synthetic data is varied in steps of 10%, from 0% real (fully synthetic) to 100% real (no synthetic).
fixed_total_size/
βββ transformer_real0/best_model.pt (0% real, 100% synthetic)
βββ transformer_real10/best_model.pt (10% real, 90% synthetic)
βββ transformer_real20/best_model.pt
βββ ...
βββ transformer_real100/best_model.pt (100% real, 0% synthetic)
Key finding: Performance collapses when synthetic data exceeds 60% of the training set. A minimum of 50% real data is required for meaningful results.
2. progressive_training/ β Fixed real data size with synthetic augmentation (N_real_fixe)
The real data size is held constant; synthetic data is added at increasing percentages (10%β100% of the real data size). This approach is systematically evaluated across 5 real data sizes.
progressive_training/
βββ real_050000/
β βββ synth_010pct_005000/best_model.pt (50K real + 5K synthetic)
β βββ synth_020pct_010000/best_model.pt
β βββ ...
β βββ synth_100pct_050000/best_model.pt (50K real + 50K synthetic)
βββ real_100000/ (100K real, 10 synthetic proportions)
βββ real_150000/ (150K real, 10 synthetic proportions) β best results
βββ real_200000/ (200K real, 10 synthetic proportions)
βββ real_213100/ (213K real, 10 synthetic proportions)
Key finding: Synthetic augmentation monotonically improves performance. Best results plateau at β₯ 70% synthetic augmentation.
Best model
The recommended checkpoint for production use is:
progressive_training/real_150000/synth_100pct_150000/best_model.pt
| Metric | Value |
|---|---|
| Balanced accuracy | 97.5% |
| F1-score | 95.6% |
| ROC-AUC | 99.7% |
| PR-AUC | 99.3% |
Evaluated on a held-out patient test set of 173,100 reads (119,349 TN, 53,751 TP) from CLL patients and the ICGC-CLL Genome cohort.
Usage
Installation
pip install torch scikit-learn pandas numpy
git clone https://github.com/acri-nb/igh_classification.git
Loading a checkpoint
import torch
import sys
sys.path.insert(0, "/path/to/igh_classification")
from models import FTTransformer
checkpoint = torch.load(
"progressive_training/real_150000/synth_100pct_150000/best_model.pt",
map_location="cpu",
weights_only=False,
)
model = FTTransformer(
input_dim=checkpoint["input_dim"],
hidden_dims=checkpoint["hidden_dims"],
dropout=checkpoint.get("dropout", 0.3),
)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
Running inference
import pandas as pd
from sklearn.preprocessing import RobustScaler
import torch
# Load your feature CSV (output of preprocessing pipeline)
df = pd.read_csv("features_extracted.csv")
X = df.values.astype("float32")
scaler = RobustScaler()
X_scaled = scaler.fit_transform(X) # use the scaler fitted on training data
X_tensor = torch.tensor(X_scaled, dtype=torch.float32)
with torch.no_grad():
logits = model(X_tensor)
probs = torch.sigmoid(logits).squeeze()
predictions = (probs >= 0.5).int()
Note: The scaler must be fitted on the training data and saved alongside the model. The
DeepBioClassifierclass in the GitHub repository handles this automatically during training.
Training details
| Parameter | Value |
|---|---|
| Architecture | FT-Transformer |
| Input features | 464 |
| Hidden dimensions | [512, 256, 128, 64] |
| Dropout | 0.3 |
| Optimizer | AdamW |
| Learning rate | 1e-3 |
| Scheduler | Cosine Annealing with Warm Restarts |
| Loss | Focal Loss |
| Epochs | 150 (with early stopping, patience = 50) |
| Batch size | 256 |
| Feature normalization | RobustScaler |
Citation
If you use these weights or the associated code, please cite:
@article{gayap2025igh,
title = {Machine learning-based classification of IGHV mutation status
in CLL from whole-genome sequencing data},
author = {Gayap, Hadrien and others},
journal = {(submitted)},
year = {2025}
}
License
MIT License. See LICENSE for details.