IGH Classification β€” FT-Transformer

Pre-trained Feature Tokenizer Transformer (FT-Transformer) models for classifying whole-genome sequencing reads as IGH (immunoglobulin heavy chain) or non-IGH. The models are trained on a combination of real CLL (chronic lymphocytic leukemia) patient data and synthetic V(D)J recombination sequences.

  • GitHub repository: acri-nb/igh_classification
  • Paper: Darmendre J. et al. Machine learning-based classification of IGHV mutation status in CLL from whole-genome sequencing data. (submitted)

Model description

Each checkpoint is a FT-Transformer (FTTransformer class in models.py) trained on 464 numerical descriptors extracted from 100 bp sequencing reads:

Feature group Description
NAC Nucleotide amino acid composition
DNC Dinucleotide composition
TNC Trinucleotide composition
kGap (di/tri) k-spaced k-mer frequencies
ORF Open reading frame features
Fickett Fickett score
Shannon entropy 5-mer entropy
Fourier binary Binary Fourier transform
Fourier complex Complex Fourier transform
Tsallis entropy Tsallis entropy (q = 2.3)

Binary classification: TP (IGH read) vs. TN (non-IGH read)


Available checkpoints

This repository contains 61 pre-trained checkpoints organized under two experimental approaches:

1. fixed_total_size/ β€” Fixed global training set size (N_global_fixe)

The total training set is fixed at 598,709 sequences. The proportion of real versus synthetic data is varied in steps of 10%, from 0% real (fully synthetic) to 100% real (no synthetic).

fixed_total_size/
β”œβ”€β”€ transformer_real0/best_model.pt     (0% real, 100% synthetic)
β”œβ”€β”€ transformer_real10/best_model.pt    (10% real, 90% synthetic)
β”œβ”€β”€ transformer_real20/best_model.pt
β”œβ”€β”€ ...
└── transformer_real100/best_model.pt   (100% real, 0% synthetic)

Key finding: Performance collapses when synthetic data exceeds 60% of the training set. A minimum of 50% real data is required for meaningful results.

2. progressive_training/ β€” Fixed real data size with synthetic augmentation (N_real_fixe)

The real data size is held constant; synthetic data is added at increasing percentages (10%–100% of the real data size). This approach is systematically evaluated across 5 real data sizes.

progressive_training/
β”œβ”€β”€ real_050000/
β”‚   β”œβ”€β”€ synth_010pct_005000/best_model.pt    (50K real + 5K synthetic)
β”‚   β”œβ”€β”€ synth_020pct_010000/best_model.pt
β”‚   β”œβ”€β”€ ...
β”‚   └── synth_100pct_050000/best_model.pt    (50K real + 50K synthetic)
β”œβ”€β”€ real_100000/  (100K real, 10 synthetic proportions)
β”œβ”€β”€ real_150000/  (150K real, 10 synthetic proportions)  ← best results
β”œβ”€β”€ real_200000/  (200K real, 10 synthetic proportions)
└── real_213100/  (213K real, 10 synthetic proportions)

Key finding: Synthetic augmentation monotonically improves performance. Best results plateau at β‰₯ 70% synthetic augmentation.


Best model

The recommended checkpoint for production use is:

progressive_training/real_150000/synth_100pct_150000/best_model.pt
Metric Value
Balanced accuracy 97.5%
F1-score 95.6%
ROC-AUC 99.7%
PR-AUC 99.3%

Evaluated on a held-out patient test set of 173,100 reads (119,349 TN, 53,751 TP) from CLL patients and the ICGC-CLL Genome cohort.


Usage

Installation

pip install torch scikit-learn pandas numpy
git clone https://github.com/acri-nb/igh_classification.git

Loading a checkpoint

import torch
import sys
sys.path.insert(0, "/path/to/igh_classification")
from models import FTTransformer

checkpoint = torch.load(
    "progressive_training/real_150000/synth_100pct_150000/best_model.pt",
    map_location="cpu",
    weights_only=False,
)

model = FTTransformer(
    input_dim=checkpoint["input_dim"],
    hidden_dims=checkpoint["hidden_dims"],
    dropout=checkpoint.get("dropout", 0.3),
)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

Running inference

import pandas as pd
from sklearn.preprocessing import RobustScaler
import torch

# Load your feature CSV (output of preprocessing pipeline)
df = pd.read_csv("features_extracted.csv")
X = df.values.astype("float32")

scaler = RobustScaler()
X_scaled = scaler.fit_transform(X)   # use the scaler fitted on training data

X_tensor = torch.tensor(X_scaled, dtype=torch.float32)

with torch.no_grad():
    logits = model(X_tensor)
    probs = torch.sigmoid(logits).squeeze()
    predictions = (probs >= 0.5).int()

Note: The scaler must be fitted on the training data and saved alongside the model. The DeepBioClassifier class in the GitHub repository handles this automatically during training.


Training details

Parameter Value
Architecture FT-Transformer
Input features 464
Hidden dimensions [512, 256, 128, 64]
Dropout 0.3
Optimizer AdamW
Learning rate 1e-3
Scheduler Cosine Annealing with Warm Restarts
Loss Focal Loss
Epochs 150 (with early stopping, patience = 50)
Batch size 256
Feature normalization RobustScaler

Citation

If you use these weights or the associated code, please cite:

@article{gayap2025igh,
  title   = {Machine learning-based classification of IGHV mutation status
             in CLL from whole-genome sequencing data},
  author  = {Gayap, Hadrien and others},
  journal = {(submitted)},
  year    = {2025}
}

License

MIT License. See LICENSE for details.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support