gthai commited on
Commit
6834b13
Β·
0 Parent(s):

initial commit

Browse files
Files changed (2) hide show
  1. .gitattributes +1 -0
  2. README.md +200 -0
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ *.pt filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ tags:
6
+ - genomics
7
+ - bioinformatics
8
+ - classification
9
+ - immunology
10
+ - cll
11
+ - ighv
12
+ - fttransformer
13
+ - tabular
14
+ library_name: pytorch
15
+ ---
16
+
17
+ # IGH Classification β€” FT-Transformer
18
+
19
+ Pre-trained **Feature Tokenizer Transformer (FT-Transformer)** models for classifying whole-genome sequencing reads as IGH (immunoglobulin heavy chain) or non-IGH. The models are trained on a combination of real CLL (chronic lymphocytic leukemia) patient data and synthetic V(D)J recombination sequences.
20
+
21
+ - **GitHub repository:** [acri-nb/igh_classification](https://github.com/acri-nb/igh_classification)
22
+ - **Paper:** Gayap H. *et al.* *Machine learning-based classification of IGHV mutation status in CLL from whole-genome sequencing data.* (submitted)
23
+
24
+ ---
25
+
26
+ ## Model description
27
+
28
+ Each checkpoint is a FT-Transformer (`FTTransformer` class in `models.py`) trained on 464 numerical descriptors extracted from 100 bp sequencing reads:
29
+
30
+ | Feature group | Description |
31
+ |---|---|
32
+ | NAC | Nucleotide amino acid composition |
33
+ | DNC | Dinucleotide composition |
34
+ | TNC | Trinucleotide composition |
35
+ | kGap (di/tri) | k-spaced k-mer frequencies |
36
+ | ORF | Open reading frame features |
37
+ | Fickett | Fickett score |
38
+ | Shannon entropy | 5-mer entropy |
39
+ | Fourier binary | Binary Fourier transform |
40
+ | Fourier complex | Complex Fourier transform |
41
+ | Tsallis entropy | Tsallis entropy (q = 2.3) |
42
+
43
+ **Binary classification:** TP (IGH read) vs. TN (non-IGH read)
44
+
45
+ ---
46
+
47
+ ## Available checkpoints
48
+
49
+ This repository contains **61 pre-trained checkpoints** organized under two experimental approaches:
50
+
51
+ ### 1. `fixed_total_size/` β€” Fixed global training set size (N_global_fixe)
52
+
53
+ The total training set is fixed at **598,709 sequences**. The proportion of real versus synthetic data is varied in steps of 10%, from 0% real (fully synthetic) to 100% real (no synthetic).
54
+
55
+ ```
56
+ fixed_total_size/
57
+ β”œβ”€β”€ transformer_real0/best_model.pt (0% real, 100% synthetic)
58
+ β”œβ”€β”€ transformer_real10/best_model.pt (10% real, 90% synthetic)
59
+ β”œβ”€β”€ transformer_real20/best_model.pt
60
+ β”œβ”€β”€ ...
61
+ └── transformer_real100/best_model.pt (100% real, 0% synthetic)
62
+ ```
63
+
64
+ **Key finding:** Performance collapses when synthetic data exceeds 60% of the training set. A minimum of 50% real data is required for meaningful results.
65
+
66
+ ### 2. `progressive_training/` β€” Fixed real data size with synthetic augmentation (N_real_fixe)
67
+
68
+ The real data size is held constant; synthetic data is added at increasing percentages (10%–100% of the real data size). This approach is systematically evaluated across 5 real data sizes.
69
+
70
+ ```
71
+ progressive_training/
72
+ β”œβ”€β”€ real_050000/
73
+ β”‚ β”œβ”€β”€ synth_010pct_005000/best_model.pt (50K real + 5K synthetic)
74
+ β”‚ β”œβ”€β”€ synth_020pct_010000/best_model.pt
75
+ β”‚ β”œβ”€β”€ ...
76
+ β”‚ └── synth_100pct_050000/best_model.pt (50K real + 50K synthetic)
77
+ β”œβ”€β”€ real_100000/ (100K real, 10 synthetic proportions)
78
+ β”œβ”€β”€ real_150000/ (150K real, 10 synthetic proportions) ← best results
79
+ β”œβ”€β”€ real_200000/ (200K real, 10 synthetic proportions)
80
+ └── real_213100/ (213K real, 10 synthetic proportions)
81
+ ```
82
+
83
+ **Key finding:** Synthetic augmentation monotonically improves performance. Best results plateau at β‰₯ 70% synthetic augmentation.
84
+
85
+ ---
86
+
87
+ ## Best model
88
+
89
+ The recommended checkpoint for production use is:
90
+
91
+ ```
92
+ progressive_training/real_150000/synth_100pct_150000/best_model.pt
93
+ ```
94
+
95
+ | Metric | Value |
96
+ |---|---|
97
+ | Balanced accuracy | 97.5% |
98
+ | F1-score | 95.6% |
99
+ | ROC-AUC | 99.7% |
100
+ | PR-AUC | 99.3% |
101
+
102
+ *Evaluated on a held-out patient test set of 173,100 reads (119,349 TN, 53,751 TP) from CLL patients and the ICGC-CLL Genome cohort.*
103
+
104
+ ---
105
+
106
+ ## Usage
107
+
108
+ ### Installation
109
+
110
+ ```bash
111
+ pip install torch scikit-learn pandas numpy
112
+ git clone https://github.com/acri-nb/igh_classification.git
113
+ ```
114
+
115
+ ### Loading a checkpoint
116
+
117
+ ```python
118
+ import torch
119
+ import sys
120
+ sys.path.insert(0, "/path/to/igh_classification")
121
+ from models import FTTransformer
122
+
123
+ checkpoint = torch.load(
124
+ "progressive_training/real_150000/synth_100pct_150000/best_model.pt",
125
+ map_location="cpu",
126
+ weights_only=False,
127
+ )
128
+
129
+ model = FTTransformer(
130
+ input_dim=checkpoint["input_dim"],
131
+ hidden_dims=checkpoint["hidden_dims"],
132
+ dropout=checkpoint.get("dropout", 0.3),
133
+ )
134
+ model.load_state_dict(checkpoint["model_state_dict"])
135
+ model.eval()
136
+ ```
137
+
138
+ ### Running inference
139
+
140
+ ```python
141
+ import pandas as pd
142
+ from sklearn.preprocessing import RobustScaler
143
+ import torch
144
+
145
+ # Load your feature CSV (output of preprocessing pipeline)
146
+ df = pd.read_csv("features_extracted.csv")
147
+ X = df.values.astype("float32")
148
+
149
+ scaler = RobustScaler()
150
+ X_scaled = scaler.fit_transform(X) # use the scaler fitted on training data
151
+
152
+ X_tensor = torch.tensor(X_scaled, dtype=torch.float32)
153
+
154
+ with torch.no_grad():
155
+ logits = model(X_tensor)
156
+ probs = torch.sigmoid(logits).squeeze()
157
+ predictions = (probs >= 0.5).int()
158
+ ```
159
+
160
+ > **Note:** The scaler must be fitted on the **training data** and saved alongside the model. The `DeepBioClassifier` class in the GitHub repository handles this automatically during training.
161
+
162
+ ---
163
+
164
+ ## Training details
165
+
166
+ | Parameter | Value |
167
+ |---|---|
168
+ | Architecture | FT-Transformer |
169
+ | Input features | 464 |
170
+ | Hidden dimensions | [512, 256, 128, 64] |
171
+ | Dropout | 0.3 |
172
+ | Optimizer | AdamW |
173
+ | Learning rate | 1e-3 |
174
+ | Scheduler | Cosine Annealing with Warm Restarts |
175
+ | Loss | Focal Loss |
176
+ | Epochs | 150 (with early stopping, patience = 50) |
177
+ | Batch size | 256 |
178
+ | Feature normalization | RobustScaler |
179
+
180
+ ---
181
+
182
+ ## Citation
183
+
184
+ If you use these weights or the associated code, please cite:
185
+
186
+ ```bibtex
187
+ @article{gayap2025igh,
188
+ title = {Machine learning-based classification of IGHV mutation status
189
+ in CLL from whole-genome sequencing data},
190
+ author = {Gayap, Hadrien and others},
191
+ journal = {(submitted)},
192
+ year = {2025}
193
+ }
194
+ ```
195
+
196
+ ---
197
+
198
+ ## License
199
+
200
+ MIT License. See [LICENSE](https://github.com/acri-nb/igh_classification/blob/main/LICENSE) for details.