toxipredict-api / models /molecule_graph.py
Arko006's picture
fix: correct node features to 45-dim (exact one-hot + mass)
09e6d6a verified
Raw
History Blame Contribute Delete
3.04 kB
import torch
from rdkit import Chem
from rdkit.Chem import AllChem
from torch_geometric.data import Data
ATOM_TYPES = [6, 7, 8, 16, 15, 9, 17, 35, 53]
HYBRIDIZATION_TYPES = [
Chem.rdchem.HybridizationType.SP,
Chem.rdchem.HybridizationType.SP2,
Chem.rdchem.HybridizationType.SP3,
Chem.rdchem.HybridizationType.SP3D,
Chem.rdchem.HybridizationType.SP3D2,
Chem.rdchem.HybridizationType.UNSPECIFIED,
]
BOND_TYPES = [
Chem.rdchem.BondType.SINGLE,
Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE,
Chem.rdchem.BondType.AROMATIC,
]
STEREO_TYPES = [
Chem.rdchem.BondStereo.STEREONONE,
Chem.rdchem.BondStereo.STEREOANY,
Chem.rdchem.BondStereo.STEREOZ,
Chem.rdchem.BondStereo.STEREOE,
]
def one_hot(val, choices):
return [1 if val == c else 0 for c in choices]
def get_atom_features(atom):
feat = one_hot(atom.GetAtomicNum(), ATOM_TYPES)
feat += one_hot(atom.GetDegree(), list(range(7)))
feat += one_hot(atom.GetTotalValence(), list(range(7)))
feat += one_hot(atom.GetFormalCharge(), list(range(-3, 4)))
feat += one_hot(atom.GetHybridization(), HYBRIDIZATION_TYPES)
feat.append(1.0 if atom.GetIsAromatic() else 0.0)
feat += one_hot(atom.GetTotalNumHs(), list(range(5)))
feat.append(float(atom.GetNumRadicalElectrons()))
feat.append(1.0 if atom.IsInRing() else 0.0)
feat.append(float(atom.GetMass()) / 200.0)
return feat
def get_bond_features(bond, mol):
feat = one_hot(bond.GetBondType(), BOND_TYPES)
feat.append(1.0 if bond.GetIsConjugated() else 0.0)
feat.append(1.0 if bond.IsInRing() else 0.0)
feat += one_hot(bond.GetStereo(), STEREO_TYPES)
conf = mol.GetConformer()
i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
dist = conf.GetAtomPosition(i).Distance(conf.GetAtomPosition(j))
feat.append(min(dist / 2.0, 1.0))
return feat
def smiles_to_graph(smiles: str):
mol = Chem.MolFromSmiles(smiles)
if mol is None:
raise ValueError(f"Invalid SMILES: {smiles}")
mol = Chem.AddHs(mol)
AllChem.EmbedMolecule(mol, randomSeed=42)
AllChem.MMFFOptimizeMolecule(mol)
mol = Chem.RemoveHs(mol)
atom_features = []
for atom in mol.GetAtoms():
atom_features.append(get_atom_features(atom))
x = torch.tensor(atom_features, dtype=torch.float32)
edge_indices = []
edge_features = []
for bond in mol.GetBonds():
i = bond.GetBeginAtomIdx()
j = bond.GetEndAtomIdx()
edge_indices.append([i, j])
edge_indices.append([j, i])
bf = get_bond_features(bond, mol)
edge_features.append(bf)
edge_features.append(bf)
if len(edge_indices) == 0:
edge_index = torch.zeros((2, 0), dtype=torch.long)
edge_attr = torch.zeros((0, 11), dtype=torch.float32)
else:
edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous()
edge_attr = torch.tensor(edge_features, dtype=torch.float32)
return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)