from MoleculeProcessing.utils.utils import * from MoleculeProcessing.utils.utils_sample import * import torch.nn.functional as F def sample(model,vocab_bos,size_batch=32,size_block=70,temperature=1.,): model,device = load_to_device(model) model.eval() with torch.no_grad(): tensor_sampled = torch.zeros(size_batch,size_block+1,dtype=torch.long,device=device) tensor_sampled[:,0] = vocab_bos hiddens = None for i in range(size_block): input_current = tensor_sampled[:,[i]] probs,hiddens = model.forward(input_current,hiddens) probs = probs[:,-1] probs = probs * temperature probs = F.softmax(probs,dim=-1) sample = torch.distributions.categorical.Categorical(probs).sample() tensor_sampled[:,i+1] = sample return tensor_sampled