import torch from rdkit import Chem from rdkit.Chem import AllChem from torch_geometric.data import Data ATOM_TYPES = [6, 7, 8, 16, 15, 9, 17, 35, 53] HYBRIDIZATION_TYPES = [ Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2, Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D, Chem.rdchem.HybridizationType.SP3D2, Chem.rdchem.HybridizationType.UNSPECIFIED, ] BOND_TYPES = [ Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC, ] STEREO_TYPES = [ Chem.rdchem.BondStereo.STEREONONE, Chem.rdchem.BondStereo.STEREOANY, Chem.rdchem.BondStereo.STEREOZ, Chem.rdchem.BondStereo.STEREOE, ] def one_hot(val, choices): return [1 if val == c else 0 for c in choices] def get_atom_features(atom): feat = one_hot(atom.GetAtomicNum(), ATOM_TYPES) feat += one_hot(atom.GetDegree(), list(range(7))) feat += one_hot(atom.GetTotalValence(), list(range(7))) feat += one_hot(atom.GetFormalCharge(), list(range(-3, 4))) feat += one_hot(atom.GetHybridization(), HYBRIDIZATION_TYPES) feat.append(1.0 if atom.GetIsAromatic() else 0.0) feat += one_hot(atom.GetTotalNumHs(), list(range(5))) feat.append(float(atom.GetNumRadicalElectrons())) feat.append(1.0 if atom.IsInRing() else 0.0) feat.append(float(atom.GetMass()) / 200.0) return feat def get_bond_features(bond, mol): feat = one_hot(bond.GetBondType(), BOND_TYPES) feat.append(1.0 if bond.GetIsConjugated() else 0.0) feat.append(1.0 if bond.IsInRing() else 0.0) feat += one_hot(bond.GetStereo(), STEREO_TYPES) conf = mol.GetConformer() i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() dist = conf.GetAtomPosition(i).Distance(conf.GetAtomPosition(j)) feat.append(min(dist / 2.0, 1.0)) return feat def smiles_to_graph(smiles: str): mol = Chem.MolFromSmiles(smiles) if mol is None: raise ValueError(f"Invalid SMILES: {smiles}") mol = Chem.AddHs(mol) AllChem.EmbedMolecule(mol, randomSeed=42) AllChem.MMFFOptimizeMolecule(mol) mol = Chem.RemoveHs(mol) atom_features = [] for atom in mol.GetAtoms(): atom_features.append(get_atom_features(atom)) x = torch.tensor(atom_features, dtype=torch.float32) edge_indices = [] edge_features = [] for bond in mol.GetBonds(): i = bond.GetBeginAtomIdx() j = bond.GetEndAtomIdx() edge_indices.append([i, j]) edge_indices.append([j, i]) bf = get_bond_features(bond, mol) edge_features.append(bf) edge_features.append(bf) if len(edge_indices) == 0: edge_index = torch.zeros((2, 0), dtype=torch.long) edge_attr = torch.zeros((0, 11), dtype=torch.float32) else: edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous() edge_attr = torch.tensor(edge_features, dtype=torch.float32) return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)