Spaces:
Configuration error
Configuration error
| import torch | |
| from sklearn.metrics import mean_absolute_error | |
| import joblib | |
| import pandas as pd | |
| import numpy as np | |
| import glob | |
| import os | |
| from MY_GNN.train import GNNWithGlobalFeats, mol_to_pyg_data, atom_features, Data, PyGDataLoader | |
| from tqdm import tqdm | |
| from rdkit.Chem.rdFingerprintGenerator import GetMorganGenerator, GetAtomPairGenerator, GetTopologicalTorsionGenerator | |
| from rdkit.Chem import MACCSkeys | |
| from rdkit import Chem | |
| import networkx as nx | |
| from rdkit.Chem import Descriptors, rdmolops | |
| from rdkit.Chem.Descriptors import MolWt, MolLogP | |
| from rdkit.Chem.rdMolDescriptors import CalcTPSA, CalcNumRotatableBonds | |
| required_descriptors = {'graph_diameter','num_cycles','avg_shortest_path','MolWt', 'LogP', 'TPSA', 'RotatableBonds', 'NumAtoms', 'SMILES'} | |
| filters = { | |
| 'Tg': list(set(['deg_mean', 'FractionCSP3', 'num_cycles', 'RingCount', 'HallKierAlpha', 'SMR_VSA7', 'BertzCT', 'ring_size_6', 'fr_benzene', 'NumAromaticCarbocycles', 'NumAromaticRings', 'SlogP_VSA6', 'SlogP_VSA1', 'betw_mean', 'VSA_EState6', 'BalabanJ', 'Chi4n', 'FP_446', 'PEOE_VSA14', 'Chi3n', 'AvgIpc', 'FP_489', 'Chi1', 'HeavyAtomCount', 'NumHeterocycles', 'FP_485', 'fr_bicyclic', 'SMR_VSA10', 'FP_537', 'VSA_EState2', 'FP_539', 'FP_529', 'HeavyAtomMolWt', 'LabuteASA', 'ring_size_5', 'FP_505', 'NumAmideBonds', 'MolMR', 'FP_80', 'FP_195', 'FP_310', 'fr_amide', 'FP_509', 'FP_378', 'ExactMolWt', 'MolWt', 'FP_211', 'Chi2n', 'FP_266', 'FP_379', 'FP_207', 'FP_504', 'FP_203', 'NumAtoms', 'FP_199', 'FP_519', 'FP_123', 'FP_278', 'FpDensityMorgan1', 'FP_119', 'fr_imide', 'FP_279', 'FP_223', 'betw_std', 'FP_231', 'FP_219', 'FP_251', 'NumValenceElectrons', 'FP_480', 'Chi0n', 'FP_517', 'FP_255', 'Chi0', 'FP_522', 'FP_528', 'FP_526', 'FpDensityMorgan2', 'FP_354', 'Chi1n', 'FP_459', 'FP_547', 'FP_476', 'Chi0v', 'FP_210', 'FP_516', 'FP_382', 'FP_215', 'FP_243', 'FP_521', 'FP_227', 'NumAliphaticHeterocycles', 'FP_469', 'FP_467', 'FP_342', 'FP_549', 'FP_357', 'FP_494', 'FP_194', 'FP_546', 'FP_302']).union(required_descriptors)), | |
| 'FFV': list(set(['MolLogP', 'LogP', 'Chi3v', 'Chi2v', 'Chi4v', 'Chi4n', 'Chi3n', 'VSA_EState6', 'SMR_VSA7', 'Chi1v', 'Chi2n', 'MolMR', 'Chi1n', 'Chi0v', 'SlogP_VSA6', 'BertzCT', 'PEOE_VSA14', 'Chi0n', 'LabuteASA', 'EState_VSA8', 'BalabanJ', 'Ipc', 'Chi1', 'deg_mean', 'VSA_EState8', 'MolWt', 'ExactMolWt', 'SMR_VSA9', 'HeavyAtomMolWt', 'Chi0', 'SMR_VSA6', 'EState_VSA5', 'FpDensityMorgan3', 'Kappa1', 'AvgIpc', 'FpDensityMorgan2', 'SlogP_VSA8', 'HallKierAlpha', 'FP_39', 'avg_shortest_path', 'SMR_VSA1', 'SlogP_VSA5', 'betw_mean', 'TPSA', 'FpDensityMorgan1', 'lap_eig_6', 'qed', 'lap_eig_7', 'lap_eig_8', 'RingCount', 'NumValenceElectrons', 'NumAromaticRings', 'lap_eig_5', 'num_cycles', 'EState_VSA7', 'Kappa2', 'NumAtoms', 'ring_size_6', 'betw_std', 'lap_eig_4', 'lap_eig_3', 'HeavyAtomCount', 'fr_benzene', 'NumHDonors', 'NumAromaticCarbocycles', 'PEOE_VSA7', 'SlogP_VSA2', 'SlogP_VSA3', 'NHOHCount', 'NOCount', 'fr_NH1', 'EState_VSA4', 'FP_515', 'MaxEStateIndex', 'MaxAbsEStateIndex', 'lap_eig_2', 'PEOE_VSA6', 'VSA_EState5', 'EState_VSA6', 'FractionCSP3', 'EState_VSA3', 'Phi', 'FP_535', 'NumHAcceptors', 'SlogP_VSA4', 'fr_C_O', 'FP_446', 'FP_488', 'SMR_VSA10', 'PEOE_VSA9', 'fr_C_O_noCOO', 'FP_125', 'FP_474', 'SlogP_VSA7', 'graph_diameter', 'SlogP_VSA12', 'FP_507', 'fr_bicyclic', 'MinAbsEStateIndex', 'deg_std']).union(required_descriptors)), | |
| 'Tc': list(set(['deg_std', 'fr_unbrch_alkane', 'FP_287', 'FP_286', 'betw_mean', 'avg_shortest_path', 'Kappa3', 'graph_diameter', 'FP_285', 'FP_187', 'FP_191', 'FP_139', 'VSA_EState7', 'FpDensityMorgan3', 'FP_143', 'FpDensityMorgan2', 'FP_171', 'Phi', 'FpDensityMorgan1', 'FP_167', 'FP_175', 'qed', 'FP_163', 'FP_131', 'FP_531', 'FP_502', 'FP_142', 'Kappa2', 'FP_135', 'FP_179', 'SlogP_VSA5', 'FP_513', 'FP_134', 'SMR_VSA5', 'FP_93', 'FP_138', 'FP_130', 'RotatableBonds', 'NumRotatableBonds', 'FP_182', 'FP_162', 'FP_518', 'FP_491', 'FP_174', 'FP_496', 'FP_284', 'FP_141', 'FP_475', 'lap_eig_7', 'FP_154', 'lap_eig_8', 'FP_512', 'FP_453', 'FP_256', 'FP_488', 'deg_max', 'FP_170', 'FP_137', 'FP_190', 'FP_17', 'FP_466', 'FP_535', 'FP_474', 'fr_NH1', 'lap_eig_6', 'Chi3n', 'FP_133', 'PEOE_VSA6', 'FP_178', 'FP_186', 'Chi1n', 'FP_450', 'FP_102', 'FP_508', 'EState_VSA5', 'NumHDonors', 'NumAtomStereoCenters', 'NumUnspecifiedAtomStereoCenters', 'NHOHCount', 'FP_183', 'Chi2n', 'Chi3v', 'FP_89', 'SPS', 'betw_max', 'AvgIpc', 'Chi4n', 'Chi1v', 'FP_478', 'lap_eig_5', 'FP_484', 'SMR_VSA3', 'FP_192', 'FP_166', 'Kappa1', 'FP_495', 'FP_526', 'fr_halogen', 'FP_153', 'FP_28']).union(required_descriptors)), | |
| 'Density': list(set(['SMR_VSA5', 'VSA_EState8', 'VSA_EState7', 'SlogP_VSA5', 'SMR_VSA10', 'FractionCSP3', 'EState_VSA5', 'SlogP_VSA12', 'VSA_EState10', 'fr_unbrch_alkane', 'NumRotatableBonds', 'RotatableBonds', 'FP_119', 'FP_513', 'PEOE_VSA8', 'Kappa3', 'PEOE_VSA7', 'FP_180', 'FP_472', 'FP_428', 'Kappa2', 'FP_80', 'Phi', 'FP_539', 'FP_512', 'FP_531', 'EState_VSA7', 'FP_537', 'FP_502', 'FP_98', 'NumHAcceptors', 'Chi1n', 'MaxAbsEStateIndex', 'MaxEStateIndex', 'PEOE_VSA14', 'FP_500', 'MolLogP', 'LogP', 'FP_465', 'MinAbsEStateIndex', 'Chi2n', 'SlogP_VSA7', 'FP_176', 'avg_shortest_path', 'EState_VSA4', 'FP_181', 'lap_eig_5', 'Chi0n', 'HallKierAlpha', 'PEOE_VSA5', 'qed', 'graph_diameter', 'FP_186', 'betw_mean', 'FP_287', 'FP_179', 'lap_eig_4', 'FP_134', 'Chi3n', 'NOCount', 'fr_C_S', 'FP_131', 'FP_177', 'FP_166', 'FP_127', 'FP_162', 'FP_191', 'FP_143', 'Chi4n', 'TPSA', 'lap_eig_3', 'Chi1v', 'SlogP_VSA6', 'FP_178', 'FP_457', 'FP_139', 'FP_163', 'SMR_VSA7', 'SlogP_VSA11', 'SlogP_VSA3', 'FP_183', 'Chi0', 'FP_137', 'ring_size_6', 'FP_138', 'fr_benzene', 'NumAromaticCarbocycles', 'FP_420', 'NumAromaticRings', 'NumAromaticHeterocycles', 'FP_492', 'FP_169', 'FP_284', 'Chi1', 'FP_141', 'FP_35', 'FP_182', 'FP_521', 'EState_VSA3', 'FP_135']).union(required_descriptors)), | |
| 'Rg': list(set(['FP_93', 'SlogP_VSA7', 'PEOE_VSA14', 'qed', 'FP_544', 'VSA_EState8', 'FP_499', 'SlogP_VSA1', 'fr_unbrch_alkane', 'FP_42', 'EState_VSA4', 'FP_192', 'FP_508', 'FP_520', 'lap_eig_8', 'Phi', 'FP_155', 'NumAtomStereoCenters', 'NumUnspecifiedAtomStereoCenters', 'avg_shortest_path', 'FP_17', 'FP_317', 'lap_eig_7', 'FP_73', 'VSA_EState7', 'FP_224', 'fr_ester', 'graph_diameter', 'Kappa2', 'NumAmideBonds', 'fr_NH1', 'FP_191', 'fr_amide', 'FP_286', 'Kappa3', 'FP_159', 'FP_488', 'FP_33', 'deg_std', 'FP_280', 'FP_364', 'FP_287', 'EState_VSA5', 'SlogP_VSA5', 'FP_515', 'TPSA', 'FP_151', 'SMR_VSA5', 'FP_498', 'NOCount', 'betw_mean', 'RotatableBonds', 'NumRotatableBonds', 'FP_273', 'SMR_VSA3', 'FP_163', 'FP_134', 'FP_478', 'FP_138', 'FP_187', 'FP_137', 'FP_252', 'VSA_EState3', 'FP_171', 'FP_175', 'lap_eig_6', 'NHOHCount', 'Chi4v', 'FpDensityMorgan1', 'FP_182', 'FP_526', 'FP_167', 'FP_486', 'FP_142', 'FP_316', 'AvgIpc', 'MolLogP', 'LogP', 'FP_183', 'FP_130', 'FP_102', 'FP_1', 'FP_115', 'SMR_VSA10', 'Chi4n', 'FP_24', 'FP_533', 'NumHDonors', 'FP_193', 'FP_147', 'FP_38', 'Chi3n', 'FP_249', 'FP_453', 'FP_535', 'FP_492', 'Chi3v', 'FP_240', 'FP_501', 'FP_139']).union(required_descriptors)) | |
| } | |
| def smiles_to_combined_fingerprints_with_descriptors(smiles_list, radius=2, n_bits=128): | |
| generator = GetMorganGenerator(radius=radius, fpSize=n_bits) | |
| atom_pair_gen = GetAtomPairGenerator(fpSize=n_bits) | |
| torsion_gen = GetTopologicalTorsionGenerator(fpSize=n_bits) | |
| fingerprints = [] | |
| descriptors = [] | |
| valid_smiles = [] | |
| invalid_indices = [] | |
| for i, smiles in tqdm(enumerate(smiles_list), total=len(smiles_list), desc="🔬 Data Augmentation"): | |
| mol = Chem.MolFromSmiles(smiles) | |
| if mol: | |
| # Fingerprints | |
| morgan_fp = generator.GetFingerprint(mol) | |
| atom_pair_fp = atom_pair_gen.GetFingerprint(mol) | |
| torsion_fp = torsion_gen.GetFingerprint(mol) | |
| maccs_fp = MACCSkeys.GenMACCSKeys(mol) | |
| combined_fp = np.concatenate([ | |
| np.array(morgan_fp), | |
| np.array(atom_pair_fp), | |
| np.array(torsion_fp), | |
| np.array(maccs_fp) | |
| ]) | |
| fingerprints.append(combined_fp) | |
| # RDKit Descriptors | |
| descriptor_values = {} | |
| for name, func in Descriptors.descList: | |
| try: | |
| descriptor_values[name] = func(mol) | |
| except: | |
| descriptor_values[name] = None | |
| # Specific descriptors | |
| descriptor_values['MolWt'] = MolWt(mol) | |
| descriptor_values['LogP'] = MolLogP(mol) | |
| descriptor_values['TPSA'] = CalcTPSA(mol) | |
| descriptor_values['RotatableBonds'] = CalcNumRotatableBonds(mol) | |
| descriptor_values['NumAtoms'] = mol.GetNumAtoms() | |
| descriptor_values['SMILES'] = smiles | |
| # Graph-based features | |
| try: | |
| adj = rdmolops.GetAdjacencyMatrix(mol) | |
| G = nx.from_numpy_array(adj) | |
| if nx.is_connected(G): | |
| descriptor_values['graph_diameter'] = nx.diameter(G) | |
| descriptor_values['avg_shortest_path'] = nx.average_shortest_path_length(G) | |
| else: | |
| descriptor_values['graph_diameter'] = 0 | |
| descriptor_values['avg_shortest_path'] = 0 | |
| cycles = nx.cycle_basis(G) | |
| descriptor_values['num_cycles'] = len(list(cycles)) | |
| sizes = [len(c) for c in cycles] | |
| for k in range(3, 9): | |
| descriptor_values[f'ring_size_{k}'] = sizes.count(k) | |
| except: | |
| descriptor_values['graph_diameter'] = None | |
| descriptor_values['avg_shortest_path'] = None | |
| descriptor_values['num_cycles'] = None | |
| for k in range(3, 9): | |
| descriptor_values[f'ring_size_{k}'] = None | |
| # Compute Centralities | |
| adj = rdmolops.GetAdjacencyMatrix(mol) | |
| G = nx.from_numpy_array(adj) | |
| deg = dict(nx.degree(G)) | |
| bc = nx.betweenness_centrality(G) | |
| cc = nx.clustering(G) | |
| for label, metric in [('deg', deg), ('betw', bc), ('clust', cc)]: | |
| vals = np.array(list(metric.values()), dtype=float) | |
| descriptor_values[f'{label}_mean'] = vals.mean() | |
| descriptor_values[f'{label}_std'] = vals.std() | |
| descriptor_values[f'{label}_max'] = vals.max() | |
| # Compute Spectral | |
| adj = rdmolops.GetAdjacencyMatrix(mol) | |
| G = nx.from_numpy_array(adj) | |
| L = nx.normalized_laplacian_matrix(G).toarray() | |
| eigs = np.linalg.eigvals(L) | |
| eigs = np.sort(eigs.real) | |
| for i in range(min(k, len(eigs))): | |
| descriptor_values[f'lap_eig_{i+1}'] = eigs[i] | |
| for i in range(len(eigs), k): | |
| descriptor_values[f'lap_eig_{i+1}'] = 0.0 | |
| descriptor_values['Ipc'] = np.log10(descriptor_values['Ipc']) | |
| descriptor_values['lap_eig_1'] = np.sign(descriptor_values['lap_eig_1']) * np.log10(np.abs(descriptor_values['lap_eig_1']) + 1e-20) | |
| ### | |
| descriptors.append(descriptor_values) | |
| valid_smiles.append(smiles) | |
| else: | |
| fingerprints.append(np.zeros(n_bits * 3 + 167)) | |
| descriptors.append(None) | |
| valid_smiles.append(None) | |
| invalid_indices.append(i) | |
| return np.array(fingerprints), descriptors, valid_smiles, invalid_indices | |
| def eval_on_host_train( | |
| target, | |
| host_train_csv, | |
| desc_cols_file, | |
| model_pattern="model_{}_fold*.pt", | |
| aggregate="mean", | |
| evaluate=False, | |
| ): | |
| # load host test | |
| test_smiles = host_train_csv['SMILES'].tolist() | |
| fingerprints, descriptors, valid_smiles, invalid_indices = smiles_to_combined_fingerprints_with_descriptors(test_smiles, radius=2, n_bits=128) | |
| X = pd.DataFrame(descriptors) | |
| X = X.drop(['BCUT2D_MWLOW','BCUT2D_MWHI','BCUT2D_CHGHI','BCUT2D_CHGLO','BCUT2D_LOGPHI','BCUT2D_LOGPLOW','BCUT2D_MRLOW','BCUT2D_MRHI','MinAbsPartialCharge','MaxPartialCharge','MinPartialCharge','MaxAbsPartialCharge',],axis=1) | |
| selected = filters[target] | |
| X = X.filter(items=selected) | |
| fp_df = pd.DataFrame(fingerprints, columns=[f'FP_{i}' for i in range(fingerprints.shape[1])]) | |
| print(fp_df.shape) | |
| fp_df.reset_index(drop=True, inplace=True) | |
| X.reset_index(drop=True, inplace=True) | |
| X = pd.concat([X, fp_df], axis=1) | |
| print(f"After concat: {X.shape}") | |
| df = X | |
| n = len(df) | |
| desc_cols = ( | |
| joblib.load(desc_cols_file) | |
| if desc_cols_file | |
| else [c for c in df.columns if c not in ["SMILES", target]] | |
| ) | |
| # collect models | |
| model_files = sorted(glob.glob(model_pattern.format(target))) | |
| assert model_files, "No fold models found" | |
| all_preds = np.zeros((len(model_files), n), dtype=float) | |
| for i, mp in enumerate(model_files): | |
| pkg = torch.load(mp, map_location="cpu") | |
| # load per-fold scalers (must exist) | |
| model_dir = os.path.dirname(mp) | |
| desc_scaler_path = pkg["scaler_files"]["desc"] | |
| y_scaler_path = pkg["scaler_files"]["y"] | |
| if not os.path.isabs(desc_scaler_path): | |
| desc_scaler_path = os.path.join(model_dir, desc_scaler_path) | |
| if not os.path.isabs(y_scaler_path): | |
| y_scaler_path = os.path.join(model_dir, y_scaler_path) | |
| desc_scaler = joblib.load(desc_scaler_path) | |
| y_scaler = joblib.load(y_scaler_path) | |
| # scale whole-host descriptors using this fold's scaler | |
| X = df[desc_cols].values.astype(float) | |
| Xs = desc_scaler.transform(X) | |
| # build graphs for host rows (order preserved) | |
| data_list = [] | |
| for idx in range(n): | |
| smi = df.loc[idx, "SMILES"] | |
| d = mol_to_pyg_data(smi, global_features=Xs[idx], y=None) | |
| if d is None: | |
| # fallback single-node graph | |
| from rdkit import Chem | |
| zero = torch.zeros( | |
| (1, len(atom_features(Chem.Atom("C")))), dtype=torch.float | |
| ) | |
| d = Data( | |
| x=zero, | |
| edge_index=torch.tensor([[0], [0]], dtype=torch.long), | |
| edge_attr=torch.zeros((1, 4)), | |
| global_feats=torch.tensor(Xs[idx], dtype=torch.float), | |
| ) | |
| d.orig_idx = torch.tensor(idx, dtype=torch.long) | |
| data_list.append(d) | |
| loader = PyGDataLoader( | |
| data_list, batch_size=64, shuffle=False, num_workers=0 | |
| ) | |
| # instantiate model and load weights | |
| model = GNNWithGlobalFeats( | |
| node_in_dim=pkg["node_dim"], | |
| edge_in_dim=pkg["edge_dim"], | |
| global_in_dim=pkg["global_dim"], | |
| gnn_hidden=pkg.get("gnn_hidden", 128), | |
| n_gnn_layers=pkg.get("n_gnn_layers", 3), | |
| mlp_hidden=pkg.get("mlp_hidden", 128), | |
| dropout=pkg.get("dropout", 0.2), | |
| conv_type=pkg.get("conv_type", "gcn"), | |
| ) | |
| model.load_state_dict(pkg["state_dict"]) | |
| model.eval() | |
| preds_fold = np.zeros(n, dtype=float) | |
| with torch.no_grad(): | |
| for batch in loader: | |
| batch = batch.to("cpu") | |
| out = model(batch).detach().cpu().numpy() | |
| if hasattr(batch, "orig_idx"): | |
| idxs = batch.orig_idx.detach().cpu().numpy().ravel() | |
| for p, idx in zip(out.tolist(), idxs.tolist()): | |
| preds_fold[int(idx)] = p | |
| else: | |
| # fallback sequential | |
| pass | |
| # inverse-scale fold preds to original units | |
| preds_orig = y_scaler.inverse_transform( | |
| preds_fold.reshape(-1, 1) | |
| ).ravel() | |
| all_preds[i, :] = preds_orig | |
| # aggregate | |
| if aggregate == "mean": | |
| final = all_preds.mean(axis=0) | |
| else: | |
| final = all_preds.mean(axis=0) # extendable to weighted | |
| if evaluate: | |
| # compute host-train MAE | |
| if target in df.columns: | |
| host_mae = mean_absolute_error(df[target].values.astype(float), final) | |
| print( | |
| f"Host-train MAE for {target}: {host_mae:.6f} (using {len(model_files)} fold models)" | |
| ) | |
| return host_mae, final, all_preds | |
| else: | |
| print(f"Target column '{target}' not found, skipping MAE calculation.") | |
| return None, final, all_preds | |
| else: | |
| return final, all_preds | |
| if __name__ == "__main__": | |
| submission_df = {} | |
| for label in ["Density", "Rg"]: | |
| host_csv = f"./Datasets/{label}/{label}.csv" | |
| mae, preds, allp = eval_on_host_train( | |
| label, | |
| host_csv, | |
| model_pattern=f"model_{label}_fold*.pt", | |
| desc_cols_file=f"desc_cols_{label}.pkl", | |
| evaluate=True, | |
| ) | |
| submission_df[label] = preds | |
| print(submission_df) |