nileshhanotia commited on
Commit
007a33f
·
verified ·
1 Parent(s): 5c8b5a1

Upload train_splice_cnn_v4.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_splice_cnn_v4.py +278 -0
train_splice_cnn_v4.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================
2
+ # MutationPredictorCNN_v4 Training Script (401 bp FASTA)
3
+ # Proper sequence-based training
4
+ # ============================================================
5
+
6
+ import argparse
7
+ import numpy as np
8
+ import pandas as pd
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.utils.data import Dataset, DataLoader
12
+ from sklearn.metrics import roc_auc_score
13
+ import pysam
14
+ from tqdm import tqdm
15
+ import os
16
+
17
+ # ============================================================
18
+ # Arguments
19
+ # ============================================================
20
+
21
+ parser = argparse.ArgumentParser()
22
+
23
+ parser.add_argument("--train_csv", required=True)
24
+ parser.add_argument("--fasta", required=True)
25
+ parser.add_argument("--output_model", required=True)
26
+
27
+ parser.add_argument("--epochs", type=int, default=30)
28
+ parser.add_argument("--batch_size", type=int, default=256)
29
+ parser.add_argument("--num_workers", type=int, default=8)
30
+ parser.add_argument("--lr", type=float, default=0.001)
31
+
32
+ args = parser.parse_args()
33
+
34
+ # ============================================================
35
+ # Config
36
+ # ============================================================
37
+
38
+ WINDOW = 401
39
+ HALF = WINDOW // 2
40
+ SEQ_LEN = WINDOW - 2
41
+ DEVICE = "cpu"
42
+
43
+ print("Loading FASTA...")
44
+ fasta = pysam.FastaFile(args.fasta)
45
+
46
+ # ============================================================
47
+ # Encoding
48
+ # ============================================================
49
+
50
+ BASE_MAP = {"A":0,"C":1,"G":2,"T":3}
51
+ COMP = {"A":"T","T":"A","C":"G","G":"C","N":"N"}
52
+
53
+ def fetch_seq(chrom, pos):
54
+
55
+ start = pos - HALF - 1
56
+ end = pos + HALF
57
+
58
+ try:
59
+ return fasta.fetch(str(chrom), start, end).upper()
60
+ except:
61
+ try:
62
+ return fasta.fetch("chr"+str(chrom), start, end).upper()
63
+ except:
64
+ return None
65
+
66
+
67
+ def encode_seq(seq):
68
+
69
+ arr = np.zeros((11, SEQ_LEN), dtype=np.float32)
70
+
71
+ for i in range(SEQ_LEN):
72
+
73
+ j = i + 1
74
+ base = seq[j] if j < len(seq) else "N"
75
+
76
+ if base in BASE_MAP:
77
+ arr[BASE_MAP[base], i] = 1
78
+ comp = COMP[base]
79
+ if comp in BASE_MAP:
80
+ arr[4 + BASE_MAP[comp], i] = 1
81
+
82
+ arr[8, i] = (j - HALF) / HALF
83
+
84
+ if seq[j:j+2] == "GT":
85
+ arr[9, i] = 1
86
+
87
+ if seq[j:j+2] == "AG":
88
+ arr[10, i] = 1
89
+
90
+ return arr
91
+
92
+
93
+ def mut_onehot(ref, alt):
94
+
95
+ types = [
96
+ "A>C","A>G","A>T",
97
+ "C>A","C>G","C>T",
98
+ "G>A","G>C","G>T",
99
+ "T>A","T>C","T>G"
100
+ ]
101
+
102
+ vec = np.zeros(12, dtype=np.float32)
103
+
104
+ key = f"{ref}>{alt}"
105
+
106
+ if key in types:
107
+ vec[types.index(key)] = 1
108
+
109
+ return vec
110
+
111
+
112
+ # ============================================================
113
+ # Dataset
114
+ # ============================================================
115
+
116
+ class SpliceDataset(Dataset):
117
+
118
+ def __init__(self, df):
119
+
120
+ self.df = df.reset_index(drop=True)
121
+
122
+ def __len__(self):
123
+ return len(self.df)
124
+
125
+ def __getitem__(self, idx):
126
+
127
+ row = self.df.iloc[idx]
128
+
129
+ seq = fetch_seq(row.chrom, int(row.pos))
130
+
131
+ if seq is None or len(seq) != WINDOW:
132
+ seq = "N" * WINDOW
133
+
134
+ seq_enc = encode_seq(seq)
135
+
136
+ mut = mut_onehot(row.ref, row.alt)
137
+
138
+ region = np.zeros(2, dtype=np.float32)
139
+ splice = np.zeros(3, dtype=np.float32)
140
+
141
+ label = float(row.label)
142
+
143
+ return (
144
+ torch.tensor(seq_enc),
145
+ torch.tensor(mut),
146
+ torch.tensor(region),
147
+ torch.tensor(splice),
148
+ torch.tensor(label)
149
+ )
150
+
151
+
152
+ # ============================================================
153
+ # Model
154
+ # ============================================================
155
+
156
+ class MutationPredictorCNN_v4(nn.Module):
157
+
158
+ def __init__(self):
159
+
160
+ super().__init__()
161
+
162
+ self.conv1 = nn.Conv1d(11, 64, 7, padding=3)
163
+ self.conv2 = nn.Conv1d(64, 128, 5, padding=2)
164
+ self.conv3 = nn.Conv1d(128, 256, 3, padding=1)
165
+
166
+ self.pool = nn.AdaptiveAvgPool1d(1)
167
+
168
+ self.mut_fc = nn.Linear(12, 32)
169
+
170
+ self.region_fc = nn.Linear(2, 8)
171
+ self.splice_fc = nn.Linear(3, 16)
172
+
173
+ self.fc1 = nn.Linear(312, 128)
174
+ self.fc2 = nn.Linear(128, 64)
175
+ self.fc3 = nn.Linear(64, 1)
176
+
177
+ self.relu = nn.ReLU()
178
+ self.dropout = nn.Dropout(0.3)
179
+
180
+ def forward(self, seq, mut, region, splice):
181
+
182
+ x = self.relu(self.conv1(seq))
183
+ x = self.relu(self.conv2(x))
184
+ x = self.relu(self.conv3(x))
185
+
186
+ x = self.pool(x).squeeze(-1)
187
+
188
+ m = self.relu(self.mut_fc(mut))
189
+ r = self.relu(self.region_fc(region))
190
+ s = self.relu(self.splice_fc(splice))
191
+
192
+ x = torch.cat([x,m,r,s], dim=1)
193
+
194
+ x = self.dropout(self.relu(self.fc1(x)))
195
+ x = self.relu(self.fc2(x))
196
+
197
+ return self.fc3(x)
198
+
199
+
200
+ # ============================================================
201
+ # Load dataset
202
+ # ============================================================
203
+
204
+ print("Loading dataset...")
205
+
206
+ df = pd.read_csv(args.train_csv)
207
+
208
+ train_ds = SpliceDataset(df)
209
+
210
+ train_dl = DataLoader(
211
+ train_ds,
212
+ batch_size=args.batch_size,
213
+ shuffle=True,
214
+ num_workers=args.num_workers
215
+ )
216
+
217
+ # ============================================================
218
+ # Train
219
+ # ============================================================
220
+
221
+ model = MutationPredictorCNN_v4().to(DEVICE)
222
+
223
+ criterion = nn.BCEWithLogitsLoss()
224
+
225
+ optimizer = torch.optim.Adam(
226
+ model.parameters(),
227
+ lr=args.lr
228
+ )
229
+
230
+ best_auc = 0
231
+
232
+ for epoch in range(args.epochs):
233
+
234
+ model.train()
235
+
236
+ losses = []
237
+ probs = []
238
+ labels = []
239
+
240
+ for seq, mut, region, splice, label in train_dl:
241
+
242
+ seq = seq.to(DEVICE)
243
+ mut = mut.to(DEVICE)
244
+ region = region.to(DEVICE)
245
+ splice = splice.to(DEVICE)
246
+ label = label.to(DEVICE).unsqueeze(1)
247
+
248
+ optimizer.zero_grad()
249
+
250
+ logits = model(seq, mut, region, splice)
251
+
252
+ loss = criterion(logits, label)
253
+
254
+ loss.backward()
255
+
256
+ optimizer.step()
257
+
258
+ losses.append(loss.item())
259
+
260
+ probs.extend(torch.sigmoid(logits).detach().cpu().numpy())
261
+ labels.extend(label.cpu().numpy())
262
+
263
+ auc = roc_auc_score(labels, probs)
264
+
265
+ print(f"Epoch {epoch+1}/{args.epochs} Loss={np.mean(losses):.4f} AUC={auc:.4f}")
266
+
267
+ if auc > best_auc:
268
+
269
+ best_auc = auc
270
+
271
+ torch.save(
272
+ model.state_dict(),
273
+ args.output_model
274
+ )
275
+
276
+ print("Saved best model")
277
+
278
+ print("Training complete.")