Spaces:
Running
Running
| from numba import njit | |
| import numpy as np | |
| import networkx as nx | |
| from rdkit import Chem | |
| def mol2graph(mol: Chem.Mol) -> nx.Graph: | |
| """ Convert an RDKit molecule to a NetworkX graph. | |
| Args: | |
| mol (Chem.Mol): The RDKit molecule to convert. | |
| Returns: | |
| nx.Graph: The NetworkX graph representation of the molecule. | |
| """ | |
| # NOTE: https://github.com/maxhodak/keras-molecules/pull/32/files | |
| # TODO: Double check this implementation too: https://gist.github.com/jhjensen2/6450138cda3ab796a30850610843cfff | |
| if mol is None: | |
| return nx.empty_graph() | |
| G = nx.Graph() | |
| for atom in mol.GetAtoms(): | |
| # Skip non-heavy atoms | |
| if atom.GetAtomicNum() != 0: | |
| G.add_node(atom.GetIdx(), label=atom.GetSymbol()) | |
| for bond in mol.GetBonds(): | |
| # Skip bonds to non-heavy atoms | |
| if bond.GetBeginAtom().GetAtomicNum() == 0 or bond.GetEndAtom().GetAtomicNum() == 0: | |
| continue | |
| G.add_edge(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx(), label=bond.GetBondType()) | |
| return G | |
| def smiles2graph(smiles: str) -> nx.Graph: | |
| """ Convert a SMILES string to a NetworkX graph. | |
| Args: | |
| smiles (str): The SMILES string to convert. | |
| Returns: | |
| nx.Graph: The NetworkX graph representation of the molecule. | |
| """ | |
| return mol2graph(Chem.MolFromSmiles(smiles)) | |
| def get_smiles2graph_edit_distance(smi1: str, smi2: str, **kwargs) -> float: | |
| """ Compute the graph edit distance between two SMILES strings. | |
| Args: | |
| smi1 (str): The first SMILES string. | |
| smi2 (str): The second SMILES string. | |
| **kwargs: Additional keyword arguments for `nx.graph_edit_distance`. | |
| Returns: | |
| float: The graph edit distance between the two SMILES strings. | |
| """ | |
| ged = nx.graph_edit_distance(smiles2graph(smi1), smiles2graph(smi2), **kwargs) | |
| return ged if ged is not None else np.inf | |
| def get_mol2graph_edit_distance(mol1: str, mol2: str, **kwargs) -> float: | |
| """ Compute the graph edit distance between two RDKit molecules. | |
| Args: | |
| mol1 (Chem.Mol): The first RDKit molecule. | |
| mol2 (Chem.Mol): The second RDKit molecule. | |
| **kwargs: Additional keyword arguments for `nx.graph_edit_distance`. | |
| Returns: | |
| float: The graph edit distance between the two RDKit molecules. | |
| """ | |
| ged = nx.graph_edit_distance(mol2graph(mol1), mol2graph(mol2), **kwargs) | |
| return ged if ged is not None else np.inf | |
| def get_smiles2graph_edit_distance_norm( | |
| smi1: str, | |
| smi2: str, | |
| ged_G1_G2: None, | |
| eps: float = 1e-9, | |
| **kwargs, | |
| ) -> float: | |
| """ Compute the normalized graph edit distance between two SMILES strings. | |
| Args: | |
| smi1 (str): The first SMILES string. | |
| smi2 (str): The second SMILES string. | |
| ged_G1_G2 (float): The graph edit distance between the two graphs. If None, it will be computed using `nx.graph_edit_distance`. | |
| eps (float): A small value to avoid division by zero. | |
| **kwargs: Additional keyword arguments for `nx.graph_edit_distance`. | |
| Returns: | |
| float: The normalized graph edit distance between the two SMILES strings. | |
| """ | |
| G1 = smiles2graph(smi1) | |
| G2 = smiles2graph(smi2) | |
| G0 = nx.empty_graph() | |
| ged_G1_G2 = ged_G1_G2 if ged_G1_G2 is not None else nx.graph_edit_distance(G1, G2, **kwargs) | |
| ged_G1_G0 = nx.graph_edit_distance(G1, G0, **kwargs) | |
| ged_G2_G0 = nx.graph_edit_distance(G2, G0, **kwargs) | |
| if None in [ged_G1_G2, ged_G1_G0, ged_G2_G0]: | |
| return np.inf | |
| return ged_G1_G2 / (ged_G1_G0 + ged_G2_G0 + eps) | |
| def smiles2adjacency_matrix(smiles: str) -> np.ndarray: | |
| return nx.adjacency_matrix(smiles2graph(smiles)).todense() | |
| def build_label_mapping(G1, G2): | |
| labels = set() | |
| for G in [G1, G2]: | |
| for node in G.nodes(): | |
| labels.add(G.nodes[node]['label']) | |
| label_to_int = {label: idx for idx, label in enumerate(sorted(labels))} | |
| return label_to_int | |
| def preprocess_graph(G, label_to_int): | |
| n = G.number_of_nodes() | |
| adj = np.zeros((n, n), dtype=np.int32) | |
| labels = np.zeros(n, dtype=np.int32) | |
| node_id_to_idx = {} | |
| for idx, node in enumerate(G.nodes()): | |
| node_id_to_idx[node] = idx | |
| label = G.nodes[node]['label'] | |
| labels[idx] = label_to_int[label] | |
| for u, v in G.edges(): | |
| idx_u = node_id_to_idx[u] | |
| idx_v = node_id_to_idx[v] | |
| adj[idx_u, idx_v] = 1 | |
| adj[idx_v, idx_u] = 1 # Assuming undirected graph | |
| return adj, labels | |
| def compute_cost_matrix(labels1, labels2, degrees1, degrees2): | |
| n1 = labels1.shape[0] | |
| n2 = labels2.shape[0] | |
| C = np.zeros((n1, n2), dtype=np.float64) | |
| for i in range(n1): | |
| for j in range(n2): | |
| label_cost = 0.0 if labels1[i] == labels2[j] else 1.0 | |
| neighborhood_cost = abs(degrees1[i] - degrees2[j]) | |
| C[i, j] = label_cost + neighborhood_cost | |
| return C | |
| def greedy_assignment(C): | |
| n1, n2 = C.shape | |
| assigned_cols = np.full(n2, False) | |
| row_ind = np.full(n1, -1, dtype=np.int32) | |
| for i in range(n1): | |
| min_cost = np.inf | |
| min_j = -1 | |
| for j in range(n2): | |
| if not assigned_cols[j] and C[i, j] < min_cost: | |
| min_cost = C[i, j] | |
| min_j = j | |
| if min_j != -1: | |
| row_ind[i] = min_j | |
| assigned_cols[min_j] = True | |
| return row_ind | |
| def compute_total_cost(C, row_ind, n1, n2, c_node_del, c_node_ins): | |
| total_cost = 0.0 | |
| assigned_cols = np.full(n2, False) | |
| for i in range(n1): | |
| j = row_ind[i] | |
| if j != -1: | |
| total_cost += C[i, j] | |
| assigned_cols[j] = True | |
| else: | |
| total_cost += c_node_del | |
| for j in range(n2): | |
| if not assigned_cols[j]: | |
| total_cost += c_node_ins | |
| return total_cost | |
| def approximate_graph_edit_distance(adj1, labels1, adj2, labels2, c_node_del=1.0, c_node_ins=1.0): | |
| degrees1 = adj1.sum(axis=1) | |
| degrees2 = adj2.sum(axis=1) | |
| C = compute_cost_matrix(labels1, labels2, degrees1, degrees2) | |
| row_ind = greedy_assignment(C) | |
| total_cost = compute_total_cost(C, row_ind, labels1.shape[0], labels2.shape[0], c_node_del, c_node_ins) | |
| return total_cost | |
| def get_approximate_ged(G1, G2): | |
| label_to_int = build_label_mapping(G1, G2) | |
| adj1, labels1 = preprocess_graph(G1, label_to_int) | |
| adj2, labels2 = preprocess_graph(G2, label_to_int) | |
| cost = approximate_graph_edit_distance(adj1, labels1, adj2, labels2) | |
| return cost | |
| def get_smiles2graph_edit_distance_approx(smi1: str, smi2: str) -> float: | |
| G1 = smiles2graph(smi1) | |
| G2 = smiles2graph(smi2) | |
| return get_approximate_ged(G1, G2) | |