--- license: mit language: - en tags: - genomics - bioinformatics - classification - immunology - cll - ighv - fttransformer - tabular library_name: pytorch --- # IGH Classification — FT-Transformer Pre-trained **Feature Tokenizer Transformer (FT-Transformer)** models for classifying whole-genome sequencing reads as IGH (immunoglobulin heavy chain) or non-IGH. The models are trained on a combination of real CLL (chronic lymphocytic leukemia) patient data and synthetic V(D)J recombination sequences. - **GitHub repository:** [acri-nb/igh_classification](https://github.com/acri-nb/igh_classification) - **Paper:** Darmendre J. *et al.* *Machine learning-based classification of IGHV mutation status in CLL from whole-genome sequencing data.* (submitted) --- ## Model description Each checkpoint is a FT-Transformer (`FTTransformer` class in `models.py`) trained on 464 numerical descriptors extracted from 100 bp sequencing reads: | Feature group | Description | |---|---| | NAC | Nucleotide amino acid composition | | DNC | Dinucleotide composition | | TNC | Trinucleotide composition | | kGap (di/tri) | k-spaced k-mer frequencies | | ORF | Open reading frame features | | Fickett | Fickett score | | Shannon entropy | 5-mer entropy | | Fourier binary | Binary Fourier transform | | Fourier complex | Complex Fourier transform | | Tsallis entropy | Tsallis entropy (q = 2.3) | **Binary classification:** TP (IGH read) vs. TN (non-IGH read) --- ## Available checkpoints This repository contains **61 pre-trained checkpoints** organized under two experimental approaches: ### 1. `fixed_total_size/` — Fixed global training set size (N_global_fixe) The total training set is fixed at **598,709 sequences**. The proportion of real versus synthetic data is varied in steps of 10%, from 0% real (fully synthetic) to 100% real (no synthetic). ``` fixed_total_size/ ├── transformer_real0/best_model.pt (0% real, 100% synthetic) ├── transformer_real10/best_model.pt (10% real, 90% synthetic) ├── transformer_real20/best_model.pt ├── ... └── transformer_real100/best_model.pt (100% real, 0% synthetic) ``` **Key finding:** Performance collapses when synthetic data exceeds 60% of the training set. A minimum of 50% real data is required for meaningful results. ### 2. `progressive_training/` — Fixed real data size with synthetic augmentation (N_real_fixe) The real data size is held constant; synthetic data is added at increasing percentages (10%–100% of the real data size). This approach is systematically evaluated across 5 real data sizes. ``` progressive_training/ ├── real_050000/ │ ├── synth_010pct_005000/best_model.pt (50K real + 5K synthetic) │ ├── synth_020pct_010000/best_model.pt │ ├── ... │ └── synth_100pct_050000/best_model.pt (50K real + 50K synthetic) ├── real_100000/ (100K real, 10 synthetic proportions) ├── real_150000/ (150K real, 10 synthetic proportions) ← best results ├── real_200000/ (200K real, 10 synthetic proportions) └── real_213100/ (213K real, 10 synthetic proportions) ``` **Key finding:** Synthetic augmentation monotonically improves performance. Best results plateau at ≥ 70% synthetic augmentation. --- ## Best model The recommended checkpoint for production use is: ``` progressive_training/real_150000/synth_100pct_150000/best_model.pt ``` | Metric | Value | |---|---| | Balanced accuracy | 97.5% | | F1-score | 95.6% | | ROC-AUC | 99.7% | | PR-AUC | 99.3% | *Evaluated on a held-out patient test set of 173,100 reads (119,349 TN, 53,751 TP) from CLL patients and the ICGC-CLL Genome cohort.* --- ## Usage ### Installation ```bash pip install torch scikit-learn pandas numpy git clone https://github.com/acri-nb/igh_classification.git ``` ### Loading a checkpoint ```python import torch import sys sys.path.insert(0, "/path/to/igh_classification") from models import FTTransformer checkpoint = torch.load( "progressive_training/real_150000/synth_100pct_150000/best_model.pt", map_location="cpu", weights_only=False, ) model = FTTransformer( input_dim=checkpoint["input_dim"], hidden_dims=checkpoint["hidden_dims"], dropout=checkpoint.get("dropout", 0.3), ) model.load_state_dict(checkpoint["model_state_dict"]) model.eval() ``` ### Running inference ```python import pandas as pd from sklearn.preprocessing import RobustScaler import torch # Load your feature CSV (output of preprocessing pipeline) df = pd.read_csv("features_extracted.csv") X = df.values.astype("float32") scaler = RobustScaler() X_scaled = scaler.fit_transform(X) # use the scaler fitted on training data X_tensor = torch.tensor(X_scaled, dtype=torch.float32) with torch.no_grad(): logits = model(X_tensor) probs = torch.sigmoid(logits).squeeze() predictions = (probs >= 0.5).int() ``` > **Note:** The scaler must be fitted on the **training data** and saved alongside the model. The `DeepBioClassifier` class in the GitHub repository handles this automatically during training. --- ## Training details | Parameter | Value | |---|---| | Architecture | FT-Transformer | | Input features | 464 | | Hidden dimensions | [512, 256, 128, 64] | | Dropout | 0.3 | | Optimizer | AdamW | | Learning rate | 1e-3 | | Scheduler | Cosine Annealing with Warm Restarts | | Loss | Focal Loss | | Epochs | 150 (with early stopping, patience = 50) | | Batch size | 256 | | Feature normalization | RobustScaler | --- ## Citation If you use these weights or the associated code, please cite: ```bibtex @article{gayap2025igh, title = {Machine learning-based classification of IGHV mutation status in CLL from whole-genome sequencing data}, author = {Gayap, Hadrien and others}, journal = {(submitted)}, year = {2025} } ``` --- ## License MIT License. See [LICENSE](https://github.com/acri-nb/igh_classification/blob/main/LICENSE) for details.