cla1r3 commited on
Commit
15a4e7c
·
verified ·
1 Parent(s): a3d5117

Upload neural_network.py

Browse files
Files changed (1) hide show
  1. neural_network.py +640 -0
neural_network.py ADDED
@@ -0,0 +1,640 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """neural network
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/13Vym7d6JDkWLa9cv9p8h_amR_3uUnGp9
8
+ """
9
+
10
+ # Cell A: Upload training dataset google sheets (CSV file)
11
+ from google.colab import files
12
+ import pandas as pd
13
+ import io
14
+
15
+ uploaded = files.upload()
16
+
17
+ # Cell B: Define liability predictor model
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ class LiabilityPredictor(nn.Module):
22
+ def __init__(
23
+ self,
24
+ input_dim: int = 640,
25
+ output_dim: int = 4,
26
+ hidden_dims=(128, 64),
27
+ dropout: float = 0.10,
28
+ activation: str = "gelu",
29
+ use_layernorm: bool = True,
30
+ ):
31
+ super().__init__()
32
+
33
+ # Choose activation function. Converts "gelu" string into actual PyTorch layer.
34
+ act_layer = {
35
+ "relu": nn.ReLU,
36
+ "gelu": nn.GELU,
37
+ "silu": nn.SiLU,
38
+ }.get(activation.lower())
39
+
40
+ if act_layer is None:
41
+ raise ValueError(f"Unknown activation='{activation}'. Use 'relu', 'gelu', or 'silu'.")
42
+
43
+ layers = []
44
+
45
+ if use_layernorm:
46
+ layers.append(nn.LayerNorm(input_dim))
47
+
48
+ prev = input_dim
49
+ for h in hidden_dims:
50
+ layers.append(nn.Linear(prev, h))
51
+ if use_layernorm:
52
+ layers.append(nn.LayerNorm(h))
53
+ layers.append(act_layer())
54
+ if dropout and dropout > 0:
55
+ layers.append(nn.Dropout(dropout))
56
+ prev = h
57
+
58
+ layers.append(nn.Linear(prev, output_dim))
59
+ self.net = nn.Sequential(*layers)
60
+
61
+ self._init_weights()
62
+
63
+ def _init_weights(self): #Xavier initialisation
64
+ # Stable init for small-data regression
65
+ for m in self.modules():
66
+ if isinstance(m, nn.Linear):
67
+ nn.init.xavier_uniform_(m.weight)
68
+ if m.bias is not None:
69
+ nn.init.zeros_(m.bias)
70
+
71
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
72
+ # Guardrails: ensure correct dtype/shape
73
+ if x.dim() == 1:
74
+ x = x.unsqueeze(0) # (640,) -> (1, 640)
75
+ if x.dim() != 2:
76
+ raise ValueError(f"Expected x to have shape (batch, features). Got {tuple(x.shape)}")
77
+
78
+ return self.net(x.float())
79
+
80
+ # Cell C: Create dataset
81
+ import torch
82
+ from torch.utils.data import Dataset
83
+ import pandas as pd
84
+ from transformers import AutoModel, AutoTokenizer
85
+ import numpy as np
86
+
87
+ MODEL_NAME = "facebook/esm2_t6_8M_UR50D"
88
+ CSV_PATH = "trainingdataset - Sheet 1.csv"
89
+
90
+ df = pd.read_csv(CSV_PATH)
91
+
92
+ target_cols = ['polyreactivity', 'hydrophobicity', 'aggregation', 'charge_patch']
93
+ for col in target_cols:
94
+ df[col] = pd.to_numeric(df[col], errors='coerce')
95
+
96
+ df = df.dropna(subset=['VH','VL'] + target_cols).reset_index(drop=True)
97
+
98
+ y = df[target_cols].values
99
+ print("Target order:", target_cols)
100
+ print("Rows kept:", len(df))
101
+
102
+ # Load ESM-2
103
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
104
+ esm_model = AutoModel.from_pretrained(MODEL_NAME)
105
+ esm_model.eval()
106
+
107
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
108
+ esm_model.to(device)
109
+
110
+
111
+ hidden_size = esm_model.config.hidden_size
112
+
113
+ def embed_sequences_meanpool_scoring_style(seqs, batch_size=8):
114
+
115
+ unique_seqs = list(dict.fromkeys(seqs))
116
+ seq_to_vec = {}
117
+
118
+ for i in range(0, len(unique_seqs), batch_size):
119
+ batch_seqs = unique_seqs[i:i + batch_size]
120
+
121
+ tokenized = tokenizer(
122
+ batch_seqs,
123
+ return_tensors="pt",
124
+ padding=True,
125
+ truncation=True,
126
+ )
127
+ tokenized = {k: v.to(device) for k, v in tokenized.items()}
128
+
129
+ with torch.inference_mode():
130
+ out = esm_model(**tokenized)
131
+
132
+ token_emb = out.last_hidden_state
133
+ attn = tokenized["attention_mask"].float()
134
+
135
+ pooled = (token_emb * attn.unsqueeze(-1)).sum(dim=1)
136
+ pooled = pooled / attn.sum(dim=1).clamp(min=1).unsqueeze(-1)
137
+
138
+ pooled = pooled.detach().cpu()
139
+ for s, v in zip(batch_seqs, pooled):
140
+ seq_to_vec[s] = v
141
+
142
+ return seq_to_vec
143
+
144
+ all_seqs = df["VH"].tolist() + df["VL"].tolist()
145
+ seq_to_vec = embed_sequences_meanpool_scoring_style(all_seqs, batch_size=8)
146
+
147
+ X_tensors = []
148
+ for _, row in df.iterrows():
149
+ vh_vec = seq_to_vec[row["VH"]]
150
+ vl_vec = seq_to_vec[row["VL"]]
151
+
152
+ assert vh_vec.shape == (hidden_size,), f"VH vec shape {vh_vec.shape} != ({hidden_size},)"
153
+ assert vl_vec.shape == (hidden_size,), f"VL vec shape {vl_vec.shape} != ({hidden_size},)"
154
+
155
+ # Concatenate VH + VL
156
+ combined_vec = torch.cat([vh_vec, vl_vec], dim=0) # (640,)
157
+ X_tensors.append(combined_vec)
158
+
159
+ X = torch.stack(X_tensors, dim=0).numpy()
160
+ assert X.shape[1] == 2 * hidden_size, f"Expected {2*hidden_size} features, got {X.shape[1]}"
161
+
162
+ assert X.shape[0] == y.shape[0], f"X rows {X.shape[0]} != y rows {y.shape[0]}"
163
+
164
+ # Create dataset object
165
+ class AntibodyDataset(Dataset):
166
+ def __init__(self, X, y):
167
+ self.X = torch.tensor(X, dtype=torch.float32)
168
+ self.y = torch.tensor(y, dtype=torch.float32)
169
+
170
+ def __len__(self):
171
+ return len(self.X)
172
+
173
+ def __getitem__(self, idx):
174
+ return self.X[idx], self.y[idx]
175
+
176
+
177
+ dataset = AntibodyDataset(X, y)
178
+
179
+ print(
180
+ f"Dataset created: {len(dataset)} samples | "
181
+ f"X shape: {X.shape} | y shape: {y.shape}"
182
+ )
183
+
184
+ # Double-check
185
+ print("First name:", df["name"].iloc[0] if "name" in df.columns else "(no 'name' column)")
186
+ print("First y row:", y[0])
187
+
188
+ # Cell D (REPLACEMENT): Evaluation and training data using five-fold CV
189
+ !pip -q install scikit-learn
190
+
191
+ import numpy as np
192
+ import torch
193
+ import torch.nn as nn
194
+ import torch.optim as optim
195
+ from torch.utils.data import Dataset, DataLoader
196
+ from sklearn.model_selection import KFold
197
+
198
+ # Dataset wrapper (raw y stored; z-scoring is done per fold)
199
+ class AntibodyDatasetRaw(Dataset):
200
+ def __init__(self, X_np, y_np):
201
+ self.X = torch.tensor(X_np, dtype=torch.float32)
202
+ self.y = torch.tensor(y_np, dtype=torch.float32)
203
+ def __len__(self):
204
+ return self.X.shape[0]
205
+ def __getitem__(self, idx):
206
+ return self.X[idx], self.y[idx]
207
+
208
+ def mae_rmse_r2(y_true, y_pred):
209
+ err = y_pred - y_true
210
+ mae = np.mean(np.abs(err), axis=0)
211
+ rmse = np.sqrt(np.mean(err**2, axis=0))
212
+ ss_res = np.sum((y_true - y_pred)**2, axis=0)
213
+ ss_tot = np.sum((y_true - np.mean(y_true, axis=0))**2, axis=0) + 1e-12
214
+ r2 = 1.0 - (ss_res / ss_tot)
215
+ return mae, rmse, r2
216
+
217
+ def train_one_fold(X_train, y_train_raw, X_val, y_val_raw,
218
+ hidden_dims=(128,64), dropout=0.10,
219
+ batch_size=16, max_epochs=200,
220
+ lr=3e-4, weight_decay=1e-4,
221
+ patience=12, min_delta=1e-4):
222
+
223
+
224
+ # ----- z-score targets using TRAIN only (no leakage) -----
225
+ y_mean = y_train_raw.mean(axis=0)
226
+ y_std = y_train_raw.std(axis=0) + 1e-8
227
+
228
+ y_train_z = (y_train_raw - y_mean) / y_std
229
+ y_val_z = (y_val_raw - y_mean) / y_std
230
+
231
+ train_ds = AntibodyDatasetRaw(X_train, y_train_z)
232
+ val_ds = AntibodyDatasetRaw(X_val, y_val_z)
233
+
234
+ train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
235
+ val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
236
+
237
+ # ----- model -----
238
+ model = LiabilityPredictor(
239
+ input_dim=X_train.shape[1],
240
+ hidden_dims=hidden_dims,
241
+ dropout=dropout
242
+ ).to(device)
243
+
244
+ loss_fn = nn.MSELoss()
245
+ optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
246
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(
247
+ optimizer, mode="min", factor=0.5, patience=3, min_lr=1e-5
248
+ )
249
+
250
+ best_val = float("inf")
251
+ best_state = None
252
+ bad = 0
253
+ best_ep = 0
254
+
255
+ def epoch_loss(loader, train: bool):
256
+ model.train() if train else model.eval()
257
+ total, n = 0.0, 0
258
+ for xb, yb in loader:
259
+ xb = xb.to(device)
260
+ yb = yb.to(device)
261
+
262
+ if train:
263
+ optimizer.zero_grad()
264
+
265
+ with torch.set_grad_enabled(train):
266
+ pred = model(xb)
267
+ loss = loss_fn(pred, yb)
268
+ if train:
269
+ loss.backward()
270
+ optimizer.step()
271
+
272
+ bs = xb.size(0)
273
+ total += loss.item() * bs
274
+ n += bs
275
+ return total / max(n, 1)
276
+
277
+ @torch.no_grad()
278
+ def predict_val_raw():
279
+ model.eval()
280
+ preds_z = []
281
+ for xb, _ in val_loader:
282
+ xb = xb.to(device)
283
+ pz = model(xb).cpu().numpy()
284
+ preds_z.append(pz)
285
+ preds_z = np.vstack(preds_z)
286
+ return preds_z * y_std + y_mean
287
+
288
+ # training loop
289
+ train_loss_hist = []
290
+ val_loss_hist = []
291
+
292
+ for ep in range(1, max_epochs + 1):
293
+ tr = epoch_loss(train_loader, True)
294
+ va = epoch_loss(val_loader, False)
295
+ train_loss_hist.append(tr)
296
+ val_loss_hist.append(va)
297
+
298
+ scheduler.step(va)
299
+
300
+ if va < best_val - min_delta:
301
+ best_val = va
302
+ best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
303
+ bad = 0
304
+ else:
305
+ bad += 1
306
+ if bad >= patience:
307
+ break
308
+
309
+ model.load_state_dict(best_state)
310
+
311
+ # Predictions in raw units + metrics
312
+ y_pred_raw = predict_val_raw()
313
+ mae, rmse, r2 = mae_rmse_r2(y_val_raw, y_pred_raw)
314
+
315
+ # Baseline: Predict TRAIN mean in raw units
316
+ base_pred = np.tile(y_mean.reshape(1,-1), (y_val_raw.shape[0], 1))
317
+ b_mae, b_rmse, b_r2 = mae_rmse_r2(y_val_raw, base_pred)
318
+
319
+ return (mae, rmse, r2), (b_mae, b_rmse, b_r2), (train_loss_hist, val_loss_hist)
320
+
321
+
322
+ # Run 5-fold CV
323
+ X_np = X.astype(np.float32)
324
+ y_np = y.astype(np.float32)
325
+
326
+ kf = KFold(n_splits=5, shuffle=True, random_state=42)
327
+
328
+ fold_metrics = []
329
+ fold_baseline = []
330
+ fold_histories = []
331
+
332
+ for fold, (tr_idx, va_idx) in enumerate(kf.split(X_np), start=1):
333
+ X_tr, X_va = X_np[tr_idx], X_np[va_idx]
334
+ y_tr, y_va = y_np[tr_idx], y_np[va_idx]
335
+
336
+ (mae, rmse, r2), (b_mae, b_rmse, b_r2), (tr_hist, va_hist) = train_one_fold(
337
+ X_tr, y_tr, X_va, y_va,
338
+ hidden_dims=(128,64),
339
+ dropout=0.10,
340
+ batch_size=16,
341
+ max_epochs=200,
342
+ lr=3e-4,
343
+ weight_decay=1e-4,
344
+ patience=12
345
+ )
346
+
347
+ fold_metrics.append((mae, rmse, r2))
348
+ fold_baseline.append((b_mae, b_rmse, b_r2))
349
+ fold_histories.append((tr_hist, va_hist))
350
+
351
+ print(f"\nFold {fold}/5")
352
+ print(" NN MAE :", dict(zip(target_cols, mae)))
353
+ print(" NN R2 :", dict(zip(target_cols, r2)))
354
+ print(" BASE MAE:", dict(zip(target_cols, b_mae)))
355
+ print(" BASE R2 :", dict(zip(target_cols, b_r2)))
356
+
357
+ print("\nDone. Run Cell E for plots + summary + final training.")
358
+
359
+ # Cell E: Post-CV plots + conclusion stats + Train final deployment model + Save
360
+ import numpy as np
361
+ import matplotlib.pyplot as plt
362
+ import torch
363
+ import torch.nn as nn
364
+ import torch.optim as optim
365
+ from torch.utils.data import Dataset, DataLoader
366
+ import pandas as pd # Import pandas for nice tables
367
+
368
+ # 1) CV summary plots + conclusions
369
+ K = len(fold_metrics)
370
+ T = len(target_cols)
371
+
372
+ nn_mae = np.stack([m[0] for m in fold_metrics], axis=0) # (K,4)
373
+ nn_rmse= np.stack([m[1] for m in fold_metrics], axis=0)
374
+ nn_r2 = np.stack([m[2] for m in fold_metrics], axis=0)
375
+
376
+ b_mae = np.stack([m[0] for m in fold_baseline], axis=0)
377
+ b_rmse = np.stack([m[1] for m in fold_baseline], axis=0)
378
+ b_r2 = np.stack([m[2] for m in fold_baseline], axis=0)
379
+
380
+ def mean_std(a):
381
+ return a.mean(axis=0), a.std(axis=0)
382
+
383
+ nn_mae_m, nn_mae_s = mean_std(nn_mae)
384
+ nn_r2_m, nn_r2_s = mean_std(nn_r2)
385
+ b_mae_m, b_mae_s = mean_std(b_mae)
386
+ b_r2_m, b_r2_s = mean_std(b_r2)
387
+
388
+ x = np.arange(T)
389
+ w = 0.35
390
+
391
+ plt.figure()
392
+ plt.bar(x - w/2, nn_mae_m, yerr=nn_mae_s, width=w, label="NN")
393
+ plt.bar(x + w/2, b_mae_m, yerr=b_mae_s, width=w, label="Baseline")
394
+ plt.xticks(x, target_cols, rotation=30, ha="right")
395
+ plt.ylabel("MAE (raw units)")
396
+ plt.title("5-Fold CV: MAE per target (mean ± std)")
397
+ plt.legend()
398
+ plt.show()
399
+
400
+ plt.figure()
401
+ plt.bar(x - w/2, nn_r2_m, yerr=nn_r2_s, width=w, label="NN")
402
+ plt.bar(x + w/2, b_r2_m, yerr=b_r2_s, width=w, label="Baseline")
403
+ plt.xticks(x, target_cols, rotation=30, ha="right")
404
+ plt.ylabel("R²")
405
+ plt.title("5-Fold CV: R² per target (mean ± std)")
406
+ plt.legend()
407
+ plt.show()
408
+
409
+ # Worst-target MAE: because you need all four good
410
+ nn_worst_mae = nn_mae.max(axis=1)
411
+ b_worst_mae = b_mae.max(axis=1)
412
+
413
+ print("Worst-target MAE across folds:")
414
+ worst_mae_df = pd.DataFrame({
415
+ 'Metric': ['NN worst-MAE mean ± std', 'BASE worst-MAE mean ± std'],
416
+ 'Value': [f"{nn_worst_mae.mean():.4f} ± {nn_worst_mae.std():.4f}", f"{b_worst_mae.mean():.4f} ± {b_worst_mae.std():.4f}"]
417
+ })
418
+ display(worst_mae_df)
419
+
420
+ print("\nPer-target summary (mean ± std):")
421
+ per_target_summary_data = []
422
+ for i, t in enumerate(target_cols):
423
+ per_target_summary_data.append({
424
+ 'Target': t,
425
+ 'NN MAE': f"{nn_mae_m[i]:.4f}±{nn_mae_s[i]:.4f}",
426
+ 'NN R2': f"{nn_r2_m[i]:.4f}±{nn_r2_s[i]:.4f}",
427
+ 'BASE MAE': f"{b_mae_m[i]:.4f}±{b_mae_s[i]:.4f}",
428
+ 'BASE R2': f"{b_r2_m[i]:.4f}±{b_r2_s[i]:.4f}"
429
+ })
430
+ per_target_df = pd.DataFrame(per_target_summary_data)
431
+ display(per_target_df)
432
+
433
+ print("\nOverall (mean across targets):")
434
+ overall_summary_data = [
435
+ {
436
+ 'Model': 'NN',
437
+ 'MAE_mean': f"{nn_mae_m.mean():.4f} ± {nn_mae_s.mean():.4f}",
438
+ 'R2_mean': f"{nn_r2_m.mean():.4f} ± {nn_r2_s.mean():.4f}"
439
+ },
440
+ {
441
+ 'Model': 'BASE',
442
+ 'MAE_mean': f"{b_mae_m.mean():.4f} ± {b_mae_s.mean():.4f}",
443
+ 'R2_mean': f"{b_r2_m.mean():.4f} ± {b_r2_s.mean():.4f}"
444
+ }
445
+ ]
446
+ overall_df = pd.DataFrame(overall_summary_data)
447
+ display(overall_df)
448
+
449
+ from sklearn.model_selection import train_test_split
450
+ import numpy as np
451
+ import matplotlib.pyplot as plt
452
+ import torch
453
+ import torch.nn as nn
454
+ import torch.optim as optim
455
+ from torch.utils.data import Dataset, DataLoader
456
+
457
+
458
+ import numpy as np
459
+ import matplotlib.pyplot as plt
460
+
461
+ # Safety checks
462
+ if "fold_histories" not in globals() or len(fold_histories) == 0:
463
+ raise ValueError("fold_histories not found or empty. Make sure you appended (tr_hist, va_hist) inside the CV fold loop.")
464
+
465
+ # Determine the minimum number of epochs ran across folds (due to early stopping)
466
+ min_len = min(len(tr) for tr, _ in fold_histories)
467
+ print("CV folds:", len(fold_histories))
468
+ print("Min epochs across folds (truncate to this):", min_len)
469
+ print("Epochs per fold:", [len(tr) for tr, _ in fold_histories])
470
+
471
+ # Truncate each fold to min_len so curves align by epoch index
472
+ tr_mat = np.array([tr[:min_len] for tr, _ in fold_histories], dtype=np.float32) # shape: (K, min_len)
473
+ va_mat = np.array([va[:min_len] for _, va in fold_histories], dtype=np.float32) # shape: (K, min_len)
474
+
475
+ # Compute mean ± std across folds for each epoch
476
+ tr_mean = tr_mat.mean(axis=0)
477
+ tr_std = tr_mat.std(axis=0)
478
+
479
+ va_mean = va_mat.mean(axis=0)
480
+ va_std = va_mat.std(axis=0)
481
+
482
+ # Plot mean curves with ±1 std shading
483
+ x = np.arange(1, min_len + 1)
484
+
485
+ plt.figure()
486
+ plt.plot(x, tr_mean, label="CV train loss (mean)")
487
+ plt.plot(x, va_mean, label="CV val loss (mean)")
488
+ plt.fill_between(x, tr_mean - tr_std, tr_mean + tr_std, alpha=0.2)
489
+ plt.fill_between(x, va_mean - va_std, va_mean + va_std, alpha=0.2)
490
+
491
+ plt.xlabel("Epoch")
492
+ plt.ylabel("MSE in z-space")
493
+ plt.title("5-Fold CV Learning Curves (truncated to min epoch, mean ± std)")
494
+ plt.axhline(1.0, linestyle=":", label="z-space baseline (~1.0)")
495
+ plt.legend()
496
+ plt.show()
497
+
498
+
499
+ # Train deployable model on ALL data
500
+ X_all = X.astype(np.float32)
501
+ y_all = y.astype(np.float32)
502
+
503
+ y_mean_full = y_all.mean(axis=0)
504
+ y_std_full = y_all.std(axis=0) + 1e-8
505
+ y_z_full = (y_all - y_mean_full) / y_std_full
506
+
507
+ class AntibodyDatasetZ(Dataset):
508
+ def __init__(self, X_np, y_z_np):
509
+ self.X = torch.tensor(X_np, dtype=torch.float32)
510
+ self.y = torch.tensor(y_z_np, dtype=torch.float32)
511
+ def __len__(self):
512
+ return len(self.X)
513
+
514
+ def __getitem__(self, idx):
515
+ return self.X[idx], self.y[idx]
516
+
517
+
518
+ ds_full = AntibodyDatasetZ(X_all, y_z_full)
519
+ loader_full = DataLoader(ds_full, batch_size=16, shuffle=True)
520
+
521
+ final_model = LiabilityPredictor(input_dim=640, hidden_dims=(128,64), dropout=0.10).to(device)
522
+ optimizer_final = optim.Adam(final_model.parameters(), lr= 1e-4, weight_decay=1e-4)
523
+
524
+ epochs_final = min_len
525
+
526
+ loss_hist_full = []
527
+
528
+
529
+ loss_fn = nn.MSELoss()
530
+
531
+ final_model.train()
532
+ for ep in range(1, epochs_final+1):
533
+ total, n = 0.0, 0
534
+ for xb, yb in loader_full:
535
+ xb, yb = xb.to(device), yb.to(device)
536
+ optimizer_final.zero_grad()
537
+ pred = final_model(xb)
538
+ loss = loss_fn(pred, yb)
539
+ loss.backward()
540
+ optimizer_final.step()
541
+ total += loss.item() * xb.size(0)
542
+ n += xb.size(0)
543
+ loss_epoch = total / max(n, 1)
544
+ loss_hist_full.append(loss_epoch)
545
+ if ep % 10 == 0 or ep == 1:
546
+ print(f"[FINAL-ALL] Epoch {ep:03d} | train_loss(zMSE) {loss_epoch:.4f}")
547
+
548
+ import numpy as np
549
+ def movavg(x, w=7):
550
+ x = np.array(x)
551
+ if len(x) < w: return x
552
+ return np.convolve(x, np.ones(w)/w, mode="valid")
553
+
554
+ plt.figure()
555
+ plt.plot(np.arange(1, epochs_final+1), loss_hist_full, label="train loss (all data)")
556
+ plt.xlabel("Epoch")
557
+ plt.ylabel("MSE in z-space")
558
+ plt.title("Deployable Model Training Curve (ALL data)")
559
+ plt.legend()
560
+ plt.show()
561
+
562
+ final_artifacts = {
563
+ "state_dict": final_model.state_dict(),
564
+ "y_mean": y_mean_full,
565
+ "y_std": y_std_full,
566
+ "target_cols": target_cols,
567
+ "trained_on": "ALL_DATA_FINAL_MODEL_CELL_E",
568
+ "epochs_final": epochs_final,
569
+ }
570
+
571
+ # Cell F: Plot graphs to visualise loss and accuracy
572
+ import numpy as np
573
+ import matplotlib.pyplot as plt
574
+ import torch
575
+
576
+ print("y_mean:", y_mean_full)
577
+ print("y_std:", y_std_full)
578
+
579
+ final_model.eval()
580
+
581
+ y_true_z_list = []
582
+ y_pred_z_list = []
583
+
584
+ with torch.no_grad():
585
+ for xb, yb in loader_full:
586
+ xb = xb.to(device)
587
+
588
+ pred_z = final_model(xb).cpu().numpy() # (batch, 4) in z-space
589
+ y_pred_z_list.append(pred_z)
590
+
591
+ y_true_z_list.append(yb.numpy()) # (batch, 4) in z-space
592
+
593
+ y_true_z = np.vstack(y_true_z_list)
594
+ y_pred_z = np.vstack(y_pred_z_list)
595
+
596
+ # Unscale HERE
597
+ y_true = y_true_z * y_std_full + y_mean_full
598
+ y_pred = y_pred_z * y_std_full + y_mean_full
599
+
600
+ def pearsonr(a, b):
601
+ a = a - a.mean()
602
+ b = b - b.mean()
603
+ return float((a @ b) / (np.sqrt((a @ a) * (b @ b)) + 1e-12))
604
+
605
+ def spearmanr(a, b):
606
+ ra = a.argsort().argsort().astype(float)
607
+ rb = b.argsort().argsort().astype(float)
608
+ return pearsonr(ra, rb)
609
+
610
+ for j, name in enumerate(target_cols):
611
+ p = pearsonr(y_true[:, j], y_pred[:, j])
612
+ s = spearmanr(y_true[:, j], y_pred[:, j])
613
+
614
+ plt.figure()
615
+ plt.scatter(y_true[:, j], y_pred[:, j])
616
+ lo = min(y_true[:, j].min(), y_pred[:, j].min())
617
+ hi = max(y_true[:, j].max(), y_pred[:, j].max())
618
+ plt.plot([lo, hi], [lo, hi], linestyle="--")
619
+ plt.xlabel(f"True {name}")
620
+ plt.ylabel(f"Predicted {name}")
621
+ plt.title(f"{name} (val) R={p:.2f} ρ={s:.2f}")
622
+ plt.show()
623
+
624
+ import torch
625
+
626
+ artifact = {
627
+ "state_dict": final_model.state_dict(),
628
+ "y_mean": y_mean_full,
629
+ "y_std": y_std_full,
630
+ "target_cols": target_cols,
631
+ "input_dim": 640,
632
+ "hidden_dims": (128, 64),
633
+ "dropout": 0.10,
634
+ }
635
+
636
+ torch.save(artifact, "liability_predictor.pt")
637
+ print("Saved:", "liability_predictor.pt")
638
+
639
+ from google.colab import files
640
+ files.download("liability_predictor.pt")