File size: 5,995 Bytes
6834b13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa6c06e
6834b13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
---
license: mit
language:
  - en
tags:
  - genomics
  - bioinformatics
  - classification
  - immunology
  - cll
  - ighv
  - fttransformer
  - tabular
library_name: pytorch
---

# IGH Classification β€” FT-Transformer

Pre-trained **Feature Tokenizer Transformer (FT-Transformer)** models for classifying whole-genome sequencing reads as IGH (immunoglobulin heavy chain) or non-IGH. The models are trained on a combination of real CLL (chronic lymphocytic leukemia) patient data and synthetic V(D)J recombination sequences.

- **GitHub repository:** [acri-nb/igh_classification](https://github.com/acri-nb/igh_classification)
- **Paper:** Darmendre J. *et al.* *Machine learning-based classification of IGHV mutation status in CLL from whole-genome sequencing data.* (submitted)

---

## Model description

Each checkpoint is a FT-Transformer (`FTTransformer` class in `models.py`) trained on 464 numerical descriptors extracted from 100 bp sequencing reads:

| Feature group | Description |
|---|---|
| NAC | Nucleotide amino acid composition |
| DNC | Dinucleotide composition |
| TNC | Trinucleotide composition |
| kGap (di/tri) | k-spaced k-mer frequencies |
| ORF | Open reading frame features |
| Fickett | Fickett score |
| Shannon entropy | 5-mer entropy |
| Fourier binary | Binary Fourier transform |
| Fourier complex | Complex Fourier transform |
| Tsallis entropy | Tsallis entropy (q = 2.3) |

**Binary classification:** TP (IGH read) vs. TN (non-IGH read)

---

## Available checkpoints

This repository contains **61 pre-trained checkpoints** organized under two experimental approaches:

### 1. `fixed_total_size/` β€” Fixed global training set size (N_global_fixe)

The total training set is fixed at **598,709 sequences**. The proportion of real versus synthetic data is varied in steps of 10%, from 0% real (fully synthetic) to 100% real (no synthetic).

```
fixed_total_size/
β”œβ”€β”€ transformer_real0/best_model.pt     (0% real, 100% synthetic)
β”œβ”€β”€ transformer_real10/best_model.pt    (10% real, 90% synthetic)
β”œβ”€β”€ transformer_real20/best_model.pt
β”œβ”€β”€ ...
└── transformer_real100/best_model.pt   (100% real, 0% synthetic)
```

**Key finding:** Performance collapses when synthetic data exceeds 60% of the training set. A minimum of 50% real data is required for meaningful results.

### 2. `progressive_training/` β€” Fixed real data size with synthetic augmentation (N_real_fixe)

The real data size is held constant; synthetic data is added at increasing percentages (10%–100% of the real data size). This approach is systematically evaluated across 5 real data sizes.

```
progressive_training/
β”œβ”€β”€ real_050000/
β”‚   β”œβ”€β”€ synth_010pct_005000/best_model.pt    (50K real + 5K synthetic)
β”‚   β”œβ”€β”€ synth_020pct_010000/best_model.pt
β”‚   β”œβ”€β”€ ...
β”‚   └── synth_100pct_050000/best_model.pt    (50K real + 50K synthetic)
β”œβ”€β”€ real_100000/  (100K real, 10 synthetic proportions)
β”œβ”€β”€ real_150000/  (150K real, 10 synthetic proportions)  ← best results
β”œβ”€β”€ real_200000/  (200K real, 10 synthetic proportions)
└── real_213100/  (213K real, 10 synthetic proportions)
```

**Key finding:** Synthetic augmentation monotonically improves performance. Best results plateau at β‰₯ 70% synthetic augmentation.

---

## Best model

The recommended checkpoint for production use is:

```
progressive_training/real_150000/synth_100pct_150000/best_model.pt
```

| Metric | Value |
|---|---|
| Balanced accuracy | 97.5% |
| F1-score | 95.6% |
| ROC-AUC | 99.7% |
| PR-AUC | 99.3% |

*Evaluated on a held-out patient test set of 173,100 reads (119,349 TN, 53,751 TP) from CLL patients and the ICGC-CLL Genome cohort.*

---

## Usage

### Installation

```bash
pip install torch scikit-learn pandas numpy
git clone https://github.com/acri-nb/igh_classification.git
```

### Loading a checkpoint

```python
import torch
import sys
sys.path.insert(0, "/path/to/igh_classification")
from models import FTTransformer

checkpoint = torch.load(
    "progressive_training/real_150000/synth_100pct_150000/best_model.pt",
    map_location="cpu",
    weights_only=False,
)

model = FTTransformer(
    input_dim=checkpoint["input_dim"],
    hidden_dims=checkpoint["hidden_dims"],
    dropout=checkpoint.get("dropout", 0.3),
)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
```

### Running inference

```python
import pandas as pd
from sklearn.preprocessing import RobustScaler
import torch

# Load your feature CSV (output of preprocessing pipeline)
df = pd.read_csv("features_extracted.csv")
X = df.values.astype("float32")

scaler = RobustScaler()
X_scaled = scaler.fit_transform(X)   # use the scaler fitted on training data

X_tensor = torch.tensor(X_scaled, dtype=torch.float32)

with torch.no_grad():
    logits = model(X_tensor)
    probs = torch.sigmoid(logits).squeeze()
    predictions = (probs >= 0.5).int()
```

> **Note:** The scaler must be fitted on the **training data** and saved alongside the model. The `DeepBioClassifier` class in the GitHub repository handles this automatically during training.

---

## Training details

| Parameter | Value |
|---|---|
| Architecture | FT-Transformer |
| Input features | 464 |
| Hidden dimensions | [512, 256, 128, 64] |
| Dropout | 0.3 |
| Optimizer | AdamW |
| Learning rate | 1e-3 |
| Scheduler | Cosine Annealing with Warm Restarts |
| Loss | Focal Loss |
| Epochs | 150 (with early stopping, patience = 50) |
| Batch size | 256 |
| Feature normalization | RobustScaler |

---

## Citation

If you use these weights or the associated code, please cite:

```bibtex
@article{gayap2025igh,
  title   = {Machine learning-based classification of IGHV mutation status
             in CLL from whole-genome sequencing data},
  author  = {Gayap, Hadrien and others},
  journal = {(submitted)},
  year    = {2025}
}
```

---

## License

MIT License. See [LICENSE](https://github.com/acri-nb/igh_classification/blob/main/LICENSE) for details.