FridayCode's picture
Deploy polymer property prediction model with LFS
c53d10d
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_absolute_error
import joblib
import pandas as pd
import numpy as np
import os
from rdkit import Chem
import random
import math
from rdkit.Chem.Scaffolds import MurckoScaffold
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader as PyGDataLoader
from torch_geometric.nn import GCNConv, GATConv, global_mean_pool
def compute_weighted_score(per_target_maes: dict, sample_counts: dict):
"""
per_target_maes: {'Tg': mae_value, 'FFV':..}
sample_counts: {'Tg': N_tg, ...}
weights = (1/sqrt(N)) normalized to sum 1
"""
keys = list(per_target_maes.keys())
inv_sqrt = np.array([1.0/math.sqrt(sample_counts[k]) for k in keys])
weights = inv_sqrt / inv_sqrt.sum()
score = sum(per_target_maes[k] * w for k,w in zip(keys, weights))
return score, dict(zip(keys, weights))
ATOM_LIST = ["C","H","O","N","S","F","Cl","Br","I","P"]
MAX_DEGREE = 5
def atom_features(atom):
at = atom.GetSymbol()
one_hot = [1.0 if at == a else 0.0 for a in ATOM_LIST]
if not any(one_hot): one_hot.append(1.0)
else: one_hot.append(0.0)
charge = [atom.GetFormalCharge()]
aromatic = [1.0 if atom.GetIsAromatic() else 0.0]
deg = atom.GetDegree()
deg_oh = [1.0 if deg==d else 0.0 for d in range(MAX_DEGREE+1)]
feats = one_hot + charge + aromatic + deg_oh
return np.array(feats, dtype=np.float32)
def mol_to_pyg_data(smiles, global_features=None, y=None):
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
node_feats = [atom_features(a) for a in mol.GetAtoms()]
x = torch.tensor(np.vstack(node_feats), dtype=torch.float)
edge_index = []
edge_attr = []
for bond in mol.GetBonds():
a1 = bond.GetBeginAtomIdx(); a2 = bond.GetEndAtomIdx()
edge_index.append([a1, a2]); edge_index.append([a2, a1])
bt = bond.GetBondType()
bond_type = [0.0,0.0,0.0,0.0]
if bt == Chem.rdchem.BondType.SINGLE: bond_type[0]=1.0
elif bt == Chem.rdchem.BondType.DOUBLE: bond_type[1]=1.0
elif bt == Chem.rdchem.BondType.TRIPLE: bond_type[2]=1.0
elif bt == Chem.rdchem.BondType.AROMATIC: bond_type[3]=1.0
edge_attr.append(bond_type); edge_attr.append(bond_type)
if len(edge_index)==0:
edge_index = [[0,0]]; edge_attr = [[0,0,0,0]]
edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
edge_attr = torch.tensor(np.vstack(edge_attr), dtype=torch.float)
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
if global_features is not None:
gf = torch.tensor(global_features, dtype=torch.float)
if gf.dim()==1: gf = gf.unsqueeze(0) # shape (1, D)
data.global_feats = gf
else:
data.global_feats = torch.zeros(1, dtype=torch.float)
if y is not None:
data.y = torch.tensor([y], dtype=torch.float)
return data
def mol_to_scaffold(smiles):
try:
mol = Chem.MolFromSmiles(smiles)
if mol is None: return None
scaffold = MurckoScaffold.GetScaffoldForMol(mol)
return Chem.MolToSmiles(scaffold, isomericSmiles=False)
except Exception:
return None
def scaffold_fold_assignments(df, n_folds=5, smiles_col="SMILES", seed=42):
random.seed(seed)
scaffolds = {}
for idx, smi in enumerate(df[smiles_col].values):
scaf = mol_to_scaffold(smi)
if scaf is None: scaf = f"EMPTY_{idx}"
scaffolds.setdefault(scaf, []).append(idx)
groups = sorted(scaffolds.items(), key=lambda x: len(x[1]), reverse=True)
fold_sizes = [0]*n_folds
fold_assign = np.zeros(len(df), dtype=int)
for scaf, idxs in groups:
f = int(np.argmin(fold_sizes))
for idx in idxs: fold_assign[idx] = f
fold_sizes[f] += len(idxs)
return fold_assign
class GNNWithGlobalFeats(nn.Module):
def __init__(self, node_in_dim, edge_in_dim, global_in_dim,
gnn_hidden=128, n_gnn_layers=3, mlp_hidden=128, dropout=0.2,
conv_type='gcn'): # NEW: conv_type
super().__init__()
self.global_in_dim = global_in_dim
self.convs = nn.ModuleList()
self.bns = nn.ModuleList() # CHANGED: BatchNorms for node features after conv
in_dim = node_in_dim
for _ in range(n_gnn_layers):
if conv_type == 'gcn':
self.convs.append(GCNConv(in_dim, gnn_hidden))
elif conv_type == 'gat':
# CHANGED: GAT with single head for simplicity
self.convs.append(GATConv(in_dim, gnn_hidden // 1, heads=1, concat=False))
else:
raise ValueError("conv_type must be 'gcn' or 'gat'")
self.bns.append(nn.BatchNorm1d(gnn_hidden))
in_dim = gnn_hidden
self.pool = global_mean_pool
total_in = gnn_hidden + global_in_dim
self.mlp = nn.Sequential(
nn.Linear(total_in, mlp_hidden),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(mlp_hidden, mlp_hidden//2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(mlp_hidden//2, 1)
)
self.dropout = nn.Dropout(dropout)
def forward(self, data):
x, edge_index = data.x, data.edge_index
for conv, bn in zip(self.convs, self.bns):
x = conv(x, edge_index)
x = bn(x) # CHANGED
x = F.relu(x)
x = self.dropout(x)
batch = data.batch if hasattr(data, 'batch') else torch.zeros(x.size(0), dtype=torch.long, device=x.device)
pooled = self.pool(x, batch)
gfeat = data.global_feats.to(pooled.dtype).to(pooled.device)
# same robust handling as before
if gfeat.dim() == 1:
if gfeat.numel() == self.global_in_dim:
gfeat = gfeat.unsqueeze(0).expand(pooled.size(0), -1)
elif gfeat.numel() == pooled.size(0) * self.global_in_dim:
gfeat = gfeat.view(pooled.size(0), self.global_in_dim)
else:
raise ValueError("Unexpected global_feats shape")
elif gfeat.dim() == 2:
if gfeat.size(0) != pooled.size(0) and gfeat.numel() == pooled.size(0) * self.global_in_dim:
gfeat = gfeat.view(pooled.size(0), self.global_in_dim)
elif gfeat.size(0) != pooled.size(0):
gfeat = gfeat.mean(dim=0, keepdim=True).expand(pooled.size(0), -1)
out = self.mlp(torch.cat([pooled, gfeat], dim=1))
return out.view(-1)
def train_one_epoch(model, loader, optimizer, device, loss_fn, clip_norm=None):
model.train()
total_loss = 0.0
for batch in loader:
batch = batch.to(device)
optimizer.zero_grad()
pred = model(batch)
loss = loss_fn(pred, batch.y.view(-1))
loss.backward()
if clip_norm is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), clip_norm) # CHANGED
optimizer.step()
total_loss += loss.item() * batch.num_graphs
return total_loss / len(loader.dataset)
def evaluate_mae_inverse_scaled(model, loader, device, y_scaler):
"""Evaluate MAE but invert target scaling back to original units before MAE."""
model.eval()
ys = []
preds = []
with torch.no_grad():
for batch in loader:
batch = batch.to(device)
pred = model(batch).detach().cpu().numpy().tolist()
y_batch = batch.y.view(-1).detach().cpu().numpy().tolist()
preds.extend(pred)
ys.extend(y_batch)
# inverse transform (y_scaler expects 2D)
preds = np.array(preds).reshape(-1, 1)
ys = np.array(ys).reshape(-1, 1)
preds_orig = y_scaler.inverse_transform(preds).ravel()
ys_orig = y_scaler.inverse_transform(ys).ravel()
return mean_absolute_error(ys_orig, preds_orig), ys_orig, preds_orig
def randomized_smiles(smiles):
mol = Chem.MolFromSmiles(smiles)
if mol is None: return None
return Chem.MolToSmiles(mol, doRandom=True)
def run_single_train_until_target(csv_path,
target_col,
descriptor_cols,
smiles_col="SMILES",
device='cuda' if torch.cuda.is_available() else 'cpu',
seed=42,
max_epochs=1000,
patience=30,
batch_size=32,
conv_type='gcn',
n_augment_small=1,
clip_grad_norm=5.0,
target_mae=None,
tol_rel=0.05,
tol_abs=1e-6,
verbose=True):
"""
Train one model on a single train/val split and stop early when validation MAE
is within tolerance of `target_mae` (or when early-stopping triggers).
Returns saved artifact paths and the achieved val MAE.
"""
random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
df = pd.read_csv(csv_path)
df = df[df[target_col].notna()].reset_index(drop=True)
# Create scaffold folds and pick a single validation fold (fold 0)
fold_assign = scaffold_fold_assignments(df, n_folds=5, smiles_col=smiles_col, seed=seed)
train_idx = [i for i,f in enumerate(fold_assign) if f != 0]
val_idx = [i for i,f in enumerate(fold_assign) if f == 0]
train_df = df.iloc[train_idx].reset_index(drop=True)
val_df = df.iloc[val_idx].reset_index(drop=True)
# Fit scalers on TRAIN only
desc_scaler = StandardScaler().fit(train_df[descriptor_cols].values.astype(float))
y_scaler = StandardScaler().fit(train_df[[target_col]].values.astype(float))
# Transform
X_train = desc_scaler.transform(train_df[descriptor_cols].values.astype(float))
X_val = desc_scaler.transform(val_df[descriptor_cols].values.astype(float))
y_train = y_scaler.transform(train_df[[target_col]].values.astype(float)).ravel()
y_val = y_scaler.transform(val_df[[target_col]].values.astype(float)).ravel()
# Build graph objects
train_data = []
for i in range(len(train_df)):
smi = train_df.loc[i, smiles_col]
d = mol_to_pyg_data(smi, global_features=X_train[i], y=float(y_train[i]))
if d is None:
continue
d.idx = train_idx[i]
d.orig_smiles = smi
train_data.append(d)
# optional augmentation for small train sets
if len(train_data) < 2000 and n_augment_small > 0:
aug_list = []
for d in train_data:
for _ in range(n_augment_small):
rs = randomized_smiles(d.orig_smiles)
if rs is None: continue
aug_d = mol_to_pyg_data(rs, global_features=d.global_feats.detach().cpu().numpy().ravel(), y=d.y.item())
if aug_d is None: continue
aug_d.idx = d.idx
aug_d.orig_smiles = rs
aug_list.append(aug_d)
if aug_list:
train_data += aug_list
if verbose: print(f"Augmented train set with {len(aug_list)} randomized-smiles.")
val_data = []
for i in range(len(val_df)):
smi = val_df.loc[i, smiles_col]
d = mol_to_pyg_data(smi, global_features=X_val[i], y=float(y_val[i]))
if d is None:
continue
d.idx = val_idx[i]
d.orig_smiles = smi
val_data.append(d)
# Dataloaders
train_loader = PyGDataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = PyGDataLoader(val_data, batch_size=batch_size, shuffle=False, num_workers=0)
# model
node_dim = train_data[0].x.shape[1]
edge_dim = train_data[0].edge_attr.shape[1] if hasattr(train_data[0], 'edge_attr') else 0
global_dim = train_data[0].global_feats.shape[0] if train_data[0].global_feats.dim()==1 else train_data[0].global_feats.shape[1]
model = GNNWithGlobalFeats(node_in_dim=node_dim, edge_in_dim=edge_dim, global_in_dim=global_dim,
gnn_hidden=128, n_gnn_layers=3, mlp_hidden=128, dropout=0.2,
conv_type=conv_type).to(device)
loss_fn = nn.SmoothL1Loss()
optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=1e-6)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=7, verbose=verbose, min_lr=1e-6)
best_val_mae = float('inf')
best_state = None
no_improve = 0
best_epoch = -1
# target stopping thresholds
stop_enabled = (target_mae is not None)
if stop_enabled:
tol = max(tol_rel * float(target_mae), tol_abs)
if verbose:
print(f"Target MAE {target_mae} with tolerance {tol} (relative tol {tol_rel}, abs tol {tol_abs})")
for epoch in range(1, max_epochs+1):
train_loss = train_one_epoch(model, train_loader, optimizer, device, loss_fn, clip_norm=clip_grad_norm)
val_mae, ys_orig, preds_orig = evaluate_mae_inverse_scaled(model, val_loader, device, y_scaler)
scheduler.step(val_mae)
if val_mae < best_val_mae - 1e-6:
best_val_mae = val_mae
best_epoch = epoch
no_improve = 0
best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
else:
no_improve += 1
if verbose and (epoch % 5 == 0 or epoch == 1):
print(f"Epoch {epoch:04d} | train_loss {train_loss:.6f} | val_mae_orig {val_mae:.6f} | best_val {best_val_mae:.6f}")
# stop if we reached target MAE within tolerance
if stop_enabled and abs(val_mae - float(target_mae)) <= tol:
if verbose:
print(f"Stopping at epoch {epoch} because val_mae {val_mae:.6f} is within tolerance of target {target_mae}.")
break
# normal early stopping
if no_improve >= patience:
if verbose:
print(f"Early stopping at epoch {epoch} (best_epoch {best_epoch} | best_val_mae {best_val_mae:.6f})")
break
# restore best
if best_state is not None:
model.load_state_dict({k: best_state[k].to(device) for k in best_state})
# final metric
val_mae_final, ys_orig, preds_orig = evaluate_mae_inverse_scaled(model, val_loader, device, y_scaler)
if verbose:
print(f"Final val MAE (orig scale): {val_mae_final:.6f} (best_epoch {best_epoch})")
# save model + scalers (single-model artifacts)
base = f"{target_col}_single"
pkg = {
'state_dict': {k: v.cpu().clone() for k, v in model.state_dict().items()},
'node_dim': node_dim, 'edge_dim': edge_dim, 'global_dim': global_dim,
'gnn_hidden': 128, 'n_gnn_layers': 3, 'mlp_hidden': 128, 'dropout': 0.2,
'conv_type': conv_type, 'val_mae_orig': float(val_mae_final)
}
model_path = f"model_{base}.pt"
desc_path = f"desc_scaler_{base}.pkl"
y_path = f"y_scaler_{base}.pkl"
joblib.dump(desc_scaler, desc_path)
joblib.dump(y_scaler, y_path)
torch.save(pkg, model_path)
if verbose:
print(f"Saved model -> {model_path}; scalers -> {desc_path}, {y_path}")
return {
'model_path': model_path,
'desc_scaler_path': desc_path,
'y_scaler_path': y_path,
'val_mae': val_mae_final,
'epoch': best_epoch,
'preds_val': preds_orig,
'y_val': ys_orig
}
if __name__ == "__main__":
dataset_dir = "/kaggle/input/augmented-polymer-data/results"
targets_to_train = {
# 'Tg': 39.6794,
# 'FFV': 0.0042,
# 'Tc': 0.0211,
'Density': 0.0182,
'Rg': 1.1638
}
results = {}
for t, goal in targets_to_train.items():
csv_path = os.path.join(dataset_dir, f"{t}_data.csv")
print(f"\n=== Training single-model for {t} aiming MAE ~ {goal} ===")
df_tmp = pd.read_csv(csv_path)
descriptor_cols = df_tmp.drop(columns=[t, 'SMILES']).columns.tolist()
joblib.dump(descriptor_cols, f"desc_cols_{t}_single.pkl")
res = run_single_train_until_target(csv_path=csv_path,
target_col=t,
descriptor_cols=descriptor_cols,
smiles_col="SMILES",
device='cuda' if torch.cuda.is_available() else 'cpu',
seed=42,
max_epochs=1000,
patience=40,
batch_size=32,
conv_type='gcn',
n_augment_small=1,
clip_grad_norm=5.0,
target_mae=goal,
tol_rel=0.05,
tol_abs=1e-6,
verbose=True)
results[t] = res
print(f"-> {t} achieved val MAE {res['val_mae']:.6f} (target {goal})")