teohyc commited on
Commit
e49b3c8
·
verified ·
1 Parent(s): e9b3e75

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. README.md +70 -0
  2. dict_idx2frag.pt +3 -0
  3. tree_rnn-vae.pt +3 -0
  4. tree_rnn_vae_infer.py +141 -0
  5. tree_rnn_vae_model.py +132 -0
README.md ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # De Novo Drug Generator - RNN-VAE
3
+
4
+ De Novo Drug Generator - RNN-VAE is a deep learning model designed for generating novel drug molecules.
5
+ Training data from ChemBL library
6
+
7
+ Full project file at https://github.com/teohyc/drug_agent
8
+
9
+
10
+ ## Usage
11
+
12
+ ```python
13
+ from rdkit import Chem
14
+ from rdkit.Chem import Draw, Descriptors
15
+ from tree_rnn_vae_infer import generate_candidate_mol
16
+ from tree_rnn_vae_model import TreeEncoder, LatentHead, TreeVAE, TreeDecoder
17
+
18
+
19
+ def compute_molecule_props(mol):
20
+ return {
21
+ "MW": Descriptors.MolWt(mol),
22
+ "logP": Descriptors.MolLogP(mol),
23
+ "HBD": Descriptors.NumHDonors(mol),
24
+ "HBA": Descriptors.NumHAcceptors(mol),
25
+ }
26
+
27
+
28
+ # display molecule
29
+ def render_molecule_grid(selected):
30
+ if not selected:
31
+ return
32
+
33
+ mols, legends = [], []
34
+
35
+ if isinstance(selected, dict):
36
+ iterable = selected.items()
37
+ else:
38
+ iterable = enumerate(selected, 1)
39
+
40
+ for i, item in iterable:
41
+ if isinstance(selected, dict):
42
+ smi, props = i, item
43
+ else:
44
+ smi, props = item, None
45
+
46
+ mol = Chem.MolFromSmiles(smi)
47
+ if mol:
48
+ mols.append(mol)
49
+ if props is None:
50
+ props = compute_molecule_props(mol)
51
+ legends.append(
52
+ f"M{i}
53
+ MW={props['MW']:.0f}, logP={props['logP']:.2f}, "
54
+ f"HBD={props['HBD']}, HBA={props['HBA']}"
55
+ )
56
+
57
+ img = Draw.MolsToGridImage(
58
+ mols,
59
+ molsPerRow=3,
60
+ subImgSize=(400, 400),
61
+ legends=legends,
62
+ useSVG=False,
63
+ )
64
+ return img
65
+
66
+
67
+ molecules = generate_candidate_mol(num_samples=6, max_len=20) #change to your desired molecule size and number
68
+ img = render_molecule_grid(molecules)
69
+ img.show()
70
+ ```
dict_idx2frag.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b45ac2d4f65aa76bd54a6742c48cc5b749d58fedb2f0309c648a258481155e5d
3
+ size 18769
tree_rnn-vae.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5abe40c97c5915ac21890b44d309edba62c9933fc66d8c9e0043644ef46e7cc8
3
+ size 3963609
tree_rnn_vae_infer.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import torch
3
+ from torch import nn
4
+ from torch.utils.data import Dataset, DataLoader
5
+ import math
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from rdkit import Chem
9
+ from rdkit.Chem import Draw, AllChem
10
+ import random
11
+ import re
12
+ from tree_rnn_vae_model import TreeEncoder, LatentHead, TreeVAE, TreeDecoder
13
+
14
+ def decode_fragments_to_smiles(batch_indices, vocab):
15
+ smiles_list = []
16
+ for seq in batch_indices:
17
+ fragments = [vocab[idx] for idx in seq.tolist()]
18
+
19
+ try:
20
+ mol = Chem.MolFromSmiles(fragments)
21
+ if mol is not None:
22
+ smiles_list.append(Chem.MolToSmiles(mol))
23
+ else:
24
+ smiles_list.append(fragments) # fallback
25
+ except:
26
+ smiles_list.append(fragments)
27
+ return smiles_list
28
+
29
+ def normalize_attachment_points(smiles):
30
+ return re.sub(r'\[\d+\*\]', '[*]', smiles)
31
+
32
+ def deduplicate_fragments(fragments):
33
+ seen = set()
34
+ unique = []
35
+
36
+ for f in fragments:
37
+ if f not in seen:
38
+ unique.append(f)
39
+ seen.add(f)
40
+ return unique
41
+
42
+ def fragments_to_mols(fragments):
43
+ mols = []
44
+ for f in fragments:
45
+ mol = Chem.MolFromSmiles(f)
46
+ if mol is None:
47
+ continue # skip invalid fragment
48
+ mols.append(mol)
49
+ return mols
50
+
51
+ def get_dummy_atom_indices(mol):
52
+ return [atom.GetIdx() for atom in mol.GetAtoms() if atom.GetSymbol() == '*']
53
+
54
+ def get_linear_connection_pairs(mols):
55
+ """Connect last dummy atom of mol i to first dummy atom of mol i+1"""
56
+ pairs = []
57
+ for i in range(len(mols)-1):
58
+ idx_last_i = get_dummy_atom_indices(mols[i])[-1]
59
+ idx_first_next = get_dummy_atom_indices(mols[i+1])[0]
60
+ pairs.append((i, idx_last_i, i+1, idx_first_next))
61
+ return pairs
62
+
63
+ def combine_fragments_and_connect(fragments, connection_pairs):
64
+ mols = fragments_to_mols(fragments)
65
+ if not mols:
66
+ return None
67
+
68
+ # Combine molecules sequentially
69
+ combined = Chem.CombineMols(mols[0], mols[1]) if len(mols) > 1 else mols[0]
70
+ offsets = [0, mols[0].GetNumAtoms()]
71
+ for i in range(2, len(mols)):
72
+ combined = Chem.CombineMols(combined, mols[i])
73
+ offsets.append(sum([m.GetNumAtoms() for m in mols[:i]]))
74
+
75
+ rw_mol = Chem.RWMol(combined)
76
+
77
+ # Add bonds for connection pairs
78
+ for f1, a1, f2, a2 in connection_pairs:
79
+ rw_mol.AddBond(offsets[f1] + a1, offsets[f2] + a2, Chem.BondType.SINGLE)
80
+
81
+ # Sanitize
82
+ try:
83
+ Chem.SanitizeMol(rw_mol)
84
+ return Chem.MolToSmiles(rw_mol)
85
+ except:
86
+ return None
87
+
88
+
89
+ def process_one_molecule(raw_fragments):
90
+ # Normalize & deduplicate
91
+ frags = [normalize_attachment_points(f) for f in raw_fragments]
92
+ frags = deduplicate_fragments(frags)
93
+
94
+ if len(frags) == 0:
95
+ return None
96
+
97
+ # Convert to mols
98
+ mols = fragments_to_mols(frags)
99
+ if len(mols) < 1:
100
+ return None
101
+
102
+ # Build connections (linear for now)
103
+ connection_pairs = get_linear_connection_pairs(mols)
104
+
105
+ # Combine & sanitize
106
+ smiles = combine_fragments_and_connect(frags, connection_pairs)
107
+ return smiles
108
+
109
+ def replace_dummy_with_carbon(smiles):
110
+ return smiles.replace('*', 'C')
111
+
112
+
113
+ #main
114
+ def generate_candidate_mol(num_samples=6, max_len=20):
115
+ '''This function generates candidate molecules using an in-house-designed Tree-RNN VAE model. It returns a list of SMILES strings representing the generated molecules as well as displaying their molecular structure.'''
116
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
117
+
118
+ idx2frag = torch.load("dict_idx2frag.pt") # {index: fragment_string}
119
+ vocab_size = len(idx2frag)
120
+
121
+ model = torch.load("tree_rnn-vae.pt", weights_only=False, map_location=device)
122
+
123
+ model.eval()
124
+ z = torch.randn(num_samples, model.decoder.z_to_hidden.in_features, device=device)
125
+ sampled_indices = model.sample_from_z(z, sos_idx=None, max_len=max_len)
126
+
127
+ smiles_out = decode_fragments_to_smiles(sampled_indices, idx2frag)
128
+
129
+ smiles_list = []
130
+ for mol in smiles_out:
131
+ smiles = process_one_molecule(mol)
132
+ smiles_list.append(smiles)
133
+
134
+ smiles_list = [replace_dummy_with_carbon(s) for s in smiles_list]
135
+
136
+ '''mols = [Chem.MolFromSmiles(s) for s in smiles_list]
137
+ img = Draw.MolsToGridImage(mols, molsPerRow=2, subImgSize=(600,600), legends=smiles_list)
138
+ img.show()'''
139
+
140
+ return smiles_list
141
+
tree_rnn_vae_model.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import torch
3
+ from torch import nn
4
+ from torch.utils.data import Dataset, DataLoader
5
+ import math
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from rdkit import Chem
9
+ from rdkit.Chem import Draw, AllChem
10
+ import random
11
+
12
+ class TreeEncoder(nn.Module):
13
+ def __init__(self, vocab_size, embed_dim, enc_hidden, pad_idx=0):
14
+ super().__init__()
15
+ self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
16
+ self.gru = nn.GRU(embed_dim, enc_hidden, batch_first=True)
17
+ self.enc_hidden = enc_hidden
18
+
19
+ def forward(self, x, lengths):
20
+ """
21
+ x: [B, L] LongTensor (padded fragment indices)
22
+ lengths: [B] LongTensor
23
+ returns: last_hidden: [B, enc_hidden]
24
+ """
25
+ emb = self.embed(x) # [B, L, E]
26
+ # pack
27
+ packed = nn.utils.rnn.pack_padded_sequence(emb, lengths.cpu(), batch_first=True, enforce_sorted=False)
28
+ packed_out, h_n = self.gru(packed) # h_n: [1, B, enc_hidden]
29
+ return h_n.squeeze(0) # [B, enc_hidden]
30
+
31
+ class LatentHead(nn.Module):
32
+ def __init__(self, enc_hidden, z_dim):
33
+ super().__init__()
34
+ self.linear_mu = nn.Linear(enc_hidden, z_dim)
35
+ self.linear_logvar = nn.Linear(enc_hidden, z_dim)
36
+
37
+ def forward(self, h):
38
+ mu = self.linear_mu(h)
39
+ logvar = self.linear_logvar(h)
40
+ return mu, logvar
41
+
42
+ def reparameterize(mu, logvar):
43
+ std = (0.5 * logvar).exp()
44
+ eps = torch.randn_like(std)
45
+ return mu + eps * std
46
+
47
+ class TreeDecoder(nn.Module):
48
+ def __init__(self, vocab_size, embed_dim, dec_hidden, z_dim, pad_idx=0):
49
+ super().__init__()
50
+ self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
51
+ self.z_to_hidden = nn.Linear(z_dim, dec_hidden)
52
+ # GRU input: embed_dim + z_dim
53
+ self.gru = nn.GRU(embed_dim + z_dim, dec_hidden, batch_first=True)
54
+ self.out = nn.Linear(dec_hidden, vocab_size)
55
+
56
+ def init_hidden_from_z(self, z):
57
+ # z: [B, z_dim] -> [1, B, dec_hidden]
58
+ return torch.tanh(self.z_to_hidden(z)).unsqueeze(0)
59
+
60
+ def forward(self, inputs, z, hidden=None):
61
+ # Full sequence forward (teacher forcing)
62
+ B, L = inputs.size()
63
+ if hidden is None:
64
+ hidden = self.init_hidden_from_z(z)
65
+
66
+ emb = self.embed(inputs) # [B,L,E]
67
+ z_exp = z.unsqueeze(1).expand(-1, L, -1) # [B,L,z_dim]
68
+ gru_input = torch.cat([emb, z_exp], dim=-1) # [B,L,E+z]
69
+ out, hidden = self.gru(gru_input, hidden) # [B,L,H]
70
+ logits = self.out(out) # [B,L,vocab]
71
+ return logits, hidden
72
+
73
+ def step(self, input_token, z, hidden=None):
74
+ """
75
+ Single time-step for autoregressive inference.
76
+ input_token: [B] long
77
+ z: [B, z_dim]
78
+ hidden: previous hidden
79
+ returns logits [B, vocab], new_hidden
80
+ """
81
+ emb = self.embed(input_token).unsqueeze(1) # [B,1,E]
82
+ gru_in = torch.cat([emb, z.unsqueeze(1)], dim=-1) # [B,1,E+z]
83
+ out, h_n = self.gru(gru_in, hidden) # out: [B,1,dec_hidden]
84
+ logits = self.out(out.squeeze(1)) # [B, vocab]
85
+ return logits, h_n
86
+
87
+ class TreeVAE(nn.Module):
88
+ def __init__(self, vocab_size, embed_dim, enc_hidden, dec_hidden, z_dim, pad_idx=0):
89
+ super().__init__()
90
+ self.encoder = TreeEncoder(vocab_size, embed_dim, enc_hidden, pad_idx)
91
+ self.latent = LatentHead(enc_hidden, z_dim)
92
+ self.decoder = TreeDecoder(vocab_size, embed_dim, dec_hidden, z_dim, pad_idx)
93
+
94
+ def forward(self, x, lengths, tf_prob=1.0):
95
+ """
96
+ x: [B, L] padded target sequences (we will use teacher forcing)
97
+ lengths: [B] actual lengths
98
+ tf_prob: teacher forcing probability (0..1)
99
+ Returns: logits [B, L, V], mu, logvar
100
+ """
101
+ h_enc = self.encoder(x, lengths) # [B, enc_hidden]
102
+ mu, logvar = self.latent(h_enc) # [B, z_dim]
103
+ z = reparameterize(mu, logvar).to(device) # [B, z_dim]
104
+
105
+ # Decoder with teacher forcing:
106
+ # For teacher forcing we input the target sequence as inputs (shifted if you want SOS)
107
+ _, hidden = self.decoder(x, z) # predict tokens given the inputs (simpler)
108
+ return mu, logvar, z, hidden
109
+
110
+ def sample_from_z(self, z, sos_idx=None, max_len=32):
111
+ """
112
+ z: [B, z_dim] latent
113
+ returns: generated indices list (B x <=max_len)
114
+ """
115
+ if sos_idx is None:
116
+ sos_idx = random.choice([17, 9, 5, 11, 2]) #if not specified use top 5 most frequent starting fragments
117
+ self.eval()
118
+ B = z.size(0)
119
+ with torch.no_grad():
120
+
121
+ # start token: we will assume user uses a special sos index; if none, use first fragment in vocab
122
+ input_tok = torch.full((B,), sos_idx, dtype=torch.long, device=z.device) # [B]
123
+ hidden = None
124
+ generated = [input_tok.unsqueeze(1)]
125
+ for t in range(max_len):
126
+ logits, hidden = self.decoder.step(input_tok, z, hidden) # [B, vocab]
127
+ probs = F.softmax(logits, dim=-1)
128
+ # sample or argmax; use sampling to get diverse outputs
129
+ input_tok = torch.multinomial(probs, num_samples=1).squeeze(1) # [B]
130
+ generated.append(input_tok.unsqueeze(1))
131
+ gen = torch.cat(generated, dim=1) # [B, max_len]
132
+ return gen # indices