| | import torch |
| | import numpy as np |
| | from rdkit import Chem |
| | import networkx as nx |
| | from collections import defaultdict |
| | from torch_geometric.data import Data |
| | from polyatomic_complexes.src.complexes.abstract_complex import AbstractComplex |
| | from polyatomic_complexes.src.complexes import PolyatomicGeometrySMILE |
| |
|
| |
|
| | def compressed_topsignal_graph_from_smiles( |
| | smile: str, y_val: int, topk_lap: int = 5 |
| | ) -> Data | None: |
| | try: |
| | |
| | pg = PolyatomicGeometrySMILE(smile=smile, mode="abstract") |
| | ac = pg.smiles_to_geom_complex() |
| | assert isinstance(ac, AbstractComplex) |
| |
|
| | |
| | mol = Chem.MolFromSmiles(smile) |
| | if mol is None: |
| | return None |
| |
|
| | |
| | chains = ac.get_raw_k_chains() |
| | chain0 = chains.get("chain_0", []) |
| | atom_types = [6, 7, 8, 15, 16, 17] |
| | hyb_types = [ |
| | Chem.rdchem.HybridizationType.SP, |
| | Chem.rdchem.HybridizationType.SP2, |
| | Chem.rdchem.HybridizationType.SP3, |
| | ] |
| | node_feats = [] |
| | for atom in mol.GetAtoms(): |
| | idx = atom.GetIdx() |
| | |
| | c0 = float(chain0[idx]) if idx < len(chain0) else 0.0 |
| | feats = [c0] |
| | feats += one_hot(atom.GetAtomicNum(), atom_types) |
| | feats += one_hot(atom.GetHybridization(), hyb_types) |
| | feats += [ |
| | float(atom.GetDegree()), |
| | float(atom.GetIsAromatic()), |
| | float(atom.GetFormalCharge()), |
| | ] |
| | node_feats.append(feats) |
| | x = torch.tensor(node_feats, dtype=torch.float32) |
| | n = x.size(0) |
| |
|
| | |
| | sk = ac.get_skeleta().get("molecule_skeleta", [[]])[0] |
| | zero = next((lst for dim, lst in sk if dim == "0"), []) |
| | node_ids = [next(iter(fz))[0] for fz in zero] |
| | atom_map = defaultdict(list) |
| | for i, nid in enumerate(node_ids): |
| | symbol = nid.split("_")[0] |
| | atom_map[symbol].append(i) |
| |
|
| | edge_index_list, edge_attr_list = [], [] |
| | bond_types = [ |
| | Chem.rdchem.BondType.SINGLE, |
| | Chem.rdchem.BondType.DOUBLE, |
| | Chem.rdchem.BondType.TRIPLE, |
| | Chem.rdchem.BondType.AROMATIC, |
| | ] |
| | for a1, a2, (btype, order) in ac.get_bonds(): |
| | bt_val = getattr(Chem.rdchem.BondType, btype, None) |
| | for i in atom_map.get(a1, []): |
| | for j in atom_map.get(a2, []): |
| | if i < n and j < n: |
| | edge_index_list += [[i, j], [j, i]] |
| | attr = one_hot(bt_val, bond_types) + [float(order), 0.0] |
| | edge_attr_list += [attr, attr] |
| | if not edge_index_list: |
| | for bond in mol.GetBonds(): |
| | i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() |
| | edge_index_list += [[i, j], [j, i]] |
| | attr = one_hot(bond.GetBondType(), bond_types) |
| | attr += [float(bond.GetIsConjugated()), float(bond.IsInRing())] |
| | edge_attr_list += [attr, attr] |
| |
|
| | edge_index = torch.tensor(edge_index_list, dtype=torch.long).t().contiguous() |
| | edge_attr = torch.tensor(edge_attr_list, dtype=torch.float32) |
| |
|
| | |
| | G = nx.Graph() |
| | G.add_nodes_from(range(n)) |
| | G.add_edges_from(edge_index_list) |
| | cent = nx.closeness_centrality(G) |
| | spd = dict(nx.all_pairs_shortest_path_length(G)) |
| | cent_vec = [cent.get(i, 0.0) for i in range(n)] |
| | spd_vec = [ |
| | sum(d.values()) / max(len(d), 1) for d in (spd.get(i, {}) for i in range(n)) |
| | ] |
| | cent_t = torch.tensor(cent_vec, dtype=torch.float32).view(n, 1) |
| | spd_t = torch.tensor(spd_vec, dtype=torch.float32).view(n, 1) |
| | x = torch.cat([x, cent_t, spd_t], dim=1) |
| |
|
| | |
| |
|
| | |
| | g_stats, lap_feats = [], [] |
| | for k, arr in chains.items(): |
| | if k == "chain_0": |
| | continue |
| | a = np.array(arr, dtype=np.float32) |
| | g_stats += [a.mean(), a.std()] |
| |
|
| | |
| |
|
| | for grp in ac.get_laplacians().get("molecule_laplacians", []): |
| | recs = grp if isinstance(grp, list) else [grp] |
| | for _, mat in recs: |
| | |
| | M = np.array(mat, dtype=np.float32) |
| | |
| | try: |
| | eigs = np.linalg.eigvalsh(M) |
| | except Exception: |
| | eigs = np.zeros(M.shape[0], dtype=np.float32) |
| | |
| | nonzero = eigs[eigs > 1e-6] |
| | vals = nonzero[:topk_lap] if len(nonzero) >= topk_lap else nonzero |
| | |
| | if len(vals) < topk_lap: |
| | vals = np.pad(vals, (0, topk_lap - len(vals))) |
| | lap_feats += list(vals) |
| |
|
| | |
| | spectral = ac.get_spectral_k_chains() |
| | spec_feats = [] |
| | for arr in spectral.values(): |
| | a = np.array(arr, dtype=np.float32) |
| | spec_feats += [a.mean(), a.std()] |
| |
|
| | |
| | b0 = nx.number_connected_components(G) |
| | b1 = sum( |
| | len(nx.cycle_basis(G.subgraph(comp))) for comp in nx.connected_components(G) |
| | ) |
| |
|
| | |
| | all_feats = g_stats + lap_feats + spec_feats + [float(b0), float(b1)] |
| | graph_feats = torch.tensor(all_feats, dtype=torch.float32) |
| |
|
| | |
| |
|
| | data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr) |
| | data.graph_feats = graph_feats |
| | data.y = torch.tensor([y_val], dtype=torch.float) |
| | |
| | return data |
| | except Exception as e: |
| | |
| | return None |
| |
|
| |
|
| | def one_hot(val, choices): |
| | return [1.0 if val == c else 0.0 for c in choices] |
| |
|