| | --- |
| | license: mit |
| | language: |
| | - en |
| | tags: |
| | - genomics |
| | - bioinformatics |
| | - classification |
| | - immunology |
| | - cll |
| | - ighv |
| | - fttransformer |
| | - tabular |
| | library_name: pytorch |
| | --- |
| | |
| | # IGH Classification β FT-Transformer |
| |
|
| | Pre-trained **Feature Tokenizer Transformer (FT-Transformer)** models for classifying whole-genome sequencing reads as IGH (immunoglobulin heavy chain) or non-IGH. The models are trained on a combination of real CLL (chronic lymphocytic leukemia) patient data and synthetic V(D)J recombination sequences. |
| |
|
| | - **GitHub repository:** [acri-nb/igh_classification](https://github.com/acri-nb/igh_classification) |
| | - **Paper:** Darmendre J. *et al.* *Machine learning-based classification of IGHV mutation status in CLL from whole-genome sequencing data.* (submitted) |
| |
|
| | --- |
| |
|
| | ## Model description |
| |
|
| | Each checkpoint is a FT-Transformer (`FTTransformer` class in `models.py`) trained on 464 numerical descriptors extracted from 100 bp sequencing reads: |
| |
|
| | | Feature group | Description | |
| | |---|---| |
| | | NAC | Nucleotide amino acid composition | |
| | | DNC | Dinucleotide composition | |
| | | TNC | Trinucleotide composition | |
| | | kGap (di/tri) | k-spaced k-mer frequencies | |
| | | ORF | Open reading frame features | |
| | | Fickett | Fickett score | |
| | | Shannon entropy | 5-mer entropy | |
| | | Fourier binary | Binary Fourier transform | |
| | | Fourier complex | Complex Fourier transform | |
| | | Tsallis entropy | Tsallis entropy (q = 2.3) | |
| |
|
| | **Binary classification:** TP (IGH read) vs. TN (non-IGH read) |
| |
|
| | --- |
| |
|
| | ## Available checkpoints |
| |
|
| | This repository contains **61 pre-trained checkpoints** organized under two experimental approaches: |
| |
|
| | ### 1. `fixed_total_size/` β Fixed global training set size (N_global_fixe) |
| |
|
| | The total training set is fixed at **598,709 sequences**. The proportion of real versus synthetic data is varied in steps of 10%, from 0% real (fully synthetic) to 100% real (no synthetic). |
| |
|
| | ``` |
| | fixed_total_size/ |
| | βββ transformer_real0/best_model.pt (0% real, 100% synthetic) |
| | βββ transformer_real10/best_model.pt (10% real, 90% synthetic) |
| | βββ transformer_real20/best_model.pt |
| | βββ ... |
| | βββ transformer_real100/best_model.pt (100% real, 0% synthetic) |
| | ``` |
| |
|
| | **Key finding:** Performance collapses when synthetic data exceeds 60% of the training set. A minimum of 50% real data is required for meaningful results. |
| |
|
| | ### 2. `progressive_training/` β Fixed real data size with synthetic augmentation (N_real_fixe) |
| | |
| | The real data size is held constant; synthetic data is added at increasing percentages (10%β100% of the real data size). This approach is systematically evaluated across 5 real data sizes. |
| | |
| | ``` |
| | progressive_training/ |
| | βββ real_050000/ |
| | β βββ synth_010pct_005000/best_model.pt (50K real + 5K synthetic) |
| | β βββ synth_020pct_010000/best_model.pt |
| | β βββ ... |
| | β βββ synth_100pct_050000/best_model.pt (50K real + 50K synthetic) |
| | βββ real_100000/ (100K real, 10 synthetic proportions) |
| | βββ real_150000/ (150K real, 10 synthetic proportions) β best results |
| | βββ real_200000/ (200K real, 10 synthetic proportions) |
| | βββ real_213100/ (213K real, 10 synthetic proportions) |
| | ``` |
| | |
| | **Key finding:** Synthetic augmentation monotonically improves performance. Best results plateau at β₯ 70% synthetic augmentation. |
| | |
| | --- |
| | |
| | ## Best model |
| | |
| | The recommended checkpoint for production use is: |
| | |
| | ``` |
| | progressive_training/real_150000/synth_100pct_150000/best_model.pt |
| | ``` |
| | |
| | | Metric | Value | |
| | |---|---| |
| | | Balanced accuracy | 97.5% | |
| | | F1-score | 95.6% | |
| | | ROC-AUC | 99.7% | |
| | | PR-AUC | 99.3% | |
| | |
| | *Evaluated on a held-out patient test set of 173,100 reads (119,349 TN, 53,751 TP) from CLL patients and the ICGC-CLL Genome cohort.* |
| | |
| | --- |
| | |
| | ## Usage |
| | |
| | ### Installation |
| | |
| | ```bash |
| | pip install torch scikit-learn pandas numpy |
| | git clone https://github.com/acri-nb/igh_classification.git |
| | ``` |
| | |
| | ### Loading a checkpoint |
| | |
| | ```python |
| | import torch |
| | import sys |
| | sys.path.insert(0, "/path/to/igh_classification") |
| | from models import FTTransformer |
| | |
| | checkpoint = torch.load( |
| | "progressive_training/real_150000/synth_100pct_150000/best_model.pt", |
| | map_location="cpu", |
| | weights_only=False, |
| | ) |
| | |
| | model = FTTransformer( |
| | input_dim=checkpoint["input_dim"], |
| | hidden_dims=checkpoint["hidden_dims"], |
| | dropout=checkpoint.get("dropout", 0.3), |
| | ) |
| | model.load_state_dict(checkpoint["model_state_dict"]) |
| | model.eval() |
| | ``` |
| | |
| | ### Running inference |
| |
|
| | ```python |
| | import pandas as pd |
| | from sklearn.preprocessing import RobustScaler |
| | import torch |
| | |
| | # Load your feature CSV (output of preprocessing pipeline) |
| | df = pd.read_csv("features_extracted.csv") |
| | X = df.values.astype("float32") |
| | |
| | scaler = RobustScaler() |
| | X_scaled = scaler.fit_transform(X) # use the scaler fitted on training data |
| | |
| | X_tensor = torch.tensor(X_scaled, dtype=torch.float32) |
| | |
| | with torch.no_grad(): |
| | logits = model(X_tensor) |
| | probs = torch.sigmoid(logits).squeeze() |
| | predictions = (probs >= 0.5).int() |
| | ``` |
| |
|
| | > **Note:** The scaler must be fitted on the **training data** and saved alongside the model. The `DeepBioClassifier` class in the GitHub repository handles this automatically during training. |
| |
|
| | --- |
| |
|
| | ## Training details |
| |
|
| | | Parameter | Value | |
| | |---|---| |
| | | Architecture | FT-Transformer | |
| | | Input features | 464 | |
| | | Hidden dimensions | [512, 256, 128, 64] | |
| | | Dropout | 0.3 | |
| | | Optimizer | AdamW | |
| | | Learning rate | 1e-3 | |
| | | Scheduler | Cosine Annealing with Warm Restarts | |
| | | Loss | Focal Loss | |
| | | Epochs | 150 (with early stopping, patience = 50) | |
| | | Batch size | 256 | |
| | | Feature normalization | RobustScaler | |
| |
|
| | --- |
| |
|
| | ## Citation |
| |
|
| | If you use these weights or the associated code, please cite: |
| |
|
| | ```bibtex |
| | @article{gayap2025igh, |
| | title = {Machine learning-based classification of IGHV mutation status |
| | in CLL from whole-genome sequencing data}, |
| | author = {Gayap, Hadrien and others}, |
| | journal = {(submitted)}, |
| | year = {2025} |
| | } |
| | ``` |
| |
|
| | --- |
| |
|
| | ## License |
| |
|
| | MIT License. See [LICENSE](https://github.com/acri-nb/igh_classification/blob/main/LICENSE) for details. |
| |
|