Spaces:
Configuration error
Configuration error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.optim import AdamW | |
| from sklearn.preprocessing import StandardScaler | |
| from sklearn.metrics import mean_absolute_error | |
| import joblib | |
| import pandas as pd | |
| import numpy as np | |
| import os | |
| from rdkit import Chem | |
| import random | |
| import math | |
| from rdkit.Chem.Scaffolds import MurckoScaffold | |
| from torch_geometric.data import Data | |
| from torch_geometric.loader import DataLoader as PyGDataLoader | |
| from torch_geometric.nn import GCNConv, GATConv, global_mean_pool | |
| def compute_weighted_score(per_target_maes: dict, sample_counts: dict): | |
| """ | |
| per_target_maes: {'Tg': mae_value, 'FFV':..} | |
| sample_counts: {'Tg': N_tg, ...} | |
| weights = (1/sqrt(N)) normalized to sum 1 | |
| """ | |
| keys = list(per_target_maes.keys()) | |
| inv_sqrt = np.array([1.0/math.sqrt(sample_counts[k]) for k in keys]) | |
| weights = inv_sqrt / inv_sqrt.sum() | |
| score = sum(per_target_maes[k] * w for k,w in zip(keys, weights)) | |
| return score, dict(zip(keys, weights)) | |
| ATOM_LIST = ["C","H","O","N","S","F","Cl","Br","I","P"] | |
| MAX_DEGREE = 5 | |
| def atom_features(atom): | |
| at = atom.GetSymbol() | |
| one_hot = [1.0 if at == a else 0.0 for a in ATOM_LIST] | |
| if not any(one_hot): one_hot.append(1.0) | |
| else: one_hot.append(0.0) | |
| charge = [atom.GetFormalCharge()] | |
| aromatic = [1.0 if atom.GetIsAromatic() else 0.0] | |
| deg = atom.GetDegree() | |
| deg_oh = [1.0 if deg==d else 0.0 for d in range(MAX_DEGREE+1)] | |
| feats = one_hot + charge + aromatic + deg_oh | |
| return np.array(feats, dtype=np.float32) | |
| def mol_to_pyg_data(smiles, global_features=None, y=None): | |
| mol = Chem.MolFromSmiles(smiles) | |
| if mol is None: | |
| return None | |
| node_feats = [atom_features(a) for a in mol.GetAtoms()] | |
| x = torch.tensor(np.vstack(node_feats), dtype=torch.float) | |
| edge_index = [] | |
| edge_attr = [] | |
| for bond in mol.GetBonds(): | |
| a1 = bond.GetBeginAtomIdx(); a2 = bond.GetEndAtomIdx() | |
| edge_index.append([a1, a2]); edge_index.append([a2, a1]) | |
| bt = bond.GetBondType() | |
| bond_type = [0.0,0.0,0.0,0.0] | |
| if bt == Chem.rdchem.BondType.SINGLE: bond_type[0]=1.0 | |
| elif bt == Chem.rdchem.BondType.DOUBLE: bond_type[1]=1.0 | |
| elif bt == Chem.rdchem.BondType.TRIPLE: bond_type[2]=1.0 | |
| elif bt == Chem.rdchem.BondType.AROMATIC: bond_type[3]=1.0 | |
| edge_attr.append(bond_type); edge_attr.append(bond_type) | |
| if len(edge_index)==0: | |
| edge_index = [[0,0]]; edge_attr = [[0,0,0,0]] | |
| edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous() | |
| edge_attr = torch.tensor(np.vstack(edge_attr), dtype=torch.float) | |
| data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr) | |
| if global_features is not None: | |
| gf = torch.tensor(global_features, dtype=torch.float) | |
| if gf.dim()==1: gf = gf.unsqueeze(0) # shape (1, D) | |
| data.global_feats = gf | |
| else: | |
| data.global_feats = torch.zeros(1, dtype=torch.float) | |
| if y is not None: | |
| data.y = torch.tensor([y], dtype=torch.float) | |
| return data | |
| def mol_to_scaffold(smiles): | |
| try: | |
| mol = Chem.MolFromSmiles(smiles) | |
| if mol is None: return None | |
| scaffold = MurckoScaffold.GetScaffoldForMol(mol) | |
| return Chem.MolToSmiles(scaffold, isomericSmiles=False) | |
| except Exception: | |
| return None | |
| def scaffold_fold_assignments(df, n_folds=5, smiles_col="SMILES", seed=42): | |
| random.seed(seed) | |
| scaffolds = {} | |
| for idx, smi in enumerate(df[smiles_col].values): | |
| scaf = mol_to_scaffold(smi) | |
| if scaf is None: scaf = f"EMPTY_{idx}" | |
| scaffolds.setdefault(scaf, []).append(idx) | |
| groups = sorted(scaffolds.items(), key=lambda x: len(x[1]), reverse=True) | |
| fold_sizes = [0]*n_folds | |
| fold_assign = np.zeros(len(df), dtype=int) | |
| for scaf, idxs in groups: | |
| f = int(np.argmin(fold_sizes)) | |
| for idx in idxs: fold_assign[idx] = f | |
| fold_sizes[f] += len(idxs) | |
| return fold_assign | |
| class GNNWithGlobalFeats(nn.Module): | |
| def __init__(self, node_in_dim, edge_in_dim, global_in_dim, | |
| gnn_hidden=128, n_gnn_layers=3, mlp_hidden=128, dropout=0.2, | |
| conv_type='gcn'): # NEW: conv_type | |
| super().__init__() | |
| self.global_in_dim = global_in_dim | |
| self.convs = nn.ModuleList() | |
| self.bns = nn.ModuleList() # CHANGED: BatchNorms for node features after conv | |
| in_dim = node_in_dim | |
| for _ in range(n_gnn_layers): | |
| if conv_type == 'gcn': | |
| self.convs.append(GCNConv(in_dim, gnn_hidden)) | |
| elif conv_type == 'gat': | |
| # CHANGED: GAT with single head for simplicity | |
| self.convs.append(GATConv(in_dim, gnn_hidden // 1, heads=1, concat=False)) | |
| else: | |
| raise ValueError("conv_type must be 'gcn' or 'gat'") | |
| self.bns.append(nn.BatchNorm1d(gnn_hidden)) | |
| in_dim = gnn_hidden | |
| self.pool = global_mean_pool | |
| total_in = gnn_hidden + global_in_dim | |
| self.mlp = nn.Sequential( | |
| nn.Linear(total_in, mlp_hidden), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(mlp_hidden, mlp_hidden//2), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(mlp_hidden//2, 1) | |
| ) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, data): | |
| x, edge_index = data.x, data.edge_index | |
| for conv, bn in zip(self.convs, self.bns): | |
| x = conv(x, edge_index) | |
| x = bn(x) # CHANGED | |
| x = F.relu(x) | |
| x = self.dropout(x) | |
| batch = data.batch if hasattr(data, 'batch') else torch.zeros(x.size(0), dtype=torch.long, device=x.device) | |
| pooled = self.pool(x, batch) | |
| gfeat = data.global_feats.to(pooled.dtype).to(pooled.device) | |
| # same robust handling as before | |
| if gfeat.dim() == 1: | |
| if gfeat.numel() == self.global_in_dim: | |
| gfeat = gfeat.unsqueeze(0).expand(pooled.size(0), -1) | |
| elif gfeat.numel() == pooled.size(0) * self.global_in_dim: | |
| gfeat = gfeat.view(pooled.size(0), self.global_in_dim) | |
| else: | |
| raise ValueError("Unexpected global_feats shape") | |
| elif gfeat.dim() == 2: | |
| if gfeat.size(0) != pooled.size(0) and gfeat.numel() == pooled.size(0) * self.global_in_dim: | |
| gfeat = gfeat.view(pooled.size(0), self.global_in_dim) | |
| elif gfeat.size(0) != pooled.size(0): | |
| gfeat = gfeat.mean(dim=0, keepdim=True).expand(pooled.size(0), -1) | |
| out = self.mlp(torch.cat([pooled, gfeat], dim=1)) | |
| return out.view(-1) | |
| def train_one_epoch(model, loader, optimizer, device, loss_fn, clip_norm=None): | |
| model.train() | |
| total_loss = 0.0 | |
| for batch in loader: | |
| batch = batch.to(device) | |
| optimizer.zero_grad() | |
| pred = model(batch) | |
| loss = loss_fn(pred, batch.y.view(-1)) | |
| loss.backward() | |
| if clip_norm is not None: | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), clip_norm) # CHANGED | |
| optimizer.step() | |
| total_loss += loss.item() * batch.num_graphs | |
| return total_loss / len(loader.dataset) | |
| def evaluate_mae_inverse_scaled(model, loader, device, y_scaler): | |
| """Evaluate MAE but invert target scaling back to original units before MAE.""" | |
| model.eval() | |
| ys = [] | |
| preds = [] | |
| with torch.no_grad(): | |
| for batch in loader: | |
| batch = batch.to(device) | |
| pred = model(batch).detach().cpu().numpy().tolist() | |
| y_batch = batch.y.view(-1).detach().cpu().numpy().tolist() | |
| preds.extend(pred) | |
| ys.extend(y_batch) | |
| # inverse transform (y_scaler expects 2D) | |
| preds = np.array(preds).reshape(-1, 1) | |
| ys = np.array(ys).reshape(-1, 1) | |
| preds_orig = y_scaler.inverse_transform(preds).ravel() | |
| ys_orig = y_scaler.inverse_transform(ys).ravel() | |
| return mean_absolute_error(ys_orig, preds_orig), ys_orig, preds_orig | |
| def randomized_smiles(smiles): | |
| mol = Chem.MolFromSmiles(smiles) | |
| if mol is None: return None | |
| return Chem.MolToSmiles(mol, doRandom=True) | |
| def run_single_train_until_target(csv_path, | |
| target_col, | |
| descriptor_cols, | |
| smiles_col="SMILES", | |
| device='cuda' if torch.cuda.is_available() else 'cpu', | |
| seed=42, | |
| max_epochs=1000, | |
| patience=30, | |
| batch_size=32, | |
| conv_type='gcn', | |
| n_augment_small=1, | |
| clip_grad_norm=5.0, | |
| target_mae=None, | |
| tol_rel=0.05, | |
| tol_abs=1e-6, | |
| verbose=True): | |
| """ | |
| Train one model on a single train/val split and stop early when validation MAE | |
| is within tolerance of `target_mae` (or when early-stopping triggers). | |
| Returns saved artifact paths and the achieved val MAE. | |
| """ | |
| random.seed(seed); np.random.seed(seed); torch.manual_seed(seed) | |
| df = pd.read_csv(csv_path) | |
| df = df[df[target_col].notna()].reset_index(drop=True) | |
| # Create scaffold folds and pick a single validation fold (fold 0) | |
| fold_assign = scaffold_fold_assignments(df, n_folds=5, smiles_col=smiles_col, seed=seed) | |
| train_idx = [i for i,f in enumerate(fold_assign) if f != 0] | |
| val_idx = [i for i,f in enumerate(fold_assign) if f == 0] | |
| train_df = df.iloc[train_idx].reset_index(drop=True) | |
| val_df = df.iloc[val_idx].reset_index(drop=True) | |
| # Fit scalers on TRAIN only | |
| desc_scaler = StandardScaler().fit(train_df[descriptor_cols].values.astype(float)) | |
| y_scaler = StandardScaler().fit(train_df[[target_col]].values.astype(float)) | |
| # Transform | |
| X_train = desc_scaler.transform(train_df[descriptor_cols].values.astype(float)) | |
| X_val = desc_scaler.transform(val_df[descriptor_cols].values.astype(float)) | |
| y_train = y_scaler.transform(train_df[[target_col]].values.astype(float)).ravel() | |
| y_val = y_scaler.transform(val_df[[target_col]].values.astype(float)).ravel() | |
| # Build graph objects | |
| train_data = [] | |
| for i in range(len(train_df)): | |
| smi = train_df.loc[i, smiles_col] | |
| d = mol_to_pyg_data(smi, global_features=X_train[i], y=float(y_train[i])) | |
| if d is None: | |
| continue | |
| d.idx = train_idx[i] | |
| d.orig_smiles = smi | |
| train_data.append(d) | |
| # optional augmentation for small train sets | |
| if len(train_data) < 2000 and n_augment_small > 0: | |
| aug_list = [] | |
| for d in train_data: | |
| for _ in range(n_augment_small): | |
| rs = randomized_smiles(d.orig_smiles) | |
| if rs is None: continue | |
| aug_d = mol_to_pyg_data(rs, global_features=d.global_feats.detach().cpu().numpy().ravel(), y=d.y.item()) | |
| if aug_d is None: continue | |
| aug_d.idx = d.idx | |
| aug_d.orig_smiles = rs | |
| aug_list.append(aug_d) | |
| if aug_list: | |
| train_data += aug_list | |
| if verbose: print(f"Augmented train set with {len(aug_list)} randomized-smiles.") | |
| val_data = [] | |
| for i in range(len(val_df)): | |
| smi = val_df.loc[i, smiles_col] | |
| d = mol_to_pyg_data(smi, global_features=X_val[i], y=float(y_val[i])) | |
| if d is None: | |
| continue | |
| d.idx = val_idx[i] | |
| d.orig_smiles = smi | |
| val_data.append(d) | |
| # Dataloaders | |
| train_loader = PyGDataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=0) | |
| val_loader = PyGDataLoader(val_data, batch_size=batch_size, shuffle=False, num_workers=0) | |
| # model | |
| node_dim = train_data[0].x.shape[1] | |
| edge_dim = train_data[0].edge_attr.shape[1] if hasattr(train_data[0], 'edge_attr') else 0 | |
| global_dim = train_data[0].global_feats.shape[0] if train_data[0].global_feats.dim()==1 else train_data[0].global_feats.shape[1] | |
| model = GNNWithGlobalFeats(node_in_dim=node_dim, edge_in_dim=edge_dim, global_in_dim=global_dim, | |
| gnn_hidden=128, n_gnn_layers=3, mlp_hidden=128, dropout=0.2, | |
| conv_type=conv_type).to(device) | |
| loss_fn = nn.SmoothL1Loss() | |
| optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=1e-6) | |
| scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=7, verbose=verbose, min_lr=1e-6) | |
| best_val_mae = float('inf') | |
| best_state = None | |
| no_improve = 0 | |
| best_epoch = -1 | |
| # target stopping thresholds | |
| stop_enabled = (target_mae is not None) | |
| if stop_enabled: | |
| tol = max(tol_rel * float(target_mae), tol_abs) | |
| if verbose: | |
| print(f"Target MAE {target_mae} with tolerance {tol} (relative tol {tol_rel}, abs tol {tol_abs})") | |
| for epoch in range(1, max_epochs+1): | |
| train_loss = train_one_epoch(model, train_loader, optimizer, device, loss_fn, clip_norm=clip_grad_norm) | |
| val_mae, ys_orig, preds_orig = evaluate_mae_inverse_scaled(model, val_loader, device, y_scaler) | |
| scheduler.step(val_mae) | |
| if val_mae < best_val_mae - 1e-6: | |
| best_val_mae = val_mae | |
| best_epoch = epoch | |
| no_improve = 0 | |
| best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()} | |
| else: | |
| no_improve += 1 | |
| if verbose and (epoch % 5 == 0 or epoch == 1): | |
| print(f"Epoch {epoch:04d} | train_loss {train_loss:.6f} | val_mae_orig {val_mae:.6f} | best_val {best_val_mae:.6f}") | |
| # stop if we reached target MAE within tolerance | |
| if stop_enabled and abs(val_mae - float(target_mae)) <= tol: | |
| if verbose: | |
| print(f"Stopping at epoch {epoch} because val_mae {val_mae:.6f} is within tolerance of target {target_mae}.") | |
| break | |
| # normal early stopping | |
| if no_improve >= patience: | |
| if verbose: | |
| print(f"Early stopping at epoch {epoch} (best_epoch {best_epoch} | best_val_mae {best_val_mae:.6f})") | |
| break | |
| # restore best | |
| if best_state is not None: | |
| model.load_state_dict({k: best_state[k].to(device) for k in best_state}) | |
| # final metric | |
| val_mae_final, ys_orig, preds_orig = evaluate_mae_inverse_scaled(model, val_loader, device, y_scaler) | |
| if verbose: | |
| print(f"Final val MAE (orig scale): {val_mae_final:.6f} (best_epoch {best_epoch})") | |
| # save model + scalers (single-model artifacts) | |
| base = f"{target_col}_single" | |
| pkg = { | |
| 'state_dict': {k: v.cpu().clone() for k, v in model.state_dict().items()}, | |
| 'node_dim': node_dim, 'edge_dim': edge_dim, 'global_dim': global_dim, | |
| 'gnn_hidden': 128, 'n_gnn_layers': 3, 'mlp_hidden': 128, 'dropout': 0.2, | |
| 'conv_type': conv_type, 'val_mae_orig': float(val_mae_final) | |
| } | |
| model_path = f"model_{base}.pt" | |
| desc_path = f"desc_scaler_{base}.pkl" | |
| y_path = f"y_scaler_{base}.pkl" | |
| joblib.dump(desc_scaler, desc_path) | |
| joblib.dump(y_scaler, y_path) | |
| torch.save(pkg, model_path) | |
| if verbose: | |
| print(f"Saved model -> {model_path}; scalers -> {desc_path}, {y_path}") | |
| return { | |
| 'model_path': model_path, | |
| 'desc_scaler_path': desc_path, | |
| 'y_scaler_path': y_path, | |
| 'val_mae': val_mae_final, | |
| 'epoch': best_epoch, | |
| 'preds_val': preds_orig, | |
| 'y_val': ys_orig | |
| } | |
| if __name__ == "__main__": | |
| dataset_dir = "/kaggle/input/augmented-polymer-data/results" | |
| targets_to_train = { | |
| # 'Tg': 39.6794, | |
| # 'FFV': 0.0042, | |
| # 'Tc': 0.0211, | |
| 'Density': 0.0182, | |
| 'Rg': 1.1638 | |
| } | |
| results = {} | |
| for t, goal in targets_to_train.items(): | |
| csv_path = os.path.join(dataset_dir, f"{t}_data.csv") | |
| print(f"\n=== Training single-model for {t} aiming MAE ~ {goal} ===") | |
| df_tmp = pd.read_csv(csv_path) | |
| descriptor_cols = df_tmp.drop(columns=[t, 'SMILES']).columns.tolist() | |
| joblib.dump(descriptor_cols, f"desc_cols_{t}_single.pkl") | |
| res = run_single_train_until_target(csv_path=csv_path, | |
| target_col=t, | |
| descriptor_cols=descriptor_cols, | |
| smiles_col="SMILES", | |
| device='cuda' if torch.cuda.is_available() else 'cpu', | |
| seed=42, | |
| max_epochs=1000, | |
| patience=40, | |
| batch_size=32, | |
| conv_type='gcn', | |
| n_augment_small=1, | |
| clip_grad_norm=5.0, | |
| target_mae=goal, | |
| tol_rel=0.05, | |
| tol_abs=1e-6, | |
| verbose=True) | |
| results[t] = res | |
| print(f"-> {t} achieved val MAE {res['val_mae']:.6f} (target {goal})") |