tox21-classifier / model.py
sk16er's picture
Update model.py
11825ed verified
import torch
import torch.nn as nn
from rdkit import Chem
# =============================================================================
# Featurization Utils
# =============================================================================
ATOM_FEATURES = {
'atomic_num': list(range(1, 101)),
'degree': [0, 1, 2, 3, 4, 5],
'formal_charge': [-1, -2, 1, 2, 0],
'chiral_tag': [0, 1, 2, 3],
'num_hs': [0, 1, 2, 3, 4],
'hybridization': [
Chem.rdchem.HybridizationType.SP,
Chem.rdchem.HybridizationType.SP2,
Chem.rdchem.HybridizationType.SP3,
Chem.rdchem.HybridizationType.SP3D,
Chem.rdchem.HybridizationType.SP3D2
],
}
BOND_FDIM = 13
def get_atom_fdim():
return sum(len(choices) + 1 for choices in ATOM_FEATURES.values()) + 2
def get_bond_fdim():
return BOND_FDIM
def onek_encoding_unk(value, choices):
encoding = [0] * (len(choices) + 1)
index = choices.index(value) if value in choices else -1
encoding[index] = 1
return encoding
def atom_features(atom):
return (
onek_encoding_unk(atom.GetAtomicNum(), ATOM_FEATURES['atomic_num']) +
onek_encoding_unk(atom.GetTotalDegree(), ATOM_FEATURES['degree']) +
onek_encoding_unk(atom.GetFormalCharge(), ATOM_FEATURES['formal_charge']) +
onek_encoding_unk(int(atom.GetChiralTag()), ATOM_FEATURES['chiral_tag']) +
onek_encoding_unk(int(atom.GetTotalNumHs()), ATOM_FEATURES['num_hs']) +
onek_encoding_unk(int(atom.GetHybridization()), ATOM_FEATURES['hybridization']) +
[1 if atom.GetIsAromatic() else 0] +
[atom.GetMass() * 0.01]
)
def bond_features(bond):
bt = bond.GetBondType()
feats = [
bt == Chem.rdchem.BondType.SINGLE,
bt == Chem.rdchem.BondType.DOUBLE,
bt == Chem.rdchem.BondType.TRIPLE,
bt == Chem.rdchem.BondType.AROMATIC,
bond.GetIsConjugated() if bt else 0,
bond.IsInRing() if bt else 0,
]
feats += onek_encoding_unk(int(bond.GetStereo()), list(range(6)))
return feats
class MolGraph:
def __init__(self, smiles):
self.smiles = smiles
self.f_atoms = []
self.f_bonds = []
self.a2b = []
self.b2a = []
self.b2revb = []
mol = Chem.MolFromSmiles(smiles)
if mol is None:
self.n_atoms = 0
self.n_bonds = 0
return
self.n_atoms = mol.GetNumAtoms()
for atom in mol.GetAtoms():
self.f_atoms.append(atom_features(atom))
self.a2b = [[] for _ in range(self.n_atoms)]
self.n_bonds = 0
for bond in mol.GetBonds():
a1, a2 = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
b_feat = bond_features(bond)
b1 = self.n_bonds
self.f_bonds.append(b_feat)
self.b2a.append(a1)
self.a2b[a2].append(b1)
b2 = self.n_bonds + 1
self.f_bonds.append(b_feat)
self.b2a.append(a2)
self.a2b[a1].append(b2)
self.b2revb.extend([b2, b1])
self.n_bonds += 2
class BatchMolGraph:
def __init__(self, mol_graphs):
self.atom_features = []
self.bond_features = []
self.a2b = []
self.b2a = []
self.b2revb = []
self.a_scope = []
total_atoms = 0
total_bonds = 0
for g in mol_graphs:
self.atom_features.extend(g.f_atoms)
self.bond_features.extend(g.f_bonds)
for lst in g.a2b:
self.a2b.append([b + total_bonds for b in lst])
self.b2a.extend(a + total_atoms for a in g.b2a)
self.b2revb.extend(b + total_bonds for b in g.b2revb)
self.a_scope.append((total_atoms, g.n_atoms))
total_atoms += g.n_atoms
total_bonds += g.n_bonds
self.atom_features = torch.tensor(self.atom_features, dtype=torch.float)
self.bond_features = torch.tensor(self.bond_features, dtype=torch.float)
self.b2a = torch.tensor(self.b2a, dtype=torch.long)
self.b2revb = torch.tensor(self.b2revb, dtype=torch.long)
def get_components(self):
return (
self.atom_features,
self.bond_features,
self.a2b,
self.b2a,
self.b2revb,
self.a_scope,
)
# =============================================================================
# D-MPNN (Backward-Compatible)
# =============================================================================
class DMPNN(nn.Module):
def __init__(
self,
hidden_size=300,
depth=3,
tasks=12,
global_feats_size=217,
n_tasks=None, # ← compatibility with old checkpoints
**kwargs # ← ignore legacy args safely
):
super().__init__()
if n_tasks is not None:
tasks = n_tasks
self.hidden_size = hidden_size
self.depth = depth
self.atom_fdim = get_atom_fdim()
self.bond_fdim = get_bond_fdim()
self.global_feats_size = global_feats_size
self.W_i = nn.Linear(self.atom_fdim + self.bond_fdim, hidden_size, bias=False)
self.W_h = nn.Linear(hidden_size, hidden_size, bias=False)
self.W_o = nn.Linear(self.atom_fdim + hidden_size, hidden_size)
self.act = nn.ReLU()
self.dropout = nn.Dropout(0.1)
self.readout_1 = nn.Linear(hidden_size + global_feats_size, hidden_size)
self.readout_2 = nn.Linear(hidden_size, tasks)
def forward(self, batch_graph, global_feats=None):
f_atoms, f_bonds, a2b, b2a, b2revb, a_scope = batch_graph.get_components()
if f_atoms.size(0) == 0:
return torch.zeros((len(a_scope), self.readout_2.out_features),
device=self.W_i.weight.device)
device = self.W_i.weight.device
f_atoms, f_bonds = f_atoms.to(device), f_bonds.to(device)
b2a, b2revb = b2a.to(device), b2revb.to(device)
h0 = self.act(self.W_i(torch.cat([f_atoms.index_select(0, b2a), f_bonds], 1)))
h = h0
for _ in range(self.depth):
atom_msg = torch.zeros(f_atoms.size(0), self.hidden_size, device=device)
atom_msg.index_add_(0, b2a.index_select(0, b2revb), h)
m = atom_msg.index_select(0, b2a) - h.index_select(0, b2revb)
h = self.dropout(self.act(h0 + self.W_h(m)))
atom_msg = torch.zeros(f_atoms.size(0), self.hidden_size, device=device)
atom_msg.index_add_(0, b2a.index_select(0, b2revb), h)
atom_h = self.act(self.W_o(torch.cat([f_atoms, atom_msg], 1)))
mol_vecs = [
atom_h.narrow(0, s, n).sum(0) if n > 0 else torch.zeros(self.hidden_size, device=device)
for s, n in a_scope
]
mol_vecs = torch.stack(mol_vecs)
if self.global_feats_size > 0:
if global_feats is None:
raise ValueError("Global features expected but not provided")
mol_vecs = torch.cat([mol_vecs, global_feats.to(device)], 1)
x = self.dropout(self.act(self.readout_1(mol_vecs)))
return self.readout_2(x)