Neural-network / neural_network.py
cla1r3's picture
Upload neural_network.py
15a4e7c verified
# -*- coding: utf-8 -*-
"""neural network
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/13Vym7d6JDkWLa9cv9p8h_amR_3uUnGp9
"""
# Cell A: Upload training dataset google sheets (CSV file)
from google.colab import files
import pandas as pd
import io
uploaded = files.upload()
# Cell B: Define liability predictor model
import torch
import torch.nn as nn
class LiabilityPredictor(nn.Module):
def __init__(
self,
input_dim: int = 640,
output_dim: int = 4,
hidden_dims=(128, 64),
dropout: float = 0.10,
activation: str = "gelu",
use_layernorm: bool = True,
):
super().__init__()
# Choose activation function. Converts "gelu" string into actual PyTorch layer.
act_layer = {
"relu": nn.ReLU,
"gelu": nn.GELU,
"silu": nn.SiLU,
}.get(activation.lower())
if act_layer is None:
raise ValueError(f"Unknown activation='{activation}'. Use 'relu', 'gelu', or 'silu'.")
layers = []
if use_layernorm:
layers.append(nn.LayerNorm(input_dim))
prev = input_dim
for h in hidden_dims:
layers.append(nn.Linear(prev, h))
if use_layernorm:
layers.append(nn.LayerNorm(h))
layers.append(act_layer())
if dropout and dropout > 0:
layers.append(nn.Dropout(dropout))
prev = h
layers.append(nn.Linear(prev, output_dim))
self.net = nn.Sequential(*layers)
self._init_weights()
def _init_weights(self): #Xavier initialisation
# Stable init for small-data regression
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Guardrails: ensure correct dtype/shape
if x.dim() == 1:
x = x.unsqueeze(0) # (640,) -> (1, 640)
if x.dim() != 2:
raise ValueError(f"Expected x to have shape (batch, features). Got {tuple(x.shape)}")
return self.net(x.float())
# Cell C: Create dataset
import torch
from torch.utils.data import Dataset
import pandas as pd
from transformers import AutoModel, AutoTokenizer
import numpy as np
MODEL_NAME = "facebook/esm2_t6_8M_UR50D"
CSV_PATH = "trainingdataset - Sheet 1.csv"
df = pd.read_csv(CSV_PATH)
target_cols = ['polyreactivity', 'hydrophobicity', 'aggregation', 'charge_patch']
for col in target_cols:
df[col] = pd.to_numeric(df[col], errors='coerce')
df = df.dropna(subset=['VH','VL'] + target_cols).reset_index(drop=True)
y = df[target_cols].values
print("Target order:", target_cols)
print("Rows kept:", len(df))
# Load ESM-2
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
esm_model = AutoModel.from_pretrained(MODEL_NAME)
esm_model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
esm_model.to(device)
hidden_size = esm_model.config.hidden_size
def embed_sequences_meanpool_scoring_style(seqs, batch_size=8):
unique_seqs = list(dict.fromkeys(seqs))
seq_to_vec = {}
for i in range(0, len(unique_seqs), batch_size):
batch_seqs = unique_seqs[i:i + batch_size]
tokenized = tokenizer(
batch_seqs,
return_tensors="pt",
padding=True,
truncation=True,
)
tokenized = {k: v.to(device) for k, v in tokenized.items()}
with torch.inference_mode():
out = esm_model(**tokenized)
token_emb = out.last_hidden_state
attn = tokenized["attention_mask"].float()
pooled = (token_emb * attn.unsqueeze(-1)).sum(dim=1)
pooled = pooled / attn.sum(dim=1).clamp(min=1).unsqueeze(-1)
pooled = pooled.detach().cpu()
for s, v in zip(batch_seqs, pooled):
seq_to_vec[s] = v
return seq_to_vec
all_seqs = df["VH"].tolist() + df["VL"].tolist()
seq_to_vec = embed_sequences_meanpool_scoring_style(all_seqs, batch_size=8)
X_tensors = []
for _, row in df.iterrows():
vh_vec = seq_to_vec[row["VH"]]
vl_vec = seq_to_vec[row["VL"]]
assert vh_vec.shape == (hidden_size,), f"VH vec shape {vh_vec.shape} != ({hidden_size},)"
assert vl_vec.shape == (hidden_size,), f"VL vec shape {vl_vec.shape} != ({hidden_size},)"
# Concatenate VH + VL
combined_vec = torch.cat([vh_vec, vl_vec], dim=0) # (640,)
X_tensors.append(combined_vec)
X = torch.stack(X_tensors, dim=0).numpy()
assert X.shape[1] == 2 * hidden_size, f"Expected {2*hidden_size} features, got {X.shape[1]}"
assert X.shape[0] == y.shape[0], f"X rows {X.shape[0]} != y rows {y.shape[0]}"
# Create dataset object
class AntibodyDataset(Dataset):
def __init__(self, X, y):
self.X = torch.tensor(X, dtype=torch.float32)
self.y = torch.tensor(y, dtype=torch.float32)
def __len__(self):
return len(self.X)
def __getitem__(self, idx):
return self.X[idx], self.y[idx]
dataset = AntibodyDataset(X, y)
print(
f"Dataset created: {len(dataset)} samples | "
f"X shape: {X.shape} | y shape: {y.shape}"
)
# Double-check
print("First name:", df["name"].iloc[0] if "name" in df.columns else "(no 'name' column)")
print("First y row:", y[0])
# Cell D (REPLACEMENT): Evaluation and training data using five-fold CV
!pip -q install scikit-learn
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import KFold
# Dataset wrapper (raw y stored; z-scoring is done per fold)
class AntibodyDatasetRaw(Dataset):
def __init__(self, X_np, y_np):
self.X = torch.tensor(X_np, dtype=torch.float32)
self.y = torch.tensor(y_np, dtype=torch.float32)
def __len__(self):
return self.X.shape[0]
def __getitem__(self, idx):
return self.X[idx], self.y[idx]
def mae_rmse_r2(y_true, y_pred):
err = y_pred - y_true
mae = np.mean(np.abs(err), axis=0)
rmse = np.sqrt(np.mean(err**2, axis=0))
ss_res = np.sum((y_true - y_pred)**2, axis=0)
ss_tot = np.sum((y_true - np.mean(y_true, axis=0))**2, axis=0) + 1e-12
r2 = 1.0 - (ss_res / ss_tot)
return mae, rmse, r2
def train_one_fold(X_train, y_train_raw, X_val, y_val_raw,
hidden_dims=(128,64), dropout=0.10,
batch_size=16, max_epochs=200,
lr=3e-4, weight_decay=1e-4,
patience=12, min_delta=1e-4):
# ----- z-score targets using TRAIN only (no leakage) -----
y_mean = y_train_raw.mean(axis=0)
y_std = y_train_raw.std(axis=0) + 1e-8
y_train_z = (y_train_raw - y_mean) / y_std
y_val_z = (y_val_raw - y_mean) / y_std
train_ds = AntibodyDatasetRaw(X_train, y_train_z)
val_ds = AntibodyDatasetRaw(X_val, y_val_z)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
# ----- model -----
model = LiabilityPredictor(
input_dim=X_train.shape[1],
hidden_dims=hidden_dims,
dropout=dropout
).to(device)
loss_fn = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", factor=0.5, patience=3, min_lr=1e-5
)
best_val = float("inf")
best_state = None
bad = 0
best_ep = 0
def epoch_loss(loader, train: bool):
model.train() if train else model.eval()
total, n = 0.0, 0
for xb, yb in loader:
xb = xb.to(device)
yb = yb.to(device)
if train:
optimizer.zero_grad()
with torch.set_grad_enabled(train):
pred = model(xb)
loss = loss_fn(pred, yb)
if train:
loss.backward()
optimizer.step()
bs = xb.size(0)
total += loss.item() * bs
n += bs
return total / max(n, 1)
@torch.no_grad()
def predict_val_raw():
model.eval()
preds_z = []
for xb, _ in val_loader:
xb = xb.to(device)
pz = model(xb).cpu().numpy()
preds_z.append(pz)
preds_z = np.vstack(preds_z)
return preds_z * y_std + y_mean
# training loop
train_loss_hist = []
val_loss_hist = []
for ep in range(1, max_epochs + 1):
tr = epoch_loss(train_loader, True)
va = epoch_loss(val_loader, False)
train_loss_hist.append(tr)
val_loss_hist.append(va)
scheduler.step(va)
if va < best_val - min_delta:
best_val = va
best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
bad = 0
else:
bad += 1
if bad >= patience:
break
model.load_state_dict(best_state)
# Predictions in raw units + metrics
y_pred_raw = predict_val_raw()
mae, rmse, r2 = mae_rmse_r2(y_val_raw, y_pred_raw)
# Baseline: Predict TRAIN mean in raw units
base_pred = np.tile(y_mean.reshape(1,-1), (y_val_raw.shape[0], 1))
b_mae, b_rmse, b_r2 = mae_rmse_r2(y_val_raw, base_pred)
return (mae, rmse, r2), (b_mae, b_rmse, b_r2), (train_loss_hist, val_loss_hist)
# Run 5-fold CV
X_np = X.astype(np.float32)
y_np = y.astype(np.float32)
kf = KFold(n_splits=5, shuffle=True, random_state=42)
fold_metrics = []
fold_baseline = []
fold_histories = []
for fold, (tr_idx, va_idx) in enumerate(kf.split(X_np), start=1):
X_tr, X_va = X_np[tr_idx], X_np[va_idx]
y_tr, y_va = y_np[tr_idx], y_np[va_idx]
(mae, rmse, r2), (b_mae, b_rmse, b_r2), (tr_hist, va_hist) = train_one_fold(
X_tr, y_tr, X_va, y_va,
hidden_dims=(128,64),
dropout=0.10,
batch_size=16,
max_epochs=200,
lr=3e-4,
weight_decay=1e-4,
patience=12
)
fold_metrics.append((mae, rmse, r2))
fold_baseline.append((b_mae, b_rmse, b_r2))
fold_histories.append((tr_hist, va_hist))
print(f"\nFold {fold}/5")
print(" NN MAE :", dict(zip(target_cols, mae)))
print(" NN R2 :", dict(zip(target_cols, r2)))
print(" BASE MAE:", dict(zip(target_cols, b_mae)))
print(" BASE R2 :", dict(zip(target_cols, b_r2)))
print("\nDone. Run Cell E for plots + summary + final training.")
# Cell E: Post-CV plots + conclusion stats + Train final deployment model + Save
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd # Import pandas for nice tables
# 1) CV summary plots + conclusions
K = len(fold_metrics)
T = len(target_cols)
nn_mae = np.stack([m[0] for m in fold_metrics], axis=0) # (K,4)
nn_rmse= np.stack([m[1] for m in fold_metrics], axis=0)
nn_r2 = np.stack([m[2] for m in fold_metrics], axis=0)
b_mae = np.stack([m[0] for m in fold_baseline], axis=0)
b_rmse = np.stack([m[1] for m in fold_baseline], axis=0)
b_r2 = np.stack([m[2] for m in fold_baseline], axis=0)
def mean_std(a):
return a.mean(axis=0), a.std(axis=0)
nn_mae_m, nn_mae_s = mean_std(nn_mae)
nn_r2_m, nn_r2_s = mean_std(nn_r2)
b_mae_m, b_mae_s = mean_std(b_mae)
b_r2_m, b_r2_s = mean_std(b_r2)
x = np.arange(T)
w = 0.35
plt.figure()
plt.bar(x - w/2, nn_mae_m, yerr=nn_mae_s, width=w, label="NN")
plt.bar(x + w/2, b_mae_m, yerr=b_mae_s, width=w, label="Baseline")
plt.xticks(x, target_cols, rotation=30, ha="right")
plt.ylabel("MAE (raw units)")
plt.title("5-Fold CV: MAE per target (mean ± std)")
plt.legend()
plt.show()
plt.figure()
plt.bar(x - w/2, nn_r2_m, yerr=nn_r2_s, width=w, label="NN")
plt.bar(x + w/2, b_r2_m, yerr=b_r2_s, width=w, label="Baseline")
plt.xticks(x, target_cols, rotation=30, ha="right")
plt.ylabel("R²")
plt.title("5-Fold CV: R² per target (mean ± std)")
plt.legend()
plt.show()
# Worst-target MAE: because you need all four good
nn_worst_mae = nn_mae.max(axis=1)
b_worst_mae = b_mae.max(axis=1)
print("Worst-target MAE across folds:")
worst_mae_df = pd.DataFrame({
'Metric': ['NN worst-MAE mean ± std', 'BASE worst-MAE mean ± std'],
'Value': [f"{nn_worst_mae.mean():.4f} ± {nn_worst_mae.std():.4f}", f"{b_worst_mae.mean():.4f} ± {b_worst_mae.std():.4f}"]
})
display(worst_mae_df)
print("\nPer-target summary (mean ± std):")
per_target_summary_data = []
for i, t in enumerate(target_cols):
per_target_summary_data.append({
'Target': t,
'NN MAE': f"{nn_mae_m[i]:.4f}±{nn_mae_s[i]:.4f}",
'NN R2': f"{nn_r2_m[i]:.4f}±{nn_r2_s[i]:.4f}",
'BASE MAE': f"{b_mae_m[i]:.4f}±{b_mae_s[i]:.4f}",
'BASE R2': f"{b_r2_m[i]:.4f}±{b_r2_s[i]:.4f}"
})
per_target_df = pd.DataFrame(per_target_summary_data)
display(per_target_df)
print("\nOverall (mean across targets):")
overall_summary_data = [
{
'Model': 'NN',
'MAE_mean': f"{nn_mae_m.mean():.4f} ± {nn_mae_s.mean():.4f}",
'R2_mean': f"{nn_r2_m.mean():.4f} ± {nn_r2_s.mean():.4f}"
},
{
'Model': 'BASE',
'MAE_mean': f"{b_mae_m.mean():.4f} ± {b_mae_s.mean():.4f}",
'R2_mean': f"{b_r2_m.mean():.4f} ± {b_r2_s.mean():.4f}"
}
]
overall_df = pd.DataFrame(overall_summary_data)
display(overall_df)
from sklearn.model_selection import train_test_split
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
# Safety checks
if "fold_histories" not in globals() or len(fold_histories) == 0:
raise ValueError("fold_histories not found or empty. Make sure you appended (tr_hist, va_hist) inside the CV fold loop.")
# Determine the minimum number of epochs ran across folds (due to early stopping)
min_len = min(len(tr) for tr, _ in fold_histories)
print("CV folds:", len(fold_histories))
print("Min epochs across folds (truncate to this):", min_len)
print("Epochs per fold:", [len(tr) for tr, _ in fold_histories])
# Truncate each fold to min_len so curves align by epoch index
tr_mat = np.array([tr[:min_len] for tr, _ in fold_histories], dtype=np.float32) # shape: (K, min_len)
va_mat = np.array([va[:min_len] for _, va in fold_histories], dtype=np.float32) # shape: (K, min_len)
# Compute mean ± std across folds for each epoch
tr_mean = tr_mat.mean(axis=0)
tr_std = tr_mat.std(axis=0)
va_mean = va_mat.mean(axis=0)
va_std = va_mat.std(axis=0)
# Plot mean curves with ±1 std shading
x = np.arange(1, min_len + 1)
plt.figure()
plt.plot(x, tr_mean, label="CV train loss (mean)")
plt.plot(x, va_mean, label="CV val loss (mean)")
plt.fill_between(x, tr_mean - tr_std, tr_mean + tr_std, alpha=0.2)
plt.fill_between(x, va_mean - va_std, va_mean + va_std, alpha=0.2)
plt.xlabel("Epoch")
plt.ylabel("MSE in z-space")
plt.title("5-Fold CV Learning Curves (truncated to min epoch, mean ± std)")
plt.axhline(1.0, linestyle=":", label="z-space baseline (~1.0)")
plt.legend()
plt.show()
# Train deployable model on ALL data
X_all = X.astype(np.float32)
y_all = y.astype(np.float32)
y_mean_full = y_all.mean(axis=0)
y_std_full = y_all.std(axis=0) + 1e-8
y_z_full = (y_all - y_mean_full) / y_std_full
class AntibodyDatasetZ(Dataset):
def __init__(self, X_np, y_z_np):
self.X = torch.tensor(X_np, dtype=torch.float32)
self.y = torch.tensor(y_z_np, dtype=torch.float32)
def __len__(self):
return len(self.X)
def __getitem__(self, idx):
return self.X[idx], self.y[idx]
ds_full = AntibodyDatasetZ(X_all, y_z_full)
loader_full = DataLoader(ds_full, batch_size=16, shuffle=True)
final_model = LiabilityPredictor(input_dim=640, hidden_dims=(128,64), dropout=0.10).to(device)
optimizer_final = optim.Adam(final_model.parameters(), lr= 1e-4, weight_decay=1e-4)
epochs_final = min_len
loss_hist_full = []
loss_fn = nn.MSELoss()
final_model.train()
for ep in range(1, epochs_final+1):
total, n = 0.0, 0
for xb, yb in loader_full:
xb, yb = xb.to(device), yb.to(device)
optimizer_final.zero_grad()
pred = final_model(xb)
loss = loss_fn(pred, yb)
loss.backward()
optimizer_final.step()
total += loss.item() * xb.size(0)
n += xb.size(0)
loss_epoch = total / max(n, 1)
loss_hist_full.append(loss_epoch)
if ep % 10 == 0 or ep == 1:
print(f"[FINAL-ALL] Epoch {ep:03d} | train_loss(zMSE) {loss_epoch:.4f}")
import numpy as np
def movavg(x, w=7):
x = np.array(x)
if len(x) < w: return x
return np.convolve(x, np.ones(w)/w, mode="valid")
plt.figure()
plt.plot(np.arange(1, epochs_final+1), loss_hist_full, label="train loss (all data)")
plt.xlabel("Epoch")
plt.ylabel("MSE in z-space")
plt.title("Deployable Model Training Curve (ALL data)")
plt.legend()
plt.show()
final_artifacts = {
"state_dict": final_model.state_dict(),
"y_mean": y_mean_full,
"y_std": y_std_full,
"target_cols": target_cols,
"trained_on": "ALL_DATA_FINAL_MODEL_CELL_E",
"epochs_final": epochs_final,
}
# Cell F: Plot graphs to visualise loss and accuracy
import numpy as np
import matplotlib.pyplot as plt
import torch
print("y_mean:", y_mean_full)
print("y_std:", y_std_full)
final_model.eval()
y_true_z_list = []
y_pred_z_list = []
with torch.no_grad():
for xb, yb in loader_full:
xb = xb.to(device)
pred_z = final_model(xb).cpu().numpy() # (batch, 4) in z-space
y_pred_z_list.append(pred_z)
y_true_z_list.append(yb.numpy()) # (batch, 4) in z-space
y_true_z = np.vstack(y_true_z_list)
y_pred_z = np.vstack(y_pred_z_list)
# Unscale HERE
y_true = y_true_z * y_std_full + y_mean_full
y_pred = y_pred_z * y_std_full + y_mean_full
def pearsonr(a, b):
a = a - a.mean()
b = b - b.mean()
return float((a @ b) / (np.sqrt((a @ a) * (b @ b)) + 1e-12))
def spearmanr(a, b):
ra = a.argsort().argsort().astype(float)
rb = b.argsort().argsort().astype(float)
return pearsonr(ra, rb)
for j, name in enumerate(target_cols):
p = pearsonr(y_true[:, j], y_pred[:, j])
s = spearmanr(y_true[:, j], y_pred[:, j])
plt.figure()
plt.scatter(y_true[:, j], y_pred[:, j])
lo = min(y_true[:, j].min(), y_pred[:, j].min())
hi = max(y_true[:, j].max(), y_pred[:, j].max())
plt.plot([lo, hi], [lo, hi], linestyle="--")
plt.xlabel(f"True {name}")
plt.ylabel(f"Predicted {name}")
plt.title(f"{name} (val) R={p:.2f} ρ={s:.2f}")
plt.show()
import torch
artifact = {
"state_dict": final_model.state_dict(),
"y_mean": y_mean_full,
"y_std": y_std_full,
"target_cols": target_cols,
"input_dim": 640,
"hidden_dims": (128, 64),
"dropout": 0.10,
}
torch.save(artifact, "liability_predictor.pt")
print("Saved:", "liability_predictor.pt")
from google.colab import files
files.download("liability_predictor.pt")