import torch from sklearn.metrics import mean_absolute_error import joblib import pandas as pd import numpy as np import glob import os from MY_GNN.train import GNNWithGlobalFeats, mol_to_pyg_data, atom_features, Data, PyGDataLoader from tqdm import tqdm from rdkit.Chem.rdFingerprintGenerator import GetMorganGenerator, GetAtomPairGenerator, GetTopologicalTorsionGenerator from rdkit.Chem import MACCSkeys from rdkit import Chem import networkx as nx from rdkit.Chem import Descriptors, rdmolops from rdkit.Chem.Descriptors import MolWt, MolLogP from rdkit.Chem.rdMolDescriptors import CalcTPSA, CalcNumRotatableBonds required_descriptors = {'graph_diameter','num_cycles','avg_shortest_path','MolWt', 'LogP', 'TPSA', 'RotatableBonds', 'NumAtoms', 'SMILES'} filters = { 'Tg': list(set(['deg_mean', 'FractionCSP3', 'num_cycles', 'RingCount', 'HallKierAlpha', 'SMR_VSA7', 'BertzCT', 'ring_size_6', 'fr_benzene', 'NumAromaticCarbocycles', 'NumAromaticRings', 'SlogP_VSA6', 'SlogP_VSA1', 'betw_mean', 'VSA_EState6', 'BalabanJ', 'Chi4n', 'FP_446', 'PEOE_VSA14', 'Chi3n', 'AvgIpc', 'FP_489', 'Chi1', 'HeavyAtomCount', 'NumHeterocycles', 'FP_485', 'fr_bicyclic', 'SMR_VSA10', 'FP_537', 'VSA_EState2', 'FP_539', 'FP_529', 'HeavyAtomMolWt', 'LabuteASA', 'ring_size_5', 'FP_505', 'NumAmideBonds', 'MolMR', 'FP_80', 'FP_195', 'FP_310', 'fr_amide', 'FP_509', 'FP_378', 'ExactMolWt', 'MolWt', 'FP_211', 'Chi2n', 'FP_266', 'FP_379', 'FP_207', 'FP_504', 'FP_203', 'NumAtoms', 'FP_199', 'FP_519', 'FP_123', 'FP_278', 'FpDensityMorgan1', 'FP_119', 'fr_imide', 'FP_279', 'FP_223', 'betw_std', 'FP_231', 'FP_219', 'FP_251', 'NumValenceElectrons', 'FP_480', 'Chi0n', 'FP_517', 'FP_255', 'Chi0', 'FP_522', 'FP_528', 'FP_526', 'FpDensityMorgan2', 'FP_354', 'Chi1n', 'FP_459', 'FP_547', 'FP_476', 'Chi0v', 'FP_210', 'FP_516', 'FP_382', 'FP_215', 'FP_243', 'FP_521', 'FP_227', 'NumAliphaticHeterocycles', 'FP_469', 'FP_467', 'FP_342', 'FP_549', 'FP_357', 'FP_494', 'FP_194', 'FP_546', 'FP_302']).union(required_descriptors)), 'FFV': list(set(['MolLogP', 'LogP', 'Chi3v', 'Chi2v', 'Chi4v', 'Chi4n', 'Chi3n', 'VSA_EState6', 'SMR_VSA7', 'Chi1v', 'Chi2n', 'MolMR', 'Chi1n', 'Chi0v', 'SlogP_VSA6', 'BertzCT', 'PEOE_VSA14', 'Chi0n', 'LabuteASA', 'EState_VSA8', 'BalabanJ', 'Ipc', 'Chi1', 'deg_mean', 'VSA_EState8', 'MolWt', 'ExactMolWt', 'SMR_VSA9', 'HeavyAtomMolWt', 'Chi0', 'SMR_VSA6', 'EState_VSA5', 'FpDensityMorgan3', 'Kappa1', 'AvgIpc', 'FpDensityMorgan2', 'SlogP_VSA8', 'HallKierAlpha', 'FP_39', 'avg_shortest_path', 'SMR_VSA1', 'SlogP_VSA5', 'betw_mean', 'TPSA', 'FpDensityMorgan1', 'lap_eig_6', 'qed', 'lap_eig_7', 'lap_eig_8', 'RingCount', 'NumValenceElectrons', 'NumAromaticRings', 'lap_eig_5', 'num_cycles', 'EState_VSA7', 'Kappa2', 'NumAtoms', 'ring_size_6', 'betw_std', 'lap_eig_4', 'lap_eig_3', 'HeavyAtomCount', 'fr_benzene', 'NumHDonors', 'NumAromaticCarbocycles', 'PEOE_VSA7', 'SlogP_VSA2', 'SlogP_VSA3', 'NHOHCount', 'NOCount', 'fr_NH1', 'EState_VSA4', 'FP_515', 'MaxEStateIndex', 'MaxAbsEStateIndex', 'lap_eig_2', 'PEOE_VSA6', 'VSA_EState5', 'EState_VSA6', 'FractionCSP3', 'EState_VSA3', 'Phi', 'FP_535', 'NumHAcceptors', 'SlogP_VSA4', 'fr_C_O', 'FP_446', 'FP_488', 'SMR_VSA10', 'PEOE_VSA9', 'fr_C_O_noCOO', 'FP_125', 'FP_474', 'SlogP_VSA7', 'graph_diameter', 'SlogP_VSA12', 'FP_507', 'fr_bicyclic', 'MinAbsEStateIndex', 'deg_std']).union(required_descriptors)), 'Tc': list(set(['deg_std', 'fr_unbrch_alkane', 'FP_287', 'FP_286', 'betw_mean', 'avg_shortest_path', 'Kappa3', 'graph_diameter', 'FP_285', 'FP_187', 'FP_191', 'FP_139', 'VSA_EState7', 'FpDensityMorgan3', 'FP_143', 'FpDensityMorgan2', 'FP_171', 'Phi', 'FpDensityMorgan1', 'FP_167', 'FP_175', 'qed', 'FP_163', 'FP_131', 'FP_531', 'FP_502', 'FP_142', 'Kappa2', 'FP_135', 'FP_179', 'SlogP_VSA5', 'FP_513', 'FP_134', 'SMR_VSA5', 'FP_93', 'FP_138', 'FP_130', 'RotatableBonds', 'NumRotatableBonds', 'FP_182', 'FP_162', 'FP_518', 'FP_491', 'FP_174', 'FP_496', 'FP_284', 'FP_141', 'FP_475', 'lap_eig_7', 'FP_154', 'lap_eig_8', 'FP_512', 'FP_453', 'FP_256', 'FP_488', 'deg_max', 'FP_170', 'FP_137', 'FP_190', 'FP_17', 'FP_466', 'FP_535', 'FP_474', 'fr_NH1', 'lap_eig_6', 'Chi3n', 'FP_133', 'PEOE_VSA6', 'FP_178', 'FP_186', 'Chi1n', 'FP_450', 'FP_102', 'FP_508', 'EState_VSA5', 'NumHDonors', 'NumAtomStereoCenters', 'NumUnspecifiedAtomStereoCenters', 'NHOHCount', 'FP_183', 'Chi2n', 'Chi3v', 'FP_89', 'SPS', 'betw_max', 'AvgIpc', 'Chi4n', 'Chi1v', 'FP_478', 'lap_eig_5', 'FP_484', 'SMR_VSA3', 'FP_192', 'FP_166', 'Kappa1', 'FP_495', 'FP_526', 'fr_halogen', 'FP_153', 'FP_28']).union(required_descriptors)), 'Density': list(set(['SMR_VSA5', 'VSA_EState8', 'VSA_EState7', 'SlogP_VSA5', 'SMR_VSA10', 'FractionCSP3', 'EState_VSA5', 'SlogP_VSA12', 'VSA_EState10', 'fr_unbrch_alkane', 'NumRotatableBonds', 'RotatableBonds', 'FP_119', 'FP_513', 'PEOE_VSA8', 'Kappa3', 'PEOE_VSA7', 'FP_180', 'FP_472', 'FP_428', 'Kappa2', 'FP_80', 'Phi', 'FP_539', 'FP_512', 'FP_531', 'EState_VSA7', 'FP_537', 'FP_502', 'FP_98', 'NumHAcceptors', 'Chi1n', 'MaxAbsEStateIndex', 'MaxEStateIndex', 'PEOE_VSA14', 'FP_500', 'MolLogP', 'LogP', 'FP_465', 'MinAbsEStateIndex', 'Chi2n', 'SlogP_VSA7', 'FP_176', 'avg_shortest_path', 'EState_VSA4', 'FP_181', 'lap_eig_5', 'Chi0n', 'HallKierAlpha', 'PEOE_VSA5', 'qed', 'graph_diameter', 'FP_186', 'betw_mean', 'FP_287', 'FP_179', 'lap_eig_4', 'FP_134', 'Chi3n', 'NOCount', 'fr_C_S', 'FP_131', 'FP_177', 'FP_166', 'FP_127', 'FP_162', 'FP_191', 'FP_143', 'Chi4n', 'TPSA', 'lap_eig_3', 'Chi1v', 'SlogP_VSA6', 'FP_178', 'FP_457', 'FP_139', 'FP_163', 'SMR_VSA7', 'SlogP_VSA11', 'SlogP_VSA3', 'FP_183', 'Chi0', 'FP_137', 'ring_size_6', 'FP_138', 'fr_benzene', 'NumAromaticCarbocycles', 'FP_420', 'NumAromaticRings', 'NumAromaticHeterocycles', 'FP_492', 'FP_169', 'FP_284', 'Chi1', 'FP_141', 'FP_35', 'FP_182', 'FP_521', 'EState_VSA3', 'FP_135']).union(required_descriptors)), 'Rg': list(set(['FP_93', 'SlogP_VSA7', 'PEOE_VSA14', 'qed', 'FP_544', 'VSA_EState8', 'FP_499', 'SlogP_VSA1', 'fr_unbrch_alkane', 'FP_42', 'EState_VSA4', 'FP_192', 'FP_508', 'FP_520', 'lap_eig_8', 'Phi', 'FP_155', 'NumAtomStereoCenters', 'NumUnspecifiedAtomStereoCenters', 'avg_shortest_path', 'FP_17', 'FP_317', 'lap_eig_7', 'FP_73', 'VSA_EState7', 'FP_224', 'fr_ester', 'graph_diameter', 'Kappa2', 'NumAmideBonds', 'fr_NH1', 'FP_191', 'fr_amide', 'FP_286', 'Kappa3', 'FP_159', 'FP_488', 'FP_33', 'deg_std', 'FP_280', 'FP_364', 'FP_287', 'EState_VSA5', 'SlogP_VSA5', 'FP_515', 'TPSA', 'FP_151', 'SMR_VSA5', 'FP_498', 'NOCount', 'betw_mean', 'RotatableBonds', 'NumRotatableBonds', 'FP_273', 'SMR_VSA3', 'FP_163', 'FP_134', 'FP_478', 'FP_138', 'FP_187', 'FP_137', 'FP_252', 'VSA_EState3', 'FP_171', 'FP_175', 'lap_eig_6', 'NHOHCount', 'Chi4v', 'FpDensityMorgan1', 'FP_182', 'FP_526', 'FP_167', 'FP_486', 'FP_142', 'FP_316', 'AvgIpc', 'MolLogP', 'LogP', 'FP_183', 'FP_130', 'FP_102', 'FP_1', 'FP_115', 'SMR_VSA10', 'Chi4n', 'FP_24', 'FP_533', 'NumHDonors', 'FP_193', 'FP_147', 'FP_38', 'Chi3n', 'FP_249', 'FP_453', 'FP_535', 'FP_492', 'Chi3v', 'FP_240', 'FP_501', 'FP_139']).union(required_descriptors)) } def smiles_to_combined_fingerprints_with_descriptors(smiles_list, radius=2, n_bits=128): generator = GetMorganGenerator(radius=radius, fpSize=n_bits) atom_pair_gen = GetAtomPairGenerator(fpSize=n_bits) torsion_gen = GetTopologicalTorsionGenerator(fpSize=n_bits) fingerprints = [] descriptors = [] valid_smiles = [] invalid_indices = [] for i, smiles in tqdm(enumerate(smiles_list), total=len(smiles_list), desc="🔬 Data Augmentation"): mol = Chem.MolFromSmiles(smiles) if mol: # Fingerprints morgan_fp = generator.GetFingerprint(mol) atom_pair_fp = atom_pair_gen.GetFingerprint(mol) torsion_fp = torsion_gen.GetFingerprint(mol) maccs_fp = MACCSkeys.GenMACCSKeys(mol) combined_fp = np.concatenate([ np.array(morgan_fp), np.array(atom_pair_fp), np.array(torsion_fp), np.array(maccs_fp) ]) fingerprints.append(combined_fp) # RDKit Descriptors descriptor_values = {} for name, func in Descriptors.descList: try: descriptor_values[name] = func(mol) except: descriptor_values[name] = None # Specific descriptors descriptor_values['MolWt'] = MolWt(mol) descriptor_values['LogP'] = MolLogP(mol) descriptor_values['TPSA'] = CalcTPSA(mol) descriptor_values['RotatableBonds'] = CalcNumRotatableBonds(mol) descriptor_values['NumAtoms'] = mol.GetNumAtoms() descriptor_values['SMILES'] = smiles # Graph-based features try: adj = rdmolops.GetAdjacencyMatrix(mol) G = nx.from_numpy_array(adj) if nx.is_connected(G): descriptor_values['graph_diameter'] = nx.diameter(G) descriptor_values['avg_shortest_path'] = nx.average_shortest_path_length(G) else: descriptor_values['graph_diameter'] = 0 descriptor_values['avg_shortest_path'] = 0 cycles = nx.cycle_basis(G) descriptor_values['num_cycles'] = len(list(cycles)) sizes = [len(c) for c in cycles] for k in range(3, 9): descriptor_values[f'ring_size_{k}'] = sizes.count(k) except: descriptor_values['graph_diameter'] = None descriptor_values['avg_shortest_path'] = None descriptor_values['num_cycles'] = None for k in range(3, 9): descriptor_values[f'ring_size_{k}'] = None # Compute Centralities adj = rdmolops.GetAdjacencyMatrix(mol) G = nx.from_numpy_array(adj) deg = dict(nx.degree(G)) bc = nx.betweenness_centrality(G) cc = nx.clustering(G) for label, metric in [('deg', deg), ('betw', bc), ('clust', cc)]: vals = np.array(list(metric.values()), dtype=float) descriptor_values[f'{label}_mean'] = vals.mean() descriptor_values[f'{label}_std'] = vals.std() descriptor_values[f'{label}_max'] = vals.max() # Compute Spectral adj = rdmolops.GetAdjacencyMatrix(mol) G = nx.from_numpy_array(adj) L = nx.normalized_laplacian_matrix(G).toarray() eigs = np.linalg.eigvals(L) eigs = np.sort(eigs.real) for i in range(min(k, len(eigs))): descriptor_values[f'lap_eig_{i+1}'] = eigs[i] for i in range(len(eigs), k): descriptor_values[f'lap_eig_{i+1}'] = 0.0 descriptor_values['Ipc'] = np.log10(descriptor_values['Ipc']) descriptor_values['lap_eig_1'] = np.sign(descriptor_values['lap_eig_1']) * np.log10(np.abs(descriptor_values['lap_eig_1']) + 1e-20) ### descriptors.append(descriptor_values) valid_smiles.append(smiles) else: fingerprints.append(np.zeros(n_bits * 3 + 167)) descriptors.append(None) valid_smiles.append(None) invalid_indices.append(i) return np.array(fingerprints), descriptors, valid_smiles, invalid_indices def eval_on_host_train( target, host_train_csv, desc_cols_file, model_pattern="model_{}_fold*.pt", aggregate="mean", evaluate=False, ): # load host test test_smiles = host_train_csv['SMILES'].tolist() fingerprints, descriptors, valid_smiles, invalid_indices = smiles_to_combined_fingerprints_with_descriptors(test_smiles, radius=2, n_bits=128) X = pd.DataFrame(descriptors) X = X.drop(['BCUT2D_MWLOW','BCUT2D_MWHI','BCUT2D_CHGHI','BCUT2D_CHGLO','BCUT2D_LOGPHI','BCUT2D_LOGPLOW','BCUT2D_MRLOW','BCUT2D_MRHI','MinAbsPartialCharge','MaxPartialCharge','MinPartialCharge','MaxAbsPartialCharge',],axis=1) selected = filters[target] X = X.filter(items=selected) fp_df = pd.DataFrame(fingerprints, columns=[f'FP_{i}' for i in range(fingerprints.shape[1])]) print(fp_df.shape) fp_df.reset_index(drop=True, inplace=True) X.reset_index(drop=True, inplace=True) X = pd.concat([X, fp_df], axis=1) print(f"After concat: {X.shape}") df = X n = len(df) desc_cols = ( joblib.load(desc_cols_file) if desc_cols_file else [c for c in df.columns if c not in ["SMILES", target]] ) # collect models model_files = sorted(glob.glob(model_pattern.format(target))) assert model_files, "No fold models found" all_preds = np.zeros((len(model_files), n), dtype=float) for i, mp in enumerate(model_files): pkg = torch.load(mp, map_location="cpu") # load per-fold scalers (must exist) model_dir = os.path.dirname(mp) desc_scaler_path = pkg["scaler_files"]["desc"] y_scaler_path = pkg["scaler_files"]["y"] if not os.path.isabs(desc_scaler_path): desc_scaler_path = os.path.join(model_dir, desc_scaler_path) if not os.path.isabs(y_scaler_path): y_scaler_path = os.path.join(model_dir, y_scaler_path) desc_scaler = joblib.load(desc_scaler_path) y_scaler = joblib.load(y_scaler_path) # scale whole-host descriptors using this fold's scaler X = df[desc_cols].values.astype(float) Xs = desc_scaler.transform(X) # build graphs for host rows (order preserved) data_list = [] for idx in range(n): smi = df.loc[idx, "SMILES"] d = mol_to_pyg_data(smi, global_features=Xs[idx], y=None) if d is None: # fallback single-node graph from rdkit import Chem zero = torch.zeros( (1, len(atom_features(Chem.Atom("C")))), dtype=torch.float ) d = Data( x=zero, edge_index=torch.tensor([[0], [0]], dtype=torch.long), edge_attr=torch.zeros((1, 4)), global_feats=torch.tensor(Xs[idx], dtype=torch.float), ) d.orig_idx = torch.tensor(idx, dtype=torch.long) data_list.append(d) loader = PyGDataLoader( data_list, batch_size=64, shuffle=False, num_workers=0 ) # instantiate model and load weights model = GNNWithGlobalFeats( node_in_dim=pkg["node_dim"], edge_in_dim=pkg["edge_dim"], global_in_dim=pkg["global_dim"], gnn_hidden=pkg.get("gnn_hidden", 128), n_gnn_layers=pkg.get("n_gnn_layers", 3), mlp_hidden=pkg.get("mlp_hidden", 128), dropout=pkg.get("dropout", 0.2), conv_type=pkg.get("conv_type", "gcn"), ) model.load_state_dict(pkg["state_dict"]) model.eval() preds_fold = np.zeros(n, dtype=float) with torch.no_grad(): for batch in loader: batch = batch.to("cpu") out = model(batch).detach().cpu().numpy() if hasattr(batch, "orig_idx"): idxs = batch.orig_idx.detach().cpu().numpy().ravel() for p, idx in zip(out.tolist(), idxs.tolist()): preds_fold[int(idx)] = p else: # fallback sequential pass # inverse-scale fold preds to original units preds_orig = y_scaler.inverse_transform( preds_fold.reshape(-1, 1) ).ravel() all_preds[i, :] = preds_orig # aggregate if aggregate == "mean": final = all_preds.mean(axis=0) else: final = all_preds.mean(axis=0) # extendable to weighted if evaluate: # compute host-train MAE if target in df.columns: host_mae = mean_absolute_error(df[target].values.astype(float), final) print( f"Host-train MAE for {target}: {host_mae:.6f} (using {len(model_files)} fold models)" ) return host_mae, final, all_preds else: print(f"Target column '{target}' not found, skipping MAE calculation.") return None, final, all_preds else: return final, all_preds if __name__ == "__main__": submission_df = {} for label in ["Density", "Rg"]: host_csv = f"./Datasets/{label}/{label}.csv" mae, preds, allp = eval_on_host_train( label, host_csv, model_pattern=f"model_{label}_fold*.pt", desc_cols_file=f"desc_cols_{label}.pkl", evaluate=True, ) submission_df[label] = preds print(submission_df)