Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| """ | |
| Created on Thu Jul 28 14:40:59 2022 | |
| @author: BM109X32G-10GPU-02 | |
| """ | |
| import os | |
| from collections import OrderedDict | |
| import numpy as np | |
| from rdkit import Chem | |
| from rdkit.Chem import AllChem | |
| from rdkit.Chem import rdchem | |
| from compound_constants import DAY_LIGHT_FG_SMARTS_LIST | |
| def get_gasteiger_partial_charges(mol, n_iter=12): | |
| """ | |
| Calculates list of gasteiger partial charges for each atom in mol object. | |
| Args: | |
| mol: rdkit mol object. | |
| n_iter(int): number of iterations. Default 12. | |
| Returns: | |
| list of computed partial charges for each atom. | |
| """ | |
| Chem.rdPartialCharges.ComputeGasteigerCharges(mol, nIter=n_iter, | |
| throwOnParamFailure=True) | |
| partial_charges = [float(a.GetProp('_GasteigerCharge')) for a in | |
| mol.GetAtoms()] | |
| return partial_charges | |
| def create_standardized_mol_id(smiles): | |
| """ | |
| Args: | |
| smiles: smiles sequence. | |
| Returns: | |
| inchi. | |
| """ | |
| if check_smiles_validity(smiles): | |
| # remove stereochemistry | |
| smiles = AllChem.MolToSmiles(AllChem.MolFromSmiles(smiles), | |
| isomericSmiles=False) | |
| mol = Chem.AddHs(AllChem.MolFromSmiles(smiles)) | |
| if not mol is None: # to catch weird issue with O=C1O[al]2oc(=O)c3ccc(cn3)c3ccccc3c3cccc(c3)c3ccccc3c3cc(C(F)(F)F)c(cc3o2)-c2ccccc2-c2cccc(c2)-c2ccccc2-c2cccnc21 | |
| if '.' in smiles: # if multiple species, pick largest molecule | |
| mol_species_list = split_rdkit_mol_obj(mol) | |
| largest_mol = get_largest_mol(mol_species_list) | |
| inchi = AllChem.MolToInchi(largest_mol) | |
| else: | |
| inchi = AllChem.MolToInchi(mol) | |
| return inchi | |
| else: | |
| return | |
| else: | |
| return | |
| def check_smiles_validity(smiles): | |
| """ | |
| Check whether the smile can't be converted to rdkit mol object. | |
| """ | |
| try: | |
| m = Chem.MolFromSmiles(smiles) | |
| if m: | |
| return True | |
| else: | |
| return False | |
| except Exception as e: | |
| return False | |
| def split_rdkit_mol_obj(mol): | |
| """ | |
| Split rdkit mol object containing multiple species or one species into a | |
| list of mol objects or a list containing a single object respectively. | |
| Args: | |
| mol: rdkit mol object. | |
| """ | |
| smiles = AllChem.MolToSmiles(mol, isomericSmiles=True) | |
| smiles_list = smiles.split('.') | |
| mol_species_list = [] | |
| for s in smiles_list: | |
| if check_smiles_validity(s): | |
| mol_species_list.append(AllChem.MolFromSmiles(s)) | |
| return mol_species_list | |
| def get_largest_mol(mol_list): | |
| """ | |
| Given a list of rdkit mol objects, returns mol object containing the | |
| largest num of atoms. If multiple containing largest num of atoms, | |
| picks the first one. | |
| Args: | |
| mol_list(list): a list of rdkit mol object. | |
| Returns: | |
| the largest mol. | |
| """ | |
| num_atoms_list = [len(m.GetAtoms()) for m in mol_list] | |
| largest_mol_idx = num_atoms_list.index(max(num_atoms_list)) | |
| return mol_list[largest_mol_idx] | |
| def rdchem_enum_to_list(values): | |
| """values = {0: rdkit.Chem.rdchem.ChiralType.CHI_UNSPECIFIED, | |
| 1: rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW, | |
| 2: rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW, | |
| 3: rdkit.Chem.rdchem.ChiralType.CHI_OTHER} | |
| """ | |
| return [values[i] for i in range(len(values))] | |
| def safe_index(alist, elem): | |
| """ | |
| Return index of element e in list l. If e is not present, return the last index | |
| """ | |
| try: | |
| return alist.index(elem) | |
| except ValueError: | |
| return len(alist) - 1 | |
| def get_atom_feature_dims(list_acquired_feature_names): | |
| """ tbd | |
| """ | |
| return list(map(len, [CompoundKit.atom_vocab_dict[name] for name in list_acquired_feature_names])) | |
| def get_bond_feature_dims(list_acquired_feature_names): | |
| """ tbd | |
| """ | |
| list_bond_feat_dim = list(map(len, [CompoundKit.bond_vocab_dict[name] for name in list_acquired_feature_names])) | |
| # +1 for self loop edges | |
| return [_l + 1 for _l in list_bond_feat_dim] | |
| class CompoundKit(object): | |
| """ | |
| CompoundKit | |
| """ | |
| atom_vocab_dict = { | |
| "atomic_num": list(range(1, 119)) + ['misc'], | |
| "chiral_tag": rdchem_enum_to_list(rdchem.ChiralType.values), | |
| } | |
| bond_vocab_dict = { | |
| "bond_dir": rdchem_enum_to_list(rdchem.BondDir.values), | |
| "bond_type": rdchem_enum_to_list(rdchem.BondType.values), | |
| } | |
| # float features | |
| atom_float_names = ["van_der_waals_radis", "partial_charge", 'mass'] | |
| # bond_float_feats= ["bond_length", "bond_angle"] # optional | |
| ### functional groups | |
| day_light_fg_smarts_list = DAY_LIGHT_FG_SMARTS_LIST | |
| day_light_fg_mo_list = [Chem.MolFromSmarts(smarts) for smarts in day_light_fg_smarts_list] | |
| morgan_fp_N = 200 | |
| morgan2048_fp_N = 2048 | |
| maccs_fp_N = 167 | |
| period_table = Chem.GetPeriodicTable() | |
| ### atom | |
| def get_atom_value(atom, name): | |
| """get atom values""" | |
| if name == 'atomic_num': | |
| return atom.GetAtomicNum() | |
| elif name == 'chiral_tag': | |
| return atom.GetChiralTag() | |
| elif name == 'degree': | |
| return atom.GetDegree() | |
| elif name == 'explicit_valence': | |
| return atom.GetExplicitValence() | |
| elif name == 'formal_charge': | |
| return atom.GetFormalCharge() | |
| elif name == 'hybridization': | |
| return atom.GetHybridization() | |
| elif name == 'implicit_valence': | |
| return atom.GetImplicitValence() | |
| elif name == 'is_aromatic': | |
| return int(atom.GetIsAromatic()) | |
| elif name == 'mass': | |
| return int(atom.GetMass()) | |
| elif name == 'total_numHs': | |
| return atom.GetTotalNumHs() | |
| elif name == 'num_radical_e': | |
| return atom.GetNumRadicalElectrons() | |
| elif name == 'atom_is_in_ring': | |
| return int(atom.IsInRing()) | |
| elif name == 'valence_out_shell': | |
| return CompoundKit.period_table.GetNOuterElecs(atom.GetAtomicNum()) | |
| else: | |
| raise ValueError(name) | |
| def get_atom_feature_id(atom, name): | |
| """get atom features id""" | |
| assert name in CompoundKit.atom_vocab_dict, "%s not found in atom_vocab_dict" % name | |
| return safe_index(CompoundKit.atom_vocab_dict[name], CompoundKit.get_atom_value(atom, name)) | |
| def get_atom_feature_size(name): | |
| """get atom features size""" | |
| assert name in CompoundKit.atom_vocab_dict, "%s not found in atom_vocab_dict" % name | |
| return len(CompoundKit.atom_vocab_dict[name]) | |
| ### bond | |
| def get_bond_value(bond, name): | |
| """get bond values""" | |
| if name == 'bond_dir': | |
| return bond.GetBondDir() | |
| elif name == 'bond_type': | |
| return bond.GetBondType() | |
| elif name == 'is_in_ring': | |
| return int(bond.IsInRing()) | |
| elif name == 'is_conjugated': | |
| return int(bond.GetIsConjugated()) | |
| elif name == 'bond_stereo': | |
| return bond.GetStereo() | |
| else: | |
| raise ValueError(name) | |
| def get_bond_feature_id(bond, name): | |
| """get bond features id""" | |
| assert name in CompoundKit.bond_vocab_dict, "%s not found in bond_vocab_dict" % name | |
| return safe_index(CompoundKit.bond_vocab_dict[name], CompoundKit.get_bond_value(bond, name)) | |
| def get_bond_feature_size(name): | |
| """get bond features size""" | |
| assert name in CompoundKit.bond_vocab_dict, "%s not found in bond_vocab_dict" % name | |
| return len(CompoundKit.bond_vocab_dict[name]) | |
| ### fingerprint | |
| def get_morgan_fingerprint(mol, radius=2): | |
| """get morgan fingerprint""" | |
| nBits = CompoundKit.morgan_fp_N | |
| mfp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits) | |
| return [int(b) for b in mfp.ToBitString()] | |
| def get_morgan2048_fingerprint(mol, radius=2): | |
| """get morgan2048 fingerprint""" | |
| nBits = CompoundKit.morgan2048_fp_N | |
| mfp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits) | |
| return [int(b) for b in mfp.ToBitString()] | |
| def get_maccs_fingerprint(mol): | |
| """get maccs fingerprint""" | |
| fp = AllChem.GetMACCSKeysFingerprint(mol) | |
| return [int(b) for b in fp.ToBitString()] | |
| ### functional groups | |
| def get_daylight_functional_group_counts(mol): | |
| """get daylight functional group counts""" | |
| fg_counts = [] | |
| for fg_mol in CompoundKit.day_light_fg_mo_list: | |
| sub_structs = Chem.Mol.GetSubstructMatches(mol, fg_mol, uniquify=True) | |
| fg_counts.append(len(sub_structs)) | |
| return fg_counts | |
| def get_ring_size(mol): | |
| """return (N,6) list""" | |
| rings = mol.GetRingInfo() | |
| rings_info = [] | |
| for r in rings.AtomRings(): | |
| rings_info.append(r) | |
| ring_list = [] | |
| for atom in mol.GetAtoms(): | |
| atom_result = [] | |
| for ringsize in range(3, 9): | |
| num_of_ring_at_ringsize = 0 | |
| for r in rings_info: | |
| if len(r) == ringsize and atom.GetIdx() in r: | |
| num_of_ring_at_ringsize += 1 | |
| if num_of_ring_at_ringsize > 8: | |
| num_of_ring_at_ringsize = 9 | |
| atom_result.append(num_of_ring_at_ringsize) | |
| ring_list.append(atom_result) | |
| return ring_list | |
| def atom_to_feat_vector(atom): | |
| """ tbd """ | |
| atom_names = { | |
| "atomic_num": safe_index(CompoundKit.atom_vocab_dict["atomic_num"], atom.GetAtomicNum()), | |
| } | |
| return atom_names | |
| def get_atom_names(mol): | |
| """get atom name list | |
| TODO: to be remove in the future | |
| """ | |
| atom_features_dicts = [] | |
| Chem.rdPartialCharges.ComputeGasteigerCharges(mol) | |
| for i, atom in enumerate(mol.GetAtoms()): | |
| atom_features_dicts.append(CompoundKit.atom_to_feat_vector(atom)) | |
| ring_list = CompoundKit.get_ring_size(mol) | |
| for i, atom in enumerate(mol.GetAtoms()): | |
| atom_features_dicts[i]['in_num_ring_with_size3'] = safe_index( | |
| CompoundKit.atom_vocab_dict['in_num_ring_with_size3'], ring_list[i][0]) | |
| atom_features_dicts[i]['in_num_ring_with_size4'] = safe_index( | |
| CompoundKit.atom_vocab_dict['in_num_ring_with_size4'], ring_list[i][1]) | |
| atom_features_dicts[i]['in_num_ring_with_size5'] = safe_index( | |
| CompoundKit.atom_vocab_dict['in_num_ring_with_size5'], ring_list[i][2]) | |
| atom_features_dicts[i]['in_num_ring_with_size6'] = safe_index( | |
| CompoundKit.atom_vocab_dict['in_num_ring_with_size6'], ring_list[i][3]) | |
| atom_features_dicts[i]['in_num_ring_with_size7'] = safe_index( | |
| CompoundKit.atom_vocab_dict['in_num_ring_with_size7'], ring_list[i][4]) | |
| atom_features_dicts[i]['in_num_ring_with_size8'] = safe_index( | |
| CompoundKit.atom_vocab_dict['in_num_ring_with_size8'], ring_list[i][5]) | |
| return atom_features_dicts | |
| def check_partial_charge(atom): | |
| """tbd""" | |
| pc = atom.GetDoubleProp('_GasteigerCharge') | |
| if pc != pc: | |
| # unsupported atom, replace nan with 0 | |
| pc = 0 | |
| if pc == float('inf'): | |
| # max 4 for other atoms, set to 10 here if inf is get | |
| pc = 10 | |
| return pc | |
| class Compound3DKit(object): | |
| """the 3Dkit of Compound""" | |
| def get_atom_poses(mol, conf): | |
| """tbd""" | |
| atom_poses = [] | |
| for i, atom in enumerate(mol.GetAtoms()): | |
| if atom.GetAtomicNum() == 0: | |
| return [[0.0, 0.0, 0.0]] * len(mol.GetAtoms()) | |
| pos = conf.GetAtomPosition(i) | |
| atom_poses.append([pos.x, pos.y, pos.z]) | |
| return atom_poses | |
| def get_MMFF_atom_poses(mol, numConfs=None, return_energy=False): | |
| """the atoms of mol will be changed in some cases.""" | |
| try: | |
| new_mol = Chem.AddHs(mol) | |
| res = AllChem.EmbedMultipleConfs(new_mol, numConfs=numConfs) | |
| ### MMFF generates multiple conformations | |
| res = AllChem.MMFFOptimizeMoleculeConfs(new_mol) | |
| #new_mol = Chem.RemoveHs(new_mol) | |
| index = np.argmin([x[1] for x in res]) | |
| energy = res[index][1] | |
| conf = new_mol.GetConformer(id=int(index)) | |
| except: | |
| new_mol = Chem.AddHs(mol) | |
| AllChem.Compute2DCoords(new_mol) | |
| energy = 0 | |
| conf = new_mol.GetConformer() | |
| atom_poses = Compound3DKit.get_atom_poses(new_mol, conf) | |
| if return_energy: | |
| return new_mol, atom_poses, energy | |
| else: | |
| return new_mol, atom_poses | |
| def get_2d_atom_poses(mol): | |
| """get 2d atom poses""" | |
| AllChem.Compute2DCoords(mol) | |
| conf = mol.GetConformer() | |
| atom_poses = Compound3DKit.get_atom_poses(mol, conf) | |
| return atom_poses | |
| def get_bond_lengths(edges, atom_poses): | |
| """get bond lengths""" | |
| bond_lengths = [] | |
| for src_node_i, tar_node_j in edges: | |
| bond_lengths.append(np.linalg.norm(atom_poses[tar_node_j] - atom_poses[src_node_i])) | |
| bond_lengths = np.array(bond_lengths, 'float32') | |
| return bond_lengths | |
| def get_superedge_angles(edges, atom_poses, dir_type='HT'): | |
| """get superedge angles""" | |
| def _get_vec(atom_poses, edge): | |
| return atom_poses[edge[1]] - atom_poses[edge[0]] | |
| def _get_angle(vec1, vec2): | |
| norm1 = np.linalg.norm(vec1) | |
| norm2 = np.linalg.norm(vec2) | |
| if norm1 == 0 or norm2 == 0: | |
| return 0 | |
| vec1 = vec1 / (norm1 + 1e-5) # 1e-5: prevent numerical errors | |
| vec2 = vec2 / (norm2 + 1e-5) | |
| angle = np.arccos(np.dot(vec1, vec2)) | |
| return angle | |
| E = len(edges) | |
| edge_indices = np.arange(E) | |
| super_edges = [] | |
| bond_angles = [] | |
| bond_angle_dirs = [] | |
| for tar_edge_i in range(E): | |
| tar_edge = edges[tar_edge_i] | |
| if dir_type == 'HT': | |
| src_edge_indices = edge_indices[edges[:, 1] == tar_edge[0]] | |
| elif dir_type == 'HH': | |
| src_edge_indices = edge_indices[edges[:, 1] == tar_edge[1]] | |
| else: | |
| raise ValueError(dir_type) | |
| for src_edge_i in src_edge_indices: | |
| if src_edge_i == tar_edge_i: | |
| continue | |
| src_edge = edges[src_edge_i] | |
| src_vec = _get_vec(atom_poses, src_edge) | |
| tar_vec = _get_vec(atom_poses, tar_edge) | |
| super_edges.append([src_edge_i, tar_edge_i]) | |
| angle = _get_angle(src_vec, tar_vec) | |
| bond_angles.append(angle) | |
| bond_angle_dirs.append(src_edge[1] == tar_edge[0]) # H -> H or H -> T | |
| if len(super_edges) == 0: | |
| super_edges = np.zeros([0, 2], 'int64') | |
| bond_angles = np.zeros([0,], 'float32') | |
| else: | |
| super_edges = np.array(super_edges, 'int64') | |
| bond_angles = np.array(bond_angles, 'float32') | |
| return super_edges, bond_angles, bond_angle_dirs | |
| def new_smiles_to_graph_data(smiles, **kwargs): | |
| """ | |
| Convert smiles to graph data. | |
| """ | |
| mol = Chem.AddHs(AllChem.MolFromSmiles(smiles)) | |
| if mol is None: | |
| return None | |
| data = new_mol_to_graph_data(mol) | |
| return data | |
| def new_mol_to_graph_data(mol): | |
| """ | |
| mol_to_graph_data | |
| Args: | |
| atom_features: Atom features. | |
| edge_features: Edge features. | |
| morgan_fingerprint: Morgan fingerprint. | |
| functional_groups: Functional groups. | |
| """ | |
| if len(mol.GetAtoms()) == 0: | |
| return None | |
| atom_id_names = list(CompoundKit.atom_vocab_dict.keys()) + CompoundKit.atom_float_names | |
| bond_id_names = list(CompoundKit.bond_vocab_dict.keys()) | |
| data = {} | |
| ### atom features | |
| data = {name: [] for name in atom_id_names} | |
| raw_atom_feat_dicts = CompoundKit.get_atom_names(mol) | |
| for atom_feat in raw_atom_feat_dicts: | |
| for name in atom_id_names: | |
| data[name].append(atom_feat[name]) | |
| ### bond and bond features | |
| for name in bond_id_names: | |
| data[name] = [] | |
| data['edges'] = [] | |
| for bond in mol.GetBonds(): | |
| i = bond.GetBeginAtomIdx() | |
| j = bond.GetEndAtomIdx() | |
| # i->j and j->i | |
| data['edges'] += [(i, j), (j, i)] | |
| for name in bond_id_names: | |
| bond_feature_id = CompoundKit.get_bond_feature_id(bond, name) | |
| data[name] += [bond_feature_id] * 2 | |
| #### self loop | |
| N = len(data[atom_id_names[0]]) | |
| for i in range(N): | |
| data['edges'] += [(i, i)] | |
| for name in bond_id_names: | |
| bond_feature_id = get_bond_feature_dims([name])[0] - 1 # self loop: value = len - 1 | |
| data[name] += [bond_feature_id] * N | |
| ### make ndarray and check length | |
| for name in list(CompoundKit.atom_vocab_dict.keys()): | |
| data[name] = np.array(data[name], 'int64') | |
| for name in CompoundKit.atom_float_names: | |
| data[name] = np.array(data[name], 'float32') | |
| for name in bond_id_names: | |
| data[name] = np.array(data[name], 'int64') | |
| data['edges'] = np.array(data['edges'], 'int64') | |
| ### morgan fingerprint | |
| data['morgan_fp'] = np.array(CompoundKit.get_morgan_fingerprint(mol), 'int64') | |
| # data['morgan2048_fp'] = np.array(CompoundKit.get_morgan2048_fingerprint(mol), 'int64') | |
| data['maccs_fp'] = np.array(CompoundKit.get_maccs_fingerprint(mol), 'int64') | |
| data['daylight_fg_counts'] = np.array(CompoundKit.get_daylight_functional_group_counts(mol), 'int64') | |
| return data | |
| def mol_to_graph_data(mol): | |
| """ | |
| mol_to_graph_data | |
| Args: | |
| atom_features: Atom features. | |
| edge_features: Edge features. | |
| morgan_fingerprint: Morgan fingerprint. | |
| functional_groups: Functional groups. | |
| """ | |
| if len(mol.GetAtoms()) == 0: | |
| return None | |
| atom_id_names = [ | |
| "atomic_num" | |
| ] | |
| bond_id_names = [ | |
| "bond_dir", "bond_type" | |
| ] | |
| data = {} | |
| for name in atom_id_names: | |
| data[name] = [] | |
| data['mass'] = [] | |
| for name in bond_id_names: | |
| data[name] = [] | |
| data['edges'] = [] | |
| ### atom features | |
| for i, atom in enumerate(mol.GetAtoms()): | |
| if atom.GetAtomicNum() == 0: | |
| return None | |
| for name in atom_id_names: | |
| data[name].append(CompoundKit.get_atom_feature_id(atom, name) + 1) # 0: OOV | |
| data['mass'].append(CompoundKit.get_atom_value(atom, 'mass') * 0.01) | |
| ### bond features | |
| for bond in mol.GetBonds(): | |
| i = bond.GetBeginAtomIdx() | |
| j = bond.GetEndAtomIdx() | |
| # i->j and j->i | |
| data['edges'] += [(i, j), (j, i)] | |
| for name in bond_id_names: | |
| bond_feature_id = CompoundKit.get_bond_feature_id(bond, name) + 1 # 0: OOV | |
| data[name] += [bond_feature_id] * 2 | |
| num_atoms = mol.GetNumAtoms() | |
| atoms_list = [] | |
| for i in range(num_atoms): | |
| atom = mol.GetAtomWithIdx(i) | |
| atoms_list.append(atom.GetSymbol()) | |
| ### self loop (+2) | |
| N = len(data[atom_id_names[0]]) | |
| for i in range(N): | |
| data['edges'] += [(i, i)] | |
| for name in bond_id_names: | |
| bond_feature_id = CompoundKit.get_bond_feature_size(name) + 2 # N + 2: self loop | |
| data[name] += [bond_feature_id] * N | |
| ### check whether edge exists | |
| if len(data['edges']) == 0: # mol has no bonds | |
| for name in bond_id_names: | |
| data[name] = np.zeros((0,), dtype="int64") | |
| data['edges'] = np.zeros((0, 2), dtype="int64") | |
| ### make ndarray and check length | |
| for name in atom_id_names: | |
| data[name] = np.array(data[name], 'int64') | |
| data['mass'] = np.array(data['mass'], 'float32') | |
| for name in bond_id_names: | |
| data[name] = np.array(data[name], 'int64') | |
| data['edges'] = np.array(data['edges'], 'int64') | |
| data['atoms'] = np.array(atoms_list) | |
| ### morgan fingerprint | |
| #data['morgan_fp'] = np.array(CompoundKit.get_morgan_fingerprint(mol), 'int64') | |
| # data['morgan2048_fp'] = np.array(CompoundKit.get_morgan2048_fingerprint(mol), 'int64') | |
| #data['maccs_fp'] = np.array(CompoundKit.get_maccs_fingerprint(mol), 'int64') | |
| #data['daylight_fg_counts'] = np.array(CompoundKit.get_daylight_functional_group_counts(mol), 'int64') | |
| #return data['bonds_dir'],data['adj_angle'] | |
| return data | |
| def mol_to_geognn_graph_data(mol, atom_poses, dir_type): | |
| """ | |
| mol: rdkit molecule | |
| dir_type: direction type for bond_angle grpah | |
| """ | |
| if len(mol.GetAtoms()) == 0: | |
| return None | |
| data = mol_to_graph_data(mol) | |
| data['atom_pos'] = np.array(atom_poses, 'float32') | |
| data['bond_length'] = Compound3DKit.get_bond_lengths(data['edges'], data['atom_pos']) | |
| # BondAngleGraph_edges, bond_angles, bond_angle_dirs = \ | |
| # Compound3DKit.get_superedge_angles(data['edges'], data['atom_pos']) | |
| # data['BondAngleGraph_edges'] = BondAngleGraph_edges | |
| # data['bond_angle'] = np.array(bond_angles, 'float32') | |
| data['adj_node'] = gen_adj(len(data['atoms']),data['edges'],data['bond_length']) | |
| # data['adj_edge'] = gen_adj(len(data['bond_dir']),data['BondAngleGraph_edges'],data['bond_angle']) | |
| return data['atoms'], data['adj_node'] | |
| def mol_to_geognn_graph_data_MMFF3d(smiles): | |
| """tbd""" | |
| mol = Chem.AddHs(AllChem.MolFromSmiles(smiles)) | |
| if len(mol.GetAtoms()) <= 400: | |
| mol, atom_poses = Compound3DKit.get_MMFF_atom_poses(mol, numConfs=10) | |
| else: | |
| atom_poses = Compound3DKit.get_2d_atom_poses(mol) | |
| return mol_to_geognn_graph_data(mol, atom_poses, dir_type='HT') | |
| def mol_to_geognn_graph_data_raw3d(mol): | |
| """tbd""" | |
| atom_poses = Compound3DKit.get_atom_poses(mol, mol.GetConformer()) | |
| return mol_to_geognn_graph_data(mol, atom_poses, dir_type='HT') | |
| def gen_adj(shape,edges,length): | |
| adj=edges | |
| e = shape | |
| ones = np.eye(e) | |
| #for i in range(e): | |
| for i in range (len(length)): | |
| if adj[i,0] != adj[i,1]: | |
| ones[adj[i,0],adj[i,1]]=(float(length[i] )) | |
| return ones | |
| if __name__ == "__main__": | |
| import pandas as pd | |
| from tqdm import tqdm | |
| f = pd.read_csv (r"J:\screenacc\new4.csv") | |
| # re = [] | |
| # pce = f['PCE'] | |
| # for ind,smile in enumerate ( f.iloc[:,1]): | |
| # print(ind) | |
| # atom,adj = mol_to_geognn_graph_data_MMFF3d(smile) | |
| # np.save('data/reg/train/adj'+str(ind)+'.npy',np.array(adj)) | |
| # re.append([atom,'data/reg/train/adj'+str(ind)+'.npy',pce[ind] ]) | |
| # r = pd.DataFrame(re) | |
| # r.to_csv('data/reg/train/train.csv') | |
| # re = [] | |
| # f = pd.read_csv(r'data/reg/test3.csv') | |
| # re = [] | |
| # pce = f['PCE'] | |
| # for ind,smile in enumerate ( f.iloc[:,1]): | |
| # print(ind) | |
| # atom,adj = mol_to_geognn_graph_data_MMFF3d(smile) | |
| # np.save('data/reg/test/adj'+str(ind)+'.npy',np.array(adj)) | |
| # re.append([atom,'data/reg/test/adj'+str(ind)+'.npy',pce[ind] ]) | |
| # r = pd.DataFrame(re) | |
| # r.to_csv('data/reg/test/test.csv') | |
| # f = pd.read_csv(r'val.csv') | |
| re = [] | |
| pce = f['PCE'] | |
| for ind,smile in enumerate ( f.iloc[ 22000: ,0]): | |
| ind = ind + 22000 | |
| print(ind) | |
| atom,adj = mol_to_geognn_graph_data_MMFF3d(smile) | |
| np.save('data/reg/val/adj'+str(ind)+'.npy',np.array(adj)) | |
| re.append([atom,'data/reg/val/adj'+str(ind)+'.npy',pce[ind] ]) | |
| r = pd.DataFrame(re) | |
| r.to_csv('data/reg/val/val22000.csv') |