| import torch |
| import numpy as np |
| from rdkit import Chem |
| import networkx as nx |
| from collections import defaultdict |
| from torch_geometric.data import Data |
| from polyatomic_complexes.src.complexes.abstract_complex import AbstractComplex |
| from polyatomic_complexes.src.complexes import PolyatomicGeometrySMILE |
|
|
|
|
| def compressed_topsignal_graph_from_smiles( |
| smile: str, y_val: int, topk_lap: int = 5 |
| ) -> Data | None: |
| try: |
| |
| pg = PolyatomicGeometrySMILE(smile=smile, mode="abstract") |
| ac = pg.smiles_to_geom_complex() |
| assert isinstance(ac, AbstractComplex) |
|
|
| |
| mol = Chem.MolFromSmiles(smile) |
| if mol is None: |
| return None |
|
|
| |
| chains = ac.get_raw_k_chains() |
| chain0 = chains.get("chain_0", []) |
| atom_types = [6, 7, 8, 15, 16, 17] |
| hyb_types = [ |
| Chem.rdchem.HybridizationType.SP, |
| Chem.rdchem.HybridizationType.SP2, |
| Chem.rdchem.HybridizationType.SP3, |
| ] |
| node_feats = [] |
| for atom in mol.GetAtoms(): |
| idx = atom.GetIdx() |
| |
| c0 = float(chain0[idx]) if idx < len(chain0) else 0.0 |
| feats = [c0] |
| feats += one_hot(atom.GetAtomicNum(), atom_types) |
| feats += one_hot(atom.GetHybridization(), hyb_types) |
| feats += [ |
| float(atom.GetDegree()), |
| float(atom.GetIsAromatic()), |
| float(atom.GetFormalCharge()), |
| ] |
| node_feats.append(feats) |
| x = torch.tensor(node_feats, dtype=torch.float32) |
| n = x.size(0) |
|
|
| |
| sk = ac.get_skeleta().get("molecule_skeleta", [[]])[0] |
| zero = next((lst for dim, lst in sk if dim == "0"), []) |
| node_ids = [next(iter(fz))[0] for fz in zero] |
| atom_map = defaultdict(list) |
| for i, nid in enumerate(node_ids): |
| symbol = nid.split("_")[0] |
| atom_map[symbol].append(i) |
|
|
| edge_index_list, edge_attr_list = [], [] |
| bond_types = [ |
| Chem.rdchem.BondType.SINGLE, |
| Chem.rdchem.BondType.DOUBLE, |
| Chem.rdchem.BondType.TRIPLE, |
| Chem.rdchem.BondType.AROMATIC, |
| ] |
| for a1, a2, (btype, order) in ac.get_bonds(): |
| bt_val = getattr(Chem.rdchem.BondType, btype, None) |
| for i in atom_map.get(a1, []): |
| for j in atom_map.get(a2, []): |
| if i < n and j < n: |
| edge_index_list += [[i, j], [j, i]] |
| attr = one_hot(bt_val, bond_types) + [float(order), 0.0] |
| edge_attr_list += [attr, attr] |
| if not edge_index_list: |
| for bond in mol.GetBonds(): |
| i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() |
| edge_index_list += [[i, j], [j, i]] |
| attr = one_hot(bond.GetBondType(), bond_types) |
| attr += [float(bond.GetIsConjugated()), float(bond.IsInRing())] |
| edge_attr_list += [attr, attr] |
|
|
| edge_index = torch.tensor(edge_index_list, dtype=torch.long).t().contiguous() |
| edge_attr = torch.tensor(edge_attr_list, dtype=torch.float32) |
|
|
| |
| G = nx.Graph() |
| G.add_nodes_from(range(n)) |
| G.add_edges_from(edge_index_list) |
| cent = nx.closeness_centrality(G) |
| spd = dict(nx.all_pairs_shortest_path_length(G)) |
| cent_vec = [cent.get(i, 0.0) for i in range(n)] |
| spd_vec = [ |
| sum(d.values()) / max(len(d), 1) for d in (spd.get(i, {}) for i in range(n)) |
| ] |
| cent_t = torch.tensor(cent_vec, dtype=torch.float32).view(n, 1) |
| spd_t = torch.tensor(spd_vec, dtype=torch.float32).view(n, 1) |
| x = torch.cat([x, cent_t, spd_t], dim=1) |
|
|
| |
|
|
| |
| g_stats, lap_feats = [], [] |
| for k, arr in chains.items(): |
| if k == "chain_0": |
| continue |
| a = np.array(arr, dtype=np.float32) |
| g_stats += [a.mean(), a.std()] |
|
|
| |
|
|
| for grp in ac.get_laplacians().get("molecule_laplacians", []): |
| recs = grp if isinstance(grp, list) else [grp] |
| for _, mat in recs: |
| |
| M = np.array(mat, dtype=np.float32) |
| |
| try: |
| eigs = np.linalg.eigvalsh(M) |
| except Exception: |
| eigs = np.zeros(M.shape[0], dtype=np.float32) |
| |
| nonzero = eigs[eigs > 1e-6] |
| vals = nonzero[:topk_lap] if len(nonzero) >= topk_lap else nonzero |
| |
| if len(vals) < topk_lap: |
| vals = np.pad(vals, (0, topk_lap - len(vals))) |
| lap_feats += list(vals) |
|
|
| |
| spectral = ac.get_spectral_k_chains() |
| spec_feats = [] |
| for arr in spectral.values(): |
| a = np.array(arr, dtype=np.float32) |
| spec_feats += [a.mean(), a.std()] |
|
|
| |
| b0 = nx.number_connected_components(G) |
| b1 = sum( |
| len(nx.cycle_basis(G.subgraph(comp))) for comp in nx.connected_components(G) |
| ) |
|
|
| |
| all_feats = g_stats + lap_feats + spec_feats + [float(b0), float(b1)] |
| graph_feats = torch.tensor(all_feats, dtype=torch.float32) |
|
|
| |
|
|
| data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr) |
| data.graph_feats = graph_feats |
| data.y = torch.tensor([y_val], dtype=torch.float) |
| |
| return data |
| except Exception as e: |
| |
| return None |
|
|
|
|
| def one_hot(val, choices): |
| return [1.0 if val == c else 0.0 for c in choices] |
|
|