cla1r3 commited on
Commit
92d62a8
·
verified ·
1 Parent(s): a5fab65

Upload neural_network.py

Browse files
Files changed (1) hide show
  1. neural_network.py +688 -0
neural_network.py ADDED
@@ -0,0 +1,688 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """neural network
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/13Vym7d6JDkWLa9cv9p8h_amR_3uUnGp9
8
+ """
9
+
10
+ # Cell A: Upload training dataset google sheets (CSV file)
11
+ from google.colab import files
12
+ import pandas as pd
13
+ import io
14
+
15
+ uploaded = files.upload()
16
+
17
+ # Cell B: Define liability predictor model
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ class LiabilityPredictor(nn.Module):
22
+ def __init__(
23
+ self,
24
+ input_dim: int = 640, #320 from VH + 320 from VL
25
+ output_dim: int = 4, #One output per liability
26
+ hidden_dims=(128, 64), #Two hidden layers. Layers between input and output.
27
+ dropout: float = 0.10, #Randomly turns off neurons during training (prevents overfitting)
28
+ activation: str = "gelu", #Smooth non-linearity (good for embeddings)
29
+ use_layernorm: bool = True, #Stabilises training
30
+ ):
31
+ super().__init__()
32
+
33
+ #Choose activation function. Converts "gelu" string into actual PyTorch layer.
34
+ act_layer = {
35
+ "relu": nn.ReLU,
36
+ "gelu": nn.GELU,
37
+ "silu": nn.SiLU,
38
+ }.get(activation.lower())
39
+
40
+ if act_layer is None:
41
+ raise ValueError(f"Unknown activation='{activation}'. Use 'relu', 'gelu', or 'silu'.")
42
+
43
+ layers = []
44
+
45
+ if use_layernorm:
46
+ layers.append(nn.LayerNorm(input_dim))
47
+
48
+ prev = input_dim
49
+ for h in hidden_dims:
50
+ layers.append(nn.Linear(prev, h))
51
+ if use_layernorm:
52
+ layers.append(nn.LayerNorm(h))
53
+ layers.append(act_layer())
54
+ if dropout and dropout > 0:
55
+ layers.append(nn.Dropout(dropout))
56
+ prev = h
57
+
58
+ layers.append(nn.Linear(prev, output_dim))
59
+ self.net = nn.Sequential(*layers)
60
+
61
+ self._init_weights()
62
+
63
+ def _init_weights(self): #Xavier initialisation
64
+ # Stable init for small-data regression
65
+ for m in self.modules():
66
+ if isinstance(m, nn.Linear):
67
+ nn.init.xavier_uniform_(m.weight)
68
+ if m.bias is not None:
69
+ nn.init.zeros_(m.bias)
70
+
71
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
72
+ # Guardrails: ensure correct dtype/shape
73
+ if x.dim() == 1:
74
+ x = x.unsqueeze(0) # (640,) -> (1, 640)
75
+ if x.dim() != 2:
76
+ raise ValueError(f"Expected x to have shape (batch, features). Got {tuple(x.shape)}")
77
+
78
+ return self.net(x.float())
79
+
80
+ def enable_mc_dropout(self): #Allows uncertainity estimation by turning dropout on during inference.
81
+ """
82
+ Optional: call before inference if you later want MC-dropout uncertainty.
83
+ Keeps BatchNorm/LayerNorm behavior in eval-like mode but enables Dropout layers.
84
+ """
85
+ for m in self.modules():
86
+ if isinstance(m, nn.Dropout):
87
+ m.train()
88
+
89
+ # Cell C: Create dataset
90
+ import torch
91
+ from torch.utils.data import Dataset
92
+ import pandas as pd
93
+ from transformers import AutoModel, AutoTokenizer
94
+ import numpy as np
95
+
96
+ MODEL_NAME = "facebook/esm2_t6_8M_UR50D"
97
+ CSV_PATH = "trainingdataset - Sheet 1.csv"
98
+
99
+ df = pd.read_csv(CSV_PATH)
100
+
101
+ target_cols = ['polyreactivity', 'hydrophobicity', 'aggregation', 'charge_patch']
102
+ for col in target_cols:
103
+ df[col] = pd.to_numeric(df[col], errors='coerce')
104
+
105
+ df = df.dropna(subset=['VH','VL'] + target_cols).reset_index(drop=True)
106
+
107
+ y = df[target_cols].values
108
+ print("Target order:", target_cols)
109
+ print("Rows kept:", len(df))
110
+
111
+ #Load ESM-2
112
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
113
+ esm_model = AutoModel.from_pretrained(MODEL_NAME)
114
+ esm_model.eval()
115
+
116
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
117
+ esm_model.to(device)
118
+
119
+ hidden_size = esm_model.config.hidden_size
120
+
121
+ #Embedding function
122
+ def embed_sequences_meanpool_residues_only(seqs, batch_size=8):
123
+ """
124
+ Returns a dict: {seq_string: torch.Tensor(shape=(hidden_size,), on CPU)}
125
+ Mean-pools token embeddings over residues ONLY (excludes special tokens like CLS/EOS).
126
+ Uses attention_mask to ignore padding.
127
+ """
128
+ # Deduplicate while preserving order
129
+ unique_seqs = list(dict.fromkeys(seqs))
130
+
131
+ seq_to_vec = {}
132
+ for i in range(0, len(unique_seqs), batch_size):
133
+ batch_seqs = unique_seqs[i:i + batch_size]
134
+
135
+ tokenized = tokenizer(
136
+ batch_seqs,
137
+ return_tensors="pt",
138
+ padding=True,
139
+ truncation=False,
140
+ )
141
+ tokenized = {k: v.to(device) for k, v in tokenized.items()}
142
+
143
+ with torch.inference_mode():
144
+ out = esm_model(**tokenized)
145
+
146
+ token_emb = out.last_hidden_state
147
+ attn = tokenized["attention_mask"].long()
148
+
149
+ mask = attn.clone()
150
+
151
+ mask[:, 0] = 0
152
+
153
+ # Remove EOS at the last real token position for each sequence
154
+ lengths = attn.sum(dim=1) # (B,) counts real tokens incl CLS/EOS
155
+ eos_idx = (lengths - 1).clamp(min=0) # index of last real token
156
+ row_idx = torch.arange(mask.size(0), device=device)
157
+ mask[row_idx, eos_idx] = 0
158
+
159
+ # Mean pool over remaining (residue) tokens
160
+ denom = mask.sum(dim=1).clamp(min=1).unsqueeze(-1) # (B, 1)
161
+ pooled = (token_emb * mask.unsqueeze(-1)).sum(dim=1) / denom # (B, H)
162
+
163
+ pooled = pooled.detach().cpu()
164
+ for s, v in zip(batch_seqs, pooled):
165
+ seq_to_vec[s] = v
166
+
167
+ return seq_to_vec
168
+
169
+ #Embed all VH and VL sequences. Embeds each unique sequence once.
170
+ all_seqs = df["VH"].tolist() + df["VL"].tolist()
171
+ seq_to_vec = embed_sequences_meanpool_residues_only(all_seqs, batch_size=8)
172
+
173
+ X_tensors = []
174
+ for _, row in df.iterrows():
175
+ vh_vec = seq_to_vec[row["VH"]]
176
+ vl_vec = seq_to_vec[row["VL"]]
177
+
178
+ assert vh_vec.shape == (hidden_size,), f"VH vec shape {vh_vec.shape} != ({hidden_size},)"
179
+ assert vl_vec.shape == (hidden_size,), f"VL vec shape {vl_vec.shape} != ({hidden_size},)"
180
+
181
+ #Concatenate VH + VL
182
+ combined_vec = torch.cat([vh_vec, vl_vec], dim=0) # (640,)
183
+ X_tensors.append(combined_vec)
184
+
185
+ X = torch.stack(X_tensors, dim=0).numpy()
186
+ assert X.shape[1] == 2 * hidden_size, f"Expected {2*hidden_size} features, got {X.shape[1]}"
187
+
188
+ assert X.shape[0] == y.shape[0], f"X rows {X.shape[0]} != y rows {y.shape[0]}"
189
+
190
+ #Create dataset object
191
+ class AntibodyDataset(Dataset):
192
+ def __init__(self, X, y):
193
+ self.X = torch.tensor(X, dtype=torch.float32)
194
+ self.y = torch.tensor(y, dtype=torch.float32)
195
+
196
+ def __len__(self):
197
+ return len(self.X)
198
+
199
+ def __getitem__(self, idx):
200
+ return self.X[idx], self.y[idx]
201
+
202
+
203
+ dataset = AntibodyDataset(X, y)
204
+
205
+ print(
206
+ f"Dataset created: {len(dataset)} samples | "
207
+ f"X shape: {X.shape} | y shape: {y.shape}"
208
+ )
209
+
210
+ # double-check
211
+ print("First name:", df["name"].iloc[0] if "name" in df.columns else "(no 'name' column)")
212
+ print("First y row:", y[0])
213
+
214
+ # Cell F (REPLACEMENT): 5-Fold Cross-Validation (with early stopping) + Baseline comparison
215
+ !pip -q install scikit-learn
216
+
217
+ import numpy as np
218
+ import torch
219
+ import torch.nn as nn
220
+ import torch.optim as optim
221
+ from torch.utils.data import Dataset, DataLoader
222
+ from sklearn.model_selection import KFold
223
+
224
+ # ---- Dataset wrapper (raw y stored; z-scoring is done per fold) ----
225
+ class AntibodyDatasetRaw(Dataset):
226
+ def __init__(self, X_np, y_np):
227
+ self.X = torch.tensor(X_np, dtype=torch.float32)
228
+ self.y = torch.tensor(y_np, dtype=torch.float32)
229
+ def __len__(self):
230
+ return self.X.shape[0]
231
+ def __getitem__(self, idx):
232
+ return self.X[idx], self.y[idx]
233
+
234
+ def mae_rmse_r2(y_true, y_pred):
235
+ err = y_pred - y_true
236
+ mae = np.mean(np.abs(err), axis=0)
237
+ rmse = np.sqrt(np.mean(err**2, axis=0))
238
+ ss_res = np.sum((y_true - y_pred)**2, axis=0)
239
+ ss_tot = np.sum((y_true - np.mean(y_true, axis=0))**2, axis=0) + 1e-12
240
+ r2 = 1.0 - (ss_res / ss_tot)
241
+ return mae, rmse, r2
242
+
243
+ def train_one_fold(X_train, y_train_raw, X_val, y_val_raw,
244
+ hidden_dims=(128,64), dropout=0.10,
245
+ batch_size=16, max_epochs=200,
246
+ lr=3e-4, weight_decay=1e-4,
247
+ patience=12, min_delta=1e-4):
248
+
249
+ # ----- z-score targets using TRAIN only (no leakage) -----
250
+ y_mean = y_train_raw.mean(axis=0)
251
+ y_std = y_train_raw.std(axis=0) + 1e-8
252
+
253
+ y_train_z = (y_train_raw - y_mean) / y_std
254
+ y_val_z = (y_val_raw - y_mean) / y_std
255
+
256
+ train_ds = AntibodyDatasetRaw(X_train, y_train_z)
257
+ val_ds = AntibodyDatasetRaw(X_val, y_val_z)
258
+
259
+ train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
260
+ val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
261
+
262
+ # ----- model -----
263
+ model = LiabilityPredictor(
264
+ input_dim=X_train.shape[1],
265
+ hidden_dims=hidden_dims,
266
+ dropout=dropout
267
+ ).to(device)
268
+
269
+ loss_fn = nn.MSELoss()
270
+ optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
271
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(
272
+ optimizer, mode="min", factor=0.5, patience=3, min_lr=1e-5
273
+ )
274
+
275
+ best_val = float("inf")
276
+ best_state = None
277
+ bad = 0
278
+
279
+ def epoch_loss(loader, train: bool):
280
+ model.train() if train else model.eval()
281
+ total, n = 0.0, 0
282
+ for xb, yb in loader:
283
+ xb = xb.to(device)
284
+ yb = yb.to(device)
285
+
286
+ if train:
287
+ optimizer.zero_grad()
288
+
289
+ with torch.set_grad_enabled(train):
290
+ pred = model(xb)
291
+ loss = loss_fn(pred, yb)
292
+ if train:
293
+ loss.backward()
294
+ optimizer.step()
295
+
296
+ bs = xb.size(0)
297
+ total += loss.item() * bs
298
+ n += bs
299
+ return total / max(n, 1)
300
+
301
+ @torch.no_grad()
302
+ def predict_val_raw():
303
+ model.eval()
304
+ preds_z = []
305
+ for xb, _ in val_loader:
306
+ xb = xb.to(device)
307
+ pz = model(xb).cpu().numpy()
308
+ preds_z.append(pz)
309
+ preds_z = np.vstack(preds_z)
310
+ return preds_z * y_std + y_mean
311
+
312
+ # ----- training loop -----
313
+ for ep in range(1, max_epochs + 1):
314
+ tr = epoch_loss(train_loader, True)
315
+ va = epoch_loss(val_loader, False)
316
+ scheduler.step(va)
317
+
318
+ if va < best_val - min_delta:
319
+ best_val = va
320
+ best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
321
+ bad = 0
322
+ else:
323
+ bad += 1
324
+ if bad >= patience:
325
+ break
326
+
327
+ # load best
328
+ model.load_state_dict(best_state)
329
+
330
+ # predictions in raw units + metrics
331
+ y_pred_raw = predict_val_raw()
332
+ mae, rmse, r2 = mae_rmse_r2(y_val_raw, y_pred_raw)
333
+
334
+ # baseline: predict TRAIN mean in raw units
335
+ base_pred = np.tile(y_mean.reshape(1,-1), (y_val_raw.shape[0], 1))
336
+ b_mae, b_rmse, b_r2 = mae_rmse_r2(y_val_raw, base_pred)
337
+
338
+ return (mae, rmse, r2), (b_mae, b_rmse, b_r2)
339
+
340
+ # -----------------------------
341
+ # Run 5-fold CV
342
+ # -----------------------------
343
+ X_np = X.astype(np.float32)
344
+ y_np = y_raw.astype(np.float32)
345
+
346
+ kf = KFold(n_splits=5, shuffle=True, random_state=42)
347
+
348
+ fold_metrics = []
349
+ fold_baseline = []
350
+
351
+ for fold, (tr_idx, va_idx) in enumerate(kf.split(X_np), start=1):
352
+ X_tr, X_va = X_np[tr_idx], X_np[va_idx]
353
+ y_tr, y_va = y_np[tr_idx], y_np[va_idx]
354
+
355
+ (mae, rmse, r2), (b_mae, b_rmse, b_r2) = train_one_fold(
356
+ X_tr, y_tr, X_va, y_va,
357
+ hidden_dims=(128,64),
358
+ dropout=0.10,
359
+ batch_size=16,
360
+ max_epochs=200,
361
+ lr=3e-4,
362
+ weight_decay=1e-4,
363
+ patience=12
364
+ )
365
+
366
+ fold_metrics.append((mae, rmse, r2))
367
+ fold_baseline.append((b_mae, b_rmse, b_r2))
368
+
369
+ print(f"\nFold {fold}/5")
370
+ print(" NN MAE :", dict(zip(target_cols, mae)))
371
+ print(" NN R2 :", dict(zip(target_cols, r2)))
372
+ print(" BASE MAE:", dict(zip(target_cols, b_mae)))
373
+ print(" BASE R2 :", dict(zip(target_cols, b_r2)))
374
+
375
+ print("\nDone. Run Cell G for plots + summary + final training.")
376
+
377
+ # Cell G: Post-CV plots + conclusion stats + Train final deployment model + Save
378
+ import numpy as np
379
+ import matplotlib.pyplot as plt
380
+ import torch
381
+ import torch.nn as nn
382
+ import torch.optim as optim
383
+ from torch.utils.data import Dataset, DataLoader
384
+
385
+ # -----------------------------
386
+ # 1) CV summary plots + conclusions
387
+ # -----------------------------
388
+ K = len(fold_metrics)
389
+ T = len(target_cols)
390
+
391
+ nn_mae = np.stack([m[0] for m in fold_metrics], axis=0) # (K,4)
392
+ nn_rmse= np.stack([m[1] for m in fold_metrics], axis=0)
393
+ nn_r2 = np.stack([m[2] for m in fold_metrics], axis=0)
394
+
395
+ b_mae = np.stack([m[0] for m in fold_baseline], axis=0)
396
+ b_rmse = np.stack([m[1] for m in fold_baseline], axis=0)
397
+ b_r2 = np.stack([m[2] for m in fold_baseline], axis=0)
398
+
399
+ def mean_std(a):
400
+ return a.mean(axis=0), a.std(axis=0)
401
+
402
+ nn_mae_m, nn_mae_s = mean_std(nn_mae)
403
+ nn_r2_m, nn_r2_s = mean_std(nn_r2)
404
+ b_mae_m, b_mae_s = mean_std(b_mae)
405
+ b_r2_m, b_r2_s = mean_std(b_r2)
406
+
407
+ x = np.arange(T)
408
+ w = 0.35
409
+
410
+ plt.figure()
411
+ plt.bar(x - w/2, nn_mae_m, yerr=nn_mae_s, width=w, label="NN")
412
+ plt.bar(x + w/2, b_mae_m, yerr=b_mae_s, width=w, label="Baseline")
413
+ plt.xticks(x, target_cols, rotation=30, ha="right")
414
+ plt.ylabel("MAE (raw units)")
415
+ plt.title("5-Fold CV: MAE per target (mean ± std)")
416
+ plt.legend()
417
+ plt.show()
418
+
419
+ plt.figure()
420
+ plt.bar(x - w/2, nn_r2_m, yerr=nn_r2_s, width=w, label="NN")
421
+ plt.bar(x + w/2, b_r2_m, yerr=b_r2_s, width=w, label="Baseline")
422
+ plt.xticks(x, target_cols, rotation=30, ha="right")
423
+ plt.ylabel("R²")
424
+ plt.title("5-Fold CV: R² per target (mean ± std)")
425
+ plt.legend()
426
+ plt.show()
427
+
428
+ # Worst-target MAE: because you need all four good
429
+ nn_worst_mae = nn_mae.max(axis=1)
430
+ b_worst_mae = b_mae.max(axis=1)
431
+ print("Worst-target MAE across folds:")
432
+ print(f" NN worst-MAE mean ± std: {nn_worst_mae.mean():.4f} ± {nn_worst_mae.std():.4f}")
433
+ print(f" BASE worst-MAE mean ± std: {b_worst_mae.mean():.4f} ± {b_worst_mae.std():.4f}")
434
+
435
+ print("\nPer-target summary (mean ± std):")
436
+ for i, t in enumerate(target_cols):
437
+ print(f"{t:14s} | NN MAE {nn_mae_m[i]:.4f}±{nn_mae_s[i]:.4f} R2 {nn_r2_m[i]:.4f}±{nn_r2_s[i]:.4f} "
438
+ f"|| BASE MAE {b_mae_m[i]:.4f}±{b_mae_s[i]:.4f} R2 {b_r2_m[i]:.4f}±{b_r2_s[i]:.4f}")
439
+
440
+ print("\nOverall (mean across targets):")
441
+ print(f" NN MAE_mean {nn_mae_m.mean():.4f} ± {nn_mae_s.mean():.4f} | R2_mean {nn_r2_m.mean():.4f} ± {nn_r2_s.mean():.4f}")
442
+ print(f" BASE MAE_mean {b_mae_m.mean():.4f} ± {b_mae_s.mean():.4f} | R2_mean {b_r2_m.mean():.4f} ± {b_r2_s.mean():.4f}")
443
+
444
+ # -----------------------------
445
+ # 2) Train final model for deployment (on all data)
446
+ # -----------------------------
447
+ class AntibodyDatasetZ(Dataset):
448
+ def __init__(self, X_np, y_z_np):
449
+ self.X = torch.tensor(X_np, dtype=torch.float32)
450
+ self.y = torch.tensor(y_z_np, dtype=torch.float32)
451
+ def __len__(self): return len(self.X)
452
+ def __getitem__(self, idx): return self.X[idx], self.y[idx]
453
+
454
+ y_mean_full = y_raw.mean(axis=0)
455
+ y_std_full = y_raw.std(axis=0) + 1e-8
456
+ y_z_full = (y_raw - y_mean_full) / y_std_full
457
+
458
+ ds_full = AntibodyDatasetZ(X.astype(np.float32), y_z_full.astype(np.float32))
459
+ loader = DataLoader(ds_full, batch_size=16, shuffle=True)
460
+
461
+ final_model = LiabilityPredictor(input_dim=640, hidden_dims=(128,64), dropout=0.10).to(device)
462
+ loss_fn = nn.MSELoss()
463
+ optimizer = optim.Adam(final_model.parameters(), lr=3e-4, weight_decay=1e-4)
464
+
465
+ epochs = 80
466
+ loss_hist = []
467
+
468
+ final_model.train()
469
+ for ep in range(1, epochs+1):
470
+ total, n = 0.0, 0
471
+ for xb, yb in loader:
472
+ xb, yb = xb.to(device), yb.to(device)
473
+ optimizer.zero_grad()
474
+ pred = final_model(xb)
475
+ loss = loss_fn(pred, yb)
476
+ loss.backward()
477
+ optimizer.step()
478
+ total += loss.item() * xb.size(0)
479
+ n += xb.size(0)
480
+ loss_epoch = total / max(n, 1)
481
+ loss_hist.append(loss_epoch)
482
+ if ep % 10 == 0 or ep == 1:
483
+ print(f"[FINAL] Epoch {ep:03d} | train_loss(zMSE) {loss_epoch:.4f}")
484
+
485
+ plt.figure()
486
+ plt.plot(np.arange(1, epochs+1), loss_hist)
487
+ plt.xlabel("Epoch")
488
+ plt.ylabel("Train MSE in z-space")
489
+ plt.title("Final Model Training Curve (for deployment)")
490
+ plt.show()
491
+
492
+ # Save: model + normalization (critical for inference)
493
+ final_artifacts = {
494
+ "state_dict": final_model.state_dict(),
495
+ "y_mean": y_mean_full,
496
+ "y_std": y_std_full,
497
+ "target_cols": target_cols,
498
+ }
499
+ torch.save(final_artifacts, "liability_predictor_final.pt")
500
+ print("Saved: liability_predictor_final.pt")
501
+ print("y_mean:", dict(zip(target_cols, y_mean_full)))
502
+ print("y_std :", dict(zip(target_cols, y_std_full)))
503
+
504
+ # Option A: Regression performance panel + baseline comparison
505
+ !pip -q install scikit-learn
506
+
507
+ import numpy as np
508
+ import pandas as pd
509
+ import matplotlib.pyplot as plt
510
+
511
+ from sklearn.linear_model import Ridge
512
+ from sklearn.ensemble import RandomForestRegressor
513
+ from sklearn.multioutput import MultiOutputRegressor
514
+ from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
515
+
516
+ # -----------------------------
517
+ # Helpers
518
+ # -----------------------------
519
+ def unz(y_z, y_mean, y_std):
520
+ return y_z * y_std + y_mean
521
+
522
+ def regression_metrics(y_true_raw, y_pred_raw, target_cols):
523
+ mae = mean_absolute_error(y_true_raw, y_pred_raw, multioutput='raw_values')
524
+ rmse = np.sqrt(mean_squared_error(y_true_raw, y_pred_raw, multioutput='raw_values'))
525
+ r2 = np.array([r2_score(y_true_raw[:,i], y_pred_raw[:,i]) for i in range(y_true_raw.shape[1])])
526
+
527
+ out = pd.DataFrame({
528
+ "target": target_cols,
529
+ "MAE": mae,
530
+ "RMSE": rmse,
531
+ "R2": r2
532
+ })
533
+ out.loc["mean"] = ["mean", mae.mean(), rmse.mean(), r2.mean()]
534
+ return out
535
+
536
+ @torch.no_grad()
537
+ def predict_nn_raw(loader, y_mean, y_std):
538
+ model.eval()
539
+ preds_z = []
540
+ trues_z = []
541
+ for xb, yb in loader:
542
+ xb = xb.to(device)
543
+ pred_z = model(xb).cpu().numpy()
544
+ preds_z.append(pred_z)
545
+ trues_z.append(yb.numpy())
546
+ preds_z = np.vstack(preds_z)
547
+ trues_z = np.vstack(trues_z)
548
+ return unz(trues_z, y_mean, y_std), unz(preds_z, y_mean, y_std)
549
+
550
+ # -----------------------------
551
+ # Prepare train/val arrays (raw y!)
552
+ # -----------------------------
553
+ # X is numpy (N,640); y_raw is numpy (N,4) from your Cell E
554
+ X_train = X[train_idx]
555
+ X_val = X[val_idx]
556
+ y_train_raw = y_raw[train_idx]
557
+ y_val_raw = y_raw[val_idx]
558
+
559
+ # -----------------------------
560
+ # Evaluate NN (your trained model already loaded best_state in Cell F)
561
+ # -----------------------------
562
+ y_val_true_nn, y_val_pred_nn = predict_nn_raw(val_loader, y_mean, y_std)
563
+ nn_table = regression_metrics(y_val_true_nn, y_val_pred_nn, target_cols)
564
+ print("\nNeural Network (val):")
565
+ display(nn_table)
566
+
567
+ # -----------------------------
568
+ # Baselines
569
+ # -----------------------------
570
+ ridge = MultiOutputRegressor(Ridge(alpha=10.0, random_state=0))
571
+ ridge.fit(X_train, y_train_raw)
572
+ y_pred_ridge = ridge.predict(X_val)
573
+ ridge_table = regression_metrics(y_val_raw, y_pred_ridge, target_cols)
574
+
575
+ rf = MultiOutputRegressor(RandomForestRegressor(
576
+ n_estimators=600, random_state=0, min_samples_leaf=2
577
+ ))
578
+ rf.fit(X_train, y_train_raw)
579
+ y_pred_rf = rf.predict(X_val)
580
+ rf_table = regression_metrics(y_val_raw, y_pred_rf, target_cols)
581
+
582
+ # -----------------------------
583
+ # Comparison summary (mean row only)
584
+ # -----------------------------
585
+ summary = pd.DataFrame({
586
+ "Model": ["NeuralNet", "Ridge", "RandomForest"],
587
+ "MAE_mean": [nn_table.loc["mean","MAE"], ridge_table.loc["mean","MAE"], rf_table.loc["mean","MAE"]],
588
+ "RMSE_mean": [nn_table.loc["mean","RMSE"], ridge_table.loc["mean","RMSE"], rf_table.loc["mean","RMSE"]],
589
+ "R2_mean": [nn_table.loc["mean","R2"], ridge_table.loc["mean","R2"], rf_table.loc["mean","R2"]],
590
+ })
591
+ print("\nModel comparison (val, mean across targets):")
592
+ display(summary)
593
+
594
+ # -----------------------------
595
+ # Predicted vs True plots for NN (per target)
596
+ # -----------------------------
597
+ for i, t in enumerate(target_cols):
598
+ plt.figure()
599
+ plt.scatter(y_val_true_nn[:, i], y_val_pred_nn[:, i])
600
+ plt.xlabel(f"True {t} (raw)")
601
+ plt.ylabel(f"Predicted {t} (raw)")
602
+ plt.title(f"NN: Predicted vs True ({t})")
603
+ plt.show()
604
+
605
+ # -----------------------------
606
+ # Residual histogram (per target)
607
+ # -----------------------------
608
+ res = y_val_pred_nn - y_val_true_nn
609
+ for i, t in enumerate(target_cols):
610
+ plt.figure()
611
+ plt.hist(res[:, i], bins=12)
612
+ plt.xlabel(f"Residual (Pred - True) for {t}")
613
+ plt.ylabel("Count")
614
+ plt.title(f"NN residuals ({t})")
615
+ plt.show()
616
+
617
+
618
+
619
+ # Cell G: Plot graphs to visualise loss and accuracy
620
+ import numpy as np
621
+ import matplotlib.pyplot as plt
622
+ import torch
623
+
624
+ print("y_mean:", y_mean)
625
+ print("y_std:", y_std)
626
+
627
+ model.eval()
628
+
629
+ y_true_z_list = []
630
+ y_pred_z_list = []
631
+
632
+ with torch.no_grad():
633
+ for xb, yb in val_loader:
634
+ xb = xb.to(device)
635
+
636
+ pred_z = model(xb).cpu().numpy() # (batch, 4) in z-space
637
+ y_pred_z_list.append(pred_z)
638
+
639
+ y_true_z_list.append(yb.numpy()) # (batch, 4) in z-space
640
+
641
+ y_true_z = np.vstack(y_true_z_list)
642
+ y_pred_z = np.vstack(y_pred_z_list)
643
+
644
+ # ---- Unscale HERE ----
645
+ y_true = y_true_z * y_std + y_mean
646
+ y_pred = y_pred_z * y_std + y_mean
647
+
648
+ def pearsonr(a, b):
649
+ a = a - a.mean()
650
+ b = b - b.mean()
651
+ return float((a @ b) / (np.sqrt((a @ a) * (b @ b)) + 1e-12))
652
+
653
+ def spearmanr(a, b):
654
+ ra = a.argsort().argsort().astype(float)
655
+ rb = b.argsort().argsort().astype(float)
656
+ return pearsonr(ra, rb)
657
+
658
+ for j, name in enumerate(target_cols):
659
+ p = pearsonr(y_true[:, j], y_pred[:, j])
660
+ s = spearmanr(y_true[:, j], y_pred[:, j])
661
+
662
+ plt.figure()
663
+ plt.scatter(y_true[:, j], y_pred[:, j])
664
+ lo = min(y_true[:, j].min(), y_pred[:, j].min())
665
+ hi = max(y_true[:, j].max(), y_pred[:, j].max())
666
+ plt.plot([lo, hi], [lo, hi], linestyle="--")
667
+ plt.xlabel(f"True {name}")
668
+ plt.ylabel(f"Predicted {name}")
669
+ plt.title(f"{name} (val) R={p:.2f} ρ={s:.2f}")
670
+ plt.show()
671
+
672
+
673
+
674
+ import torch
675
+ from google.colab import files
676
+
677
+ # Define the path where the model will be saved
678
+ output_model_path = 'liability_predictor.pt'
679
+
680
+ # Save the best model state dictionary
681
+ torch.save(best_state, output_model_path)
682
+
683
+ print(f"Model saved successfully to {output_model_path}")
684
+
685
+ """The model has been saved to `liability_predictor.pt` in your Colab environment. You can now download it to your local computer using the following code cell:"""
686
+
687
+ # Download the saved model to your local computer
688
+ files.download('liability_predictor.pt')