|
|
import torch |
|
|
import pandas as pd |
|
|
from torch_geometric.data import Data |
|
|
from rdkit import Chem |
|
|
from rdkit.Chem import AllChem, Descriptors |
|
|
from rdkit.Chem import Draw |
|
|
from rdkit import RDLogger |
|
|
RDLogger.DisableLog('rdApp.*') |
|
|
|
|
|
|
|
|
|
|
|
def smiles_to_graph(smiles): |
|
|
mol = Chem.MolFromSmiles(smiles) |
|
|
if mol is None: |
|
|
return None |
|
|
mol = Chem.AddHs(mol) |
|
|
try: |
|
|
AllChem.EmbedMolecule(mol, AllChem.ETKDG()) |
|
|
AllChem.UFFOptimizeMolecule(mol) |
|
|
except: |
|
|
return None |
|
|
|
|
|
conf = mol.GetConformer() |
|
|
atoms = mol.GetAtoms() |
|
|
bonds = mol.GetBonds() |
|
|
|
|
|
node_feats = [] |
|
|
pos = [] |
|
|
edge_index = [] |
|
|
edge_attrs = [] |
|
|
|
|
|
for atom in atoms: |
|
|
|
|
|
node_feats.append([atom.GetAtomicNum() / 100.0]) |
|
|
position = conf.GetAtomPosition(atom.GetIdx()) |
|
|
pos.append([position.x, position.y, position.z]) |
|
|
|
|
|
for bond in bonds: |
|
|
start = bond.GetBeginAtomIdx() |
|
|
end = bond.GetEndAtomIdx() |
|
|
edge_index.append([start, end]) |
|
|
edge_index.append([end, start]) |
|
|
bond_type = bond.GetBondType() |
|
|
bond_class = { |
|
|
Chem.BondType.SINGLE: 0, |
|
|
Chem.BondType.DOUBLE: 1, |
|
|
Chem.BondType.TRIPLE: 2, |
|
|
Chem.BondType.AROMATIC: 3 |
|
|
}.get(bond_type, 0) |
|
|
edge_attrs.extend([[bond_class], [bond_class]]) |
|
|
|
|
|
return Data( |
|
|
x=torch.tensor(node_feats, dtype=torch.float), |
|
|
pos=torch.tensor(pos, dtype=torch.float), |
|
|
edge_index=torch.tensor(edge_index, dtype=torch.long).t().contiguous(), |
|
|
edge_attr=torch.tensor(edge_attrs, dtype=torch.long) |
|
|
) |
|
|
|
|
|
|
|
|
def load_goodscents_subset(filepath="../data/leffingwell-goodscent-merge-dataset.csv", |
|
|
index=200, |
|
|
shuffle=True |
|
|
): |
|
|
|
|
|
df = pd.read_csv(filepath) |
|
|
if shuffle: |
|
|
df = df.sample(frac=1).reset_index(drop=True) |
|
|
if index > 0: |
|
|
df = df.head(index) |
|
|
else: |
|
|
df = df.tail(-1*index) |
|
|
descriptor_cols = df.columns[2:] |
|
|
smiles_list, label_map = [], {} |
|
|
for _, row in df.iterrows(): |
|
|
smiles = row["nonStereoSMILES"] |
|
|
labels = row[descriptor_cols].astype(int).tolist() |
|
|
if smiles and any(labels): |
|
|
smiles_list.append(smiles) |
|
|
label_map[smiles] = labels |
|
|
return smiles_list, label_map, list(descriptor_cols) |
|
|
|
|
|
|
|
|
|
|
|
def sample(model, conditioner, label_vec, constrained=True, steps=1000, debug=True): |
|
|
x_t = torch.randn((10, 1)) |
|
|
pos = torch.randn((10, 3)) |
|
|
edge_index = torch.randint(0, 10, (2, 20)) |
|
|
|
|
|
for t in reversed(range(1, steps + 1)): |
|
|
cond_embed = conditioner(label_vec.unsqueeze(0)) |
|
|
pred_x, bond_logits = model(x_t, pos, edge_index, torch.tensor([t]), cond_embed) |
|
|
bond_logits = temperature_scaled_softmax(bond_logits, temperature=(1/t)) |
|
|
x_t = x_t - pred_x * (1.0 / steps) |
|
|
|
|
|
x_t = x_t * 100.0 |
|
|
x_t.relu_() |
|
|
atom_types = torch.clamp(x_t.round(), 1, 118).int().squeeze().tolist() |
|
|
|
|
|
allowed_atoms = [6, 7, 8, 9, 15, 16, 17] |
|
|
bond_logits.relu_() |
|
|
bond_preds = torch.argmax(bond_logits, dim=-1).tolist() |
|
|
if debug: |
|
|
print(f"\tcond_embed: {cond_embed}") |
|
|
print(f"\tx_t: {x_t}") |
|
|
print(f"\tprediction: {x_t}") |
|
|
print(f"\tbond logits: {bond_logits}") |
|
|
print(f"\tatoms: {atom_types}") |
|
|
print(f"\tbonds: {bond_preds}") |
|
|
|
|
|
mol = Chem.RWMol() |
|
|
idx_map = {} |
|
|
for i, atomic_num in enumerate(atom_types): |
|
|
if constrained and atomic_num not in allowed_atoms: |
|
|
continue |
|
|
try: |
|
|
atom = Chem.Atom(int(atomic_num)) |
|
|
idx_map[i] = mol.AddAtom(atom) |
|
|
except Exception: |
|
|
continue |
|
|
|
|
|
if len(idx_map) < 2: |
|
|
print("Molecule too small or no valid atoms after filtering.") |
|
|
return "" |
|
|
|
|
|
bond_type_map = { |
|
|
0: Chem.BondType.SINGLE, |
|
|
1: Chem.BondType.DOUBLE, |
|
|
2: Chem.BondType.TRIPLE, |
|
|
3: Chem.BondType.AROMATIC |
|
|
} |
|
|
|
|
|
added = set() |
|
|
for i in range(edge_index.shape[1]): |
|
|
a = int(edge_index[0, i]) |
|
|
b = int(edge_index[1, i]) |
|
|
if a != b and (a, b) not in added and (b, a) not in added and a in idx_map and b in idx_map: |
|
|
try: |
|
|
bond_type = bond_type_map.get(bond_preds[i], Chem.BondType.SINGLE) |
|
|
mol.AddBond(idx_map[a], idx_map[b], bond_type) |
|
|
added.add((a, b)) |
|
|
except Exception: |
|
|
continue |
|
|
try: |
|
|
mol = mol.GetMol() |
|
|
Chem.SanitizeMol(mol) |
|
|
smiles = Chem.MolToSmiles(mol) |
|
|
img = Draw.MolToImage(mol) |
|
|
img.show() |
|
|
print(f"Atom types: {atom_types}") |
|
|
print(f"Generated SMILES: {smiles}") |
|
|
return smiles |
|
|
except Exception as e: |
|
|
print(f"Sanitization error: {e}") |
|
|
return "" |
|
|
|
|
|
|
|
|
|
|
|
def sample_batch(model, conditioner, label_vec, steps=1000, batch_size=4): |
|
|
mols = [] |
|
|
for _ in range(batch_size): |
|
|
x_t = torch.randn((10, 1)) |
|
|
pos = torch.randn((10, 3)) |
|
|
edge_index = torch.randint(0, 10, (2, 20)) |
|
|
|
|
|
for t in reversed(range(1, steps + 1)): |
|
|
cond_embed = conditioner(label_vec.unsqueeze(0)) |
|
|
pred_x, bond_logits = model(x_t, pos, edge_index, torch.tensor([t]), cond_embed) |
|
|
x_t = x_t - pred_x * (1.0 / steps) |
|
|
|
|
|
x_t = x_t * 100.0 |
|
|
x_t.relu_() |
|
|
atom_types = torch.clamp(x_t.round(), 1, 118).int().squeeze().tolist() |
|
|
allowed_atoms = [6, 7, 8, 9, 15, 16, 17] |
|
|
bond_logits.relu_() |
|
|
|
|
|
mol = Chem.RWMol() |
|
|
idx_map = {} |
|
|
for i, atomic_num in enumerate(atom_types): |
|
|
if atomic_num not in allowed_atoms: |
|
|
continue |
|
|
try: |
|
|
atom = Chem.Atom(int(atomic_num)) |
|
|
idx_map[i] = mol.AddAtom(atom) |
|
|
except Exception: |
|
|
continue |
|
|
|
|
|
if len(idx_map) < 2: |
|
|
continue |
|
|
|
|
|
bond_type_map = { |
|
|
0: Chem.BondType.SINGLE, |
|
|
1: Chem.BondType.DOUBLE, |
|
|
2: Chem.BondType.TRIPLE, |
|
|
3: Chem.BondType.AROMATIC |
|
|
} |
|
|
|
|
|
added = set() |
|
|
for i in range(edge_index.shape[1]): |
|
|
a = int(edge_index[0, i]) |
|
|
b = int(edge_index[1, i]) |
|
|
if a != b and (a, b) not in added and (b, a) not in added and a in idx_map and b in idx_map: |
|
|
try: |
|
|
bond_type = bond_type_map.get(bond_preds[i], Chem.BondType.SINGLE) |
|
|
mol.AddBond(idx_map[a], idx_map[b], bond_type) |
|
|
added.add((a, b)) |
|
|
except Exception: |
|
|
continue |
|
|
|
|
|
try: |
|
|
mol = mol.GetMol() |
|
|
Chem.SanitizeMol(mol) |
|
|
mols.append(mol) |
|
|
except Exception: |
|
|
continue |
|
|
return mols |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def validate_molecule(smiles): |
|
|
mol = Chem.MolFromSmiles(smiles) |
|
|
if mol is None: |
|
|
return False, {} |
|
|
return True, {"MolWt": Descriptors.MolWt(mol), "LogP": Descriptors.MolLogP(mol)} |
|
|
|
|
|
|
|
|
|
|
|
def test_models(test_model, test_conditioner): |
|
|
good_count: int = 0 |
|
|
index: int = int(4983.0 * 0.2) |
|
|
smiles_list, label_map, label_names = load_goodscents_subset(index=index) |
|
|
dataset = [] |
|
|
test_model.eval() |
|
|
test_conditioner.eval() |
|
|
for smi in smiles_list: |
|
|
g = smiles_to_graph(smi) |
|
|
if g: |
|
|
g.y = torch.tensor(label_map[smi]) |
|
|
dataset.append(g) |
|
|
|
|
|
for i in range(0, len(dataset)): |
|
|
print(f"Testing molecule {i+1}/{len(dataset)}") |
|
|
data = dataset[i] |
|
|
x_0, pos, edge_index, edge_attr, label_vec = data.x, data.pos, data.edge_index, data.edge_attr.view(-1), data.y |
|
|
new_smiles = sample(test_model, test_conditioner, label_vec=label_vec) |
|
|
print(new_smiles) |
|
|
valid, props = validate_molecule(new_smiles) |
|
|
print(f"Generated SMILES: {new_smiles}\nValid: {valid}, Properties: {props}") |
|
|
if new_smiles != "": |
|
|
good_count += 1 |
|
|
|
|
|
percent_correct: float = float(good_count) / float(len(dataset)) |
|
|
print(f"Percent correct: {percent_correct}") |
|
|
|
|
|
|
|
|
|