Upload folder using huggingface_hub
Browse files- README.md +70 -0
- dict_idx2frag.pt +3 -0
- tree_rnn-vae.pt +3 -0
- tree_rnn_vae_infer.py +141 -0
- tree_rnn_vae_model.py +132 -0
README.md
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# De Novo Drug Generator - RNN-VAE
|
| 3 |
+
|
| 4 |
+
De Novo Drug Generator - RNN-VAE is a deep learning model designed for generating novel drug molecules.
|
| 5 |
+
Training data from ChemBL library
|
| 6 |
+
|
| 7 |
+
Full project file at https://github.com/teohyc/drug_agent
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
## Usage
|
| 11 |
+
|
| 12 |
+
```python
|
| 13 |
+
from rdkit import Chem
|
| 14 |
+
from rdkit.Chem import Draw, Descriptors
|
| 15 |
+
from tree_rnn_vae_infer import generate_candidate_mol
|
| 16 |
+
from tree_rnn_vae_model import TreeEncoder, LatentHead, TreeVAE, TreeDecoder
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def compute_molecule_props(mol):
|
| 20 |
+
return {
|
| 21 |
+
"MW": Descriptors.MolWt(mol),
|
| 22 |
+
"logP": Descriptors.MolLogP(mol),
|
| 23 |
+
"HBD": Descriptors.NumHDonors(mol),
|
| 24 |
+
"HBA": Descriptors.NumHAcceptors(mol),
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# display molecule
|
| 29 |
+
def render_molecule_grid(selected):
|
| 30 |
+
if not selected:
|
| 31 |
+
return
|
| 32 |
+
|
| 33 |
+
mols, legends = [], []
|
| 34 |
+
|
| 35 |
+
if isinstance(selected, dict):
|
| 36 |
+
iterable = selected.items()
|
| 37 |
+
else:
|
| 38 |
+
iterable = enumerate(selected, 1)
|
| 39 |
+
|
| 40 |
+
for i, item in iterable:
|
| 41 |
+
if isinstance(selected, dict):
|
| 42 |
+
smi, props = i, item
|
| 43 |
+
else:
|
| 44 |
+
smi, props = item, None
|
| 45 |
+
|
| 46 |
+
mol = Chem.MolFromSmiles(smi)
|
| 47 |
+
if mol:
|
| 48 |
+
mols.append(mol)
|
| 49 |
+
if props is None:
|
| 50 |
+
props = compute_molecule_props(mol)
|
| 51 |
+
legends.append(
|
| 52 |
+
f"M{i}
|
| 53 |
+
MW={props['MW']:.0f}, logP={props['logP']:.2f}, "
|
| 54 |
+
f"HBD={props['HBD']}, HBA={props['HBA']}"
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
img = Draw.MolsToGridImage(
|
| 58 |
+
mols,
|
| 59 |
+
molsPerRow=3,
|
| 60 |
+
subImgSize=(400, 400),
|
| 61 |
+
legends=legends,
|
| 62 |
+
useSVG=False,
|
| 63 |
+
)
|
| 64 |
+
return img
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
molecules = generate_candidate_mol(num_samples=6, max_len=20) #change to your desired molecule size and number
|
| 68 |
+
img = render_molecule_grid(molecules)
|
| 69 |
+
img.show()
|
| 70 |
+
```
|
dict_idx2frag.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b45ac2d4f65aa76bd54a6742c48cc5b749d58fedb2f0309c648a258481155e5d
|
| 3 |
+
size 18769
|
tree_rnn-vae.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5abe40c97c5915ac21890b44d309edba62c9933fc66d8c9e0043644ef46e7cc8
|
| 3 |
+
size 3963609
|
tree_rnn_vae_infer.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
from torch.utils.data import Dataset, DataLoader
|
| 5 |
+
import math
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from rdkit import Chem
|
| 9 |
+
from rdkit.Chem import Draw, AllChem
|
| 10 |
+
import random
|
| 11 |
+
import re
|
| 12 |
+
from tree_rnn_vae_model import TreeEncoder, LatentHead, TreeVAE, TreeDecoder
|
| 13 |
+
|
| 14 |
+
def decode_fragments_to_smiles(batch_indices, vocab):
|
| 15 |
+
smiles_list = []
|
| 16 |
+
for seq in batch_indices:
|
| 17 |
+
fragments = [vocab[idx] for idx in seq.tolist()]
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
mol = Chem.MolFromSmiles(fragments)
|
| 21 |
+
if mol is not None:
|
| 22 |
+
smiles_list.append(Chem.MolToSmiles(mol))
|
| 23 |
+
else:
|
| 24 |
+
smiles_list.append(fragments) # fallback
|
| 25 |
+
except:
|
| 26 |
+
smiles_list.append(fragments)
|
| 27 |
+
return smiles_list
|
| 28 |
+
|
| 29 |
+
def normalize_attachment_points(smiles):
|
| 30 |
+
return re.sub(r'\[\d+\*\]', '[*]', smiles)
|
| 31 |
+
|
| 32 |
+
def deduplicate_fragments(fragments):
|
| 33 |
+
seen = set()
|
| 34 |
+
unique = []
|
| 35 |
+
|
| 36 |
+
for f in fragments:
|
| 37 |
+
if f not in seen:
|
| 38 |
+
unique.append(f)
|
| 39 |
+
seen.add(f)
|
| 40 |
+
return unique
|
| 41 |
+
|
| 42 |
+
def fragments_to_mols(fragments):
|
| 43 |
+
mols = []
|
| 44 |
+
for f in fragments:
|
| 45 |
+
mol = Chem.MolFromSmiles(f)
|
| 46 |
+
if mol is None:
|
| 47 |
+
continue # skip invalid fragment
|
| 48 |
+
mols.append(mol)
|
| 49 |
+
return mols
|
| 50 |
+
|
| 51 |
+
def get_dummy_atom_indices(mol):
|
| 52 |
+
return [atom.GetIdx() for atom in mol.GetAtoms() if atom.GetSymbol() == '*']
|
| 53 |
+
|
| 54 |
+
def get_linear_connection_pairs(mols):
|
| 55 |
+
"""Connect last dummy atom of mol i to first dummy atom of mol i+1"""
|
| 56 |
+
pairs = []
|
| 57 |
+
for i in range(len(mols)-1):
|
| 58 |
+
idx_last_i = get_dummy_atom_indices(mols[i])[-1]
|
| 59 |
+
idx_first_next = get_dummy_atom_indices(mols[i+1])[0]
|
| 60 |
+
pairs.append((i, idx_last_i, i+1, idx_first_next))
|
| 61 |
+
return pairs
|
| 62 |
+
|
| 63 |
+
def combine_fragments_and_connect(fragments, connection_pairs):
|
| 64 |
+
mols = fragments_to_mols(fragments)
|
| 65 |
+
if not mols:
|
| 66 |
+
return None
|
| 67 |
+
|
| 68 |
+
# Combine molecules sequentially
|
| 69 |
+
combined = Chem.CombineMols(mols[0], mols[1]) if len(mols) > 1 else mols[0]
|
| 70 |
+
offsets = [0, mols[0].GetNumAtoms()]
|
| 71 |
+
for i in range(2, len(mols)):
|
| 72 |
+
combined = Chem.CombineMols(combined, mols[i])
|
| 73 |
+
offsets.append(sum([m.GetNumAtoms() for m in mols[:i]]))
|
| 74 |
+
|
| 75 |
+
rw_mol = Chem.RWMol(combined)
|
| 76 |
+
|
| 77 |
+
# Add bonds for connection pairs
|
| 78 |
+
for f1, a1, f2, a2 in connection_pairs:
|
| 79 |
+
rw_mol.AddBond(offsets[f1] + a1, offsets[f2] + a2, Chem.BondType.SINGLE)
|
| 80 |
+
|
| 81 |
+
# Sanitize
|
| 82 |
+
try:
|
| 83 |
+
Chem.SanitizeMol(rw_mol)
|
| 84 |
+
return Chem.MolToSmiles(rw_mol)
|
| 85 |
+
except:
|
| 86 |
+
return None
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def process_one_molecule(raw_fragments):
|
| 90 |
+
# Normalize & deduplicate
|
| 91 |
+
frags = [normalize_attachment_points(f) for f in raw_fragments]
|
| 92 |
+
frags = deduplicate_fragments(frags)
|
| 93 |
+
|
| 94 |
+
if len(frags) == 0:
|
| 95 |
+
return None
|
| 96 |
+
|
| 97 |
+
# Convert to mols
|
| 98 |
+
mols = fragments_to_mols(frags)
|
| 99 |
+
if len(mols) < 1:
|
| 100 |
+
return None
|
| 101 |
+
|
| 102 |
+
# Build connections (linear for now)
|
| 103 |
+
connection_pairs = get_linear_connection_pairs(mols)
|
| 104 |
+
|
| 105 |
+
# Combine & sanitize
|
| 106 |
+
smiles = combine_fragments_and_connect(frags, connection_pairs)
|
| 107 |
+
return smiles
|
| 108 |
+
|
| 109 |
+
def replace_dummy_with_carbon(smiles):
|
| 110 |
+
return smiles.replace('*', 'C')
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
#main
|
| 114 |
+
def generate_candidate_mol(num_samples=6, max_len=20):
|
| 115 |
+
'''This function generates candidate molecules using an in-house-designed Tree-RNN VAE model. It returns a list of SMILES strings representing the generated molecules as well as displaying their molecular structure.'''
|
| 116 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 117 |
+
|
| 118 |
+
idx2frag = torch.load("dict_idx2frag.pt") # {index: fragment_string}
|
| 119 |
+
vocab_size = len(idx2frag)
|
| 120 |
+
|
| 121 |
+
model = torch.load("tree_rnn-vae.pt", weights_only=False, map_location=device)
|
| 122 |
+
|
| 123 |
+
model.eval()
|
| 124 |
+
z = torch.randn(num_samples, model.decoder.z_to_hidden.in_features, device=device)
|
| 125 |
+
sampled_indices = model.sample_from_z(z, sos_idx=None, max_len=max_len)
|
| 126 |
+
|
| 127 |
+
smiles_out = decode_fragments_to_smiles(sampled_indices, idx2frag)
|
| 128 |
+
|
| 129 |
+
smiles_list = []
|
| 130 |
+
for mol in smiles_out:
|
| 131 |
+
smiles = process_one_molecule(mol)
|
| 132 |
+
smiles_list.append(smiles)
|
| 133 |
+
|
| 134 |
+
smiles_list = [replace_dummy_with_carbon(s) for s in smiles_list]
|
| 135 |
+
|
| 136 |
+
'''mols = [Chem.MolFromSmiles(s) for s in smiles_list]
|
| 137 |
+
img = Draw.MolsToGridImage(mols, molsPerRow=2, subImgSize=(600,600), legends=smiles_list)
|
| 138 |
+
img.show()'''
|
| 139 |
+
|
| 140 |
+
return smiles_list
|
| 141 |
+
|
tree_rnn_vae_model.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
from torch.utils.data import Dataset, DataLoader
|
| 5 |
+
import math
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from rdkit import Chem
|
| 9 |
+
from rdkit.Chem import Draw, AllChem
|
| 10 |
+
import random
|
| 11 |
+
|
| 12 |
+
class TreeEncoder(nn.Module):
|
| 13 |
+
def __init__(self, vocab_size, embed_dim, enc_hidden, pad_idx=0):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
|
| 16 |
+
self.gru = nn.GRU(embed_dim, enc_hidden, batch_first=True)
|
| 17 |
+
self.enc_hidden = enc_hidden
|
| 18 |
+
|
| 19 |
+
def forward(self, x, lengths):
|
| 20 |
+
"""
|
| 21 |
+
x: [B, L] LongTensor (padded fragment indices)
|
| 22 |
+
lengths: [B] LongTensor
|
| 23 |
+
returns: last_hidden: [B, enc_hidden]
|
| 24 |
+
"""
|
| 25 |
+
emb = self.embed(x) # [B, L, E]
|
| 26 |
+
# pack
|
| 27 |
+
packed = nn.utils.rnn.pack_padded_sequence(emb, lengths.cpu(), batch_first=True, enforce_sorted=False)
|
| 28 |
+
packed_out, h_n = self.gru(packed) # h_n: [1, B, enc_hidden]
|
| 29 |
+
return h_n.squeeze(0) # [B, enc_hidden]
|
| 30 |
+
|
| 31 |
+
class LatentHead(nn.Module):
|
| 32 |
+
def __init__(self, enc_hidden, z_dim):
|
| 33 |
+
super().__init__()
|
| 34 |
+
self.linear_mu = nn.Linear(enc_hidden, z_dim)
|
| 35 |
+
self.linear_logvar = nn.Linear(enc_hidden, z_dim)
|
| 36 |
+
|
| 37 |
+
def forward(self, h):
|
| 38 |
+
mu = self.linear_mu(h)
|
| 39 |
+
logvar = self.linear_logvar(h)
|
| 40 |
+
return mu, logvar
|
| 41 |
+
|
| 42 |
+
def reparameterize(mu, logvar):
|
| 43 |
+
std = (0.5 * logvar).exp()
|
| 44 |
+
eps = torch.randn_like(std)
|
| 45 |
+
return mu + eps * std
|
| 46 |
+
|
| 47 |
+
class TreeDecoder(nn.Module):
|
| 48 |
+
def __init__(self, vocab_size, embed_dim, dec_hidden, z_dim, pad_idx=0):
|
| 49 |
+
super().__init__()
|
| 50 |
+
self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
|
| 51 |
+
self.z_to_hidden = nn.Linear(z_dim, dec_hidden)
|
| 52 |
+
# GRU input: embed_dim + z_dim
|
| 53 |
+
self.gru = nn.GRU(embed_dim + z_dim, dec_hidden, batch_first=True)
|
| 54 |
+
self.out = nn.Linear(dec_hidden, vocab_size)
|
| 55 |
+
|
| 56 |
+
def init_hidden_from_z(self, z):
|
| 57 |
+
# z: [B, z_dim] -> [1, B, dec_hidden]
|
| 58 |
+
return torch.tanh(self.z_to_hidden(z)).unsqueeze(0)
|
| 59 |
+
|
| 60 |
+
def forward(self, inputs, z, hidden=None):
|
| 61 |
+
# Full sequence forward (teacher forcing)
|
| 62 |
+
B, L = inputs.size()
|
| 63 |
+
if hidden is None:
|
| 64 |
+
hidden = self.init_hidden_from_z(z)
|
| 65 |
+
|
| 66 |
+
emb = self.embed(inputs) # [B,L,E]
|
| 67 |
+
z_exp = z.unsqueeze(1).expand(-1, L, -1) # [B,L,z_dim]
|
| 68 |
+
gru_input = torch.cat([emb, z_exp], dim=-1) # [B,L,E+z]
|
| 69 |
+
out, hidden = self.gru(gru_input, hidden) # [B,L,H]
|
| 70 |
+
logits = self.out(out) # [B,L,vocab]
|
| 71 |
+
return logits, hidden
|
| 72 |
+
|
| 73 |
+
def step(self, input_token, z, hidden=None):
|
| 74 |
+
"""
|
| 75 |
+
Single time-step for autoregressive inference.
|
| 76 |
+
input_token: [B] long
|
| 77 |
+
z: [B, z_dim]
|
| 78 |
+
hidden: previous hidden
|
| 79 |
+
returns logits [B, vocab], new_hidden
|
| 80 |
+
"""
|
| 81 |
+
emb = self.embed(input_token).unsqueeze(1) # [B,1,E]
|
| 82 |
+
gru_in = torch.cat([emb, z.unsqueeze(1)], dim=-1) # [B,1,E+z]
|
| 83 |
+
out, h_n = self.gru(gru_in, hidden) # out: [B,1,dec_hidden]
|
| 84 |
+
logits = self.out(out.squeeze(1)) # [B, vocab]
|
| 85 |
+
return logits, h_n
|
| 86 |
+
|
| 87 |
+
class TreeVAE(nn.Module):
|
| 88 |
+
def __init__(self, vocab_size, embed_dim, enc_hidden, dec_hidden, z_dim, pad_idx=0):
|
| 89 |
+
super().__init__()
|
| 90 |
+
self.encoder = TreeEncoder(vocab_size, embed_dim, enc_hidden, pad_idx)
|
| 91 |
+
self.latent = LatentHead(enc_hidden, z_dim)
|
| 92 |
+
self.decoder = TreeDecoder(vocab_size, embed_dim, dec_hidden, z_dim, pad_idx)
|
| 93 |
+
|
| 94 |
+
def forward(self, x, lengths, tf_prob=1.0):
|
| 95 |
+
"""
|
| 96 |
+
x: [B, L] padded target sequences (we will use teacher forcing)
|
| 97 |
+
lengths: [B] actual lengths
|
| 98 |
+
tf_prob: teacher forcing probability (0..1)
|
| 99 |
+
Returns: logits [B, L, V], mu, logvar
|
| 100 |
+
"""
|
| 101 |
+
h_enc = self.encoder(x, lengths) # [B, enc_hidden]
|
| 102 |
+
mu, logvar = self.latent(h_enc) # [B, z_dim]
|
| 103 |
+
z = reparameterize(mu, logvar).to(device) # [B, z_dim]
|
| 104 |
+
|
| 105 |
+
# Decoder with teacher forcing:
|
| 106 |
+
# For teacher forcing we input the target sequence as inputs (shifted if you want SOS)
|
| 107 |
+
_, hidden = self.decoder(x, z) # predict tokens given the inputs (simpler)
|
| 108 |
+
return mu, logvar, z, hidden
|
| 109 |
+
|
| 110 |
+
def sample_from_z(self, z, sos_idx=None, max_len=32):
|
| 111 |
+
"""
|
| 112 |
+
z: [B, z_dim] latent
|
| 113 |
+
returns: generated indices list (B x <=max_len)
|
| 114 |
+
"""
|
| 115 |
+
if sos_idx is None:
|
| 116 |
+
sos_idx = random.choice([17, 9, 5, 11, 2]) #if not specified use top 5 most frequent starting fragments
|
| 117 |
+
self.eval()
|
| 118 |
+
B = z.size(0)
|
| 119 |
+
with torch.no_grad():
|
| 120 |
+
|
| 121 |
+
# start token: we will assume user uses a special sos index; if none, use first fragment in vocab
|
| 122 |
+
input_tok = torch.full((B,), sos_idx, dtype=torch.long, device=z.device) # [B]
|
| 123 |
+
hidden = None
|
| 124 |
+
generated = [input_tok.unsqueeze(1)]
|
| 125 |
+
for t in range(max_len):
|
| 126 |
+
logits, hidden = self.decoder.step(input_tok, z, hidden) # [B, vocab]
|
| 127 |
+
probs = F.softmax(logits, dim=-1)
|
| 128 |
+
# sample or argmax; use sampling to get diverse outputs
|
| 129 |
+
input_tok = torch.multinomial(probs, num_samples=1).squeeze(1) # [B]
|
| 130 |
+
generated.append(input_tok.unsqueeze(1))
|
| 131 |
+
gen = torch.cat(generated, dim=1) # [B, max_len]
|
| 132 |
+
return gen # indices
|