Spaces:
Runtime error
Runtime error
| import os | |
| import numpy as np | |
| import torch | |
| from torch import nn | |
| from rdkit import Chem | |
| from dockformer.data.utils import FeatureTensorDict | |
| from dockformer.utils.consts import POSSIBLE_BOND_TYPES, POSSIBLE_ATOM_TYPES, POSSIBLE_CHARGES, POSSIBLE_CHIRALITIES | |
| def get_atom_features(atom: Chem.Atom): | |
| # TODO: this is temporary, we need to add more features, for example for Zn | |
| if atom.GetSymbol() not in POSSIBLE_ATOM_TYPES: | |
| print(f"********Unknown atom type {atom.GetSymbol()}") | |
| atom_type = POSSIBLE_ATOM_TYPES.index("Ni") | |
| else: | |
| atom_type = POSSIBLE_ATOM_TYPES.index(atom.GetSymbol()) | |
| atom_charge = POSSIBLE_CHARGES.index(max(min(atom.GetFormalCharge(), 1), -1)) | |
| atom_chirality = POSSIBLE_CHIRALITIES.index(atom.GetChiralTag()) | |
| return {"atom_type": atom_type, "atom_charge": atom_charge, "atom_chirality": atom_chirality} | |
| def get_bond_features(bond: Chem.Bond): | |
| bond_type = POSSIBLE_BOND_TYPES.index(bond.GetBondType()) | |
| return {"bond_type": bond_type} | |
| def make_ligand_features(ligand: Chem.Mol) -> FeatureTensorDict: | |
| atoms_features = [] | |
| atom_idx_to_atom_pos_idx = {} | |
| for atom in ligand.GetAtoms(): | |
| atom_idx_to_atom_pos_idx[atom.GetIdx()] = len(atoms_features) | |
| atoms_features.append(get_atom_features(atom)) | |
| atom_types = torch.tensor(np.array([atom["atom_type"] for atom in atoms_features], dtype=np.int64)) | |
| atom_types_one_hot = nn.functional.one_hot(atom_types, num_classes=len(POSSIBLE_ATOM_TYPES), ) | |
| atom_charges = torch.tensor(np.array([atom["atom_charge"] for atom in atoms_features], dtype=np.int64)) | |
| atom_charges_one_hot = nn.functional.one_hot(atom_charges, num_classes=len(POSSIBLE_CHARGES)) | |
| atom_chiralities = torch.tensor(np.array([atom["atom_chirality"] for atom in atoms_features], dtype=np.int64)) | |
| atom_chiralities_one_hot = nn.functional.one_hot(atom_chiralities, num_classes=len(POSSIBLE_CHIRALITIES)) | |
| ligand_target_feat = torch.cat([atom_types_one_hot.float(), atom_charges_one_hot.float(), | |
| atom_chiralities_one_hot.float()], dim=1) | |
| # create one-hot matrix encoding for bonds | |
| ligand_bonds_feat = torch.zeros((len(atoms_features), len(atoms_features), len(POSSIBLE_BOND_TYPES))) | |
| ligand_bonds = [] | |
| for bond in ligand.GetBonds(): | |
| atom1_idx = atom_idx_to_atom_pos_idx[bond.GetBeginAtomIdx()] | |
| atom2_idx = atom_idx_to_atom_pos_idx[bond.GetEndAtomIdx()] | |
| bond_features = get_bond_features(bond) | |
| ligand_bonds.append((atom1_idx, atom2_idx, bond_features["bond_type"])) | |
| ligand_bonds_feat[atom1_idx, atom2_idx, bond_features["bond_type"]] = 1 | |
| return { | |
| # These are used for reconstruction at the end of the pipeline | |
| "ligand_atype": atom_types, | |
| "ligand_charge": atom_charges, | |
| "ligand_chirality": atom_chiralities, | |
| "ligand_bonds": torch.tensor(ligand_bonds, dtype=torch.int64), | |
| # these are the actual features | |
| "ligand_target_feat": ligand_target_feat.float(), | |
| "ligand_bonds_feat": ligand_bonds_feat.float(), | |
| } | |