File size: 3,206 Bytes
e2b7617
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import random
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def top_k_logits(logits, k):
    v, ix = torch.topk(logits, k)
    out = logits.clone()
    out[out < v[:, [-1]]] = -float('Inf')
    return out

@torch.no_grad()
def sample(model, x, steps, temperature=1.0, sample=False, top_k=None):
    block_size = model.get_block_size()
    model.eval()
    for k in range(steps):
        x_cond = x if x.size(1) <= block_size else x[:, -block_size:]
        logits, _ = model(x_cond)
        logits = logits[:, -1, :] / temperature
        if top_k is not None:
            logits = top_k_logits(logits, top_k)
        probs = F.softmax(logits, dim=-1)
        if sample:
            ix = torch.multinomial(probs, num_samples=1)
        else:
            _, ix = torch.topk(probs, k=1, dim=-1)
        x = torch.cat((x, ix), dim=1)

    return x




@torch.no_grad()
def sample(model, x, steps, temperature=1.0,boundary=None):
    block_size = model.get_block_size()
    model.eval()
    for k in range(steps):
        x_cond = x if x.size(1) <= block_size else x[:, -block_size:]
        logits, _ = model(x_cond,boundary=boundary)
        logits = logits[:, -1, :] / temperature
        probs = F.softmax(logits, dim=-1)
        ix = torch.multinomial(probs, num_samples=1)
        x = torch.cat((x, ix), dim=1)
    return x

'L_5*C(=O)NCc1cccc(OC)c1.*c1nsc2ccccc12COc1cccc(CNC(=O)c2cccc(NC(=O)c3nsc4ccccc34)c2)c1'

# for i in range(1,21):
def sample_L(i,option='string'):
    # i=2
    prefix = 'L_'+str(i)
    string_input = prefix + '*O=C1NN=Cc2c1cccc2.*O=C(C1CC1)N1CCNCC1'
    array_input = [vocab[a] for a in ['<bos>'] + list(string_input)]
    boundary = [len(array_input)]
    tensor_input = torch.tensor(array_input,device='cuda').unsqueeze(0).repeat(32,1)
    boundary = boundary*32
    tensor_output = sample(model,tensor_input,250,boundary=boundary)
    strings_output = []
    for j in range(tensor_output.shape[0]):
        list_string_output = [inv[a] for a in tensor_output[j,boundary[j]:].cpu().numpy() if a != vocab['<pad>']]
        # if list_string_output[0] == '<bos>':
        #     list_string_output = list_string_output[1:]
        if list_string_output[-1] == '<eos>':
            list_string_output = list_string_output[:-1]
        string_output = ''.join(list_string_output)
        strings_output.append(string_output)
        print(string_output)
    for j in range(tensor_output.shape[0]):
        if test_valid(strings_output[j]):
            print(1)
        else:
            print(0)

    # logits,_ = model(tensor_input,boundary=boundary)


['<bos>', 'L', '_', '5', '*', 'C', '(', '=', 'O', ')', 'N', 'C', 'c', '1', 'c', 'c', 'c', 'c', '(', 'O', 'C', ')', 'c', '1', '.', '*', 'c', '1', 'n', 's', 'c', '2', 'c', 'c', 'c', 'c', 'c', '1', '2', 'C', 'O', 'c', '1', 'c', 'c', 'c', 'c', '(', 'C', 'N', 'C', '(', '=', 'O', ')', 'c', '2', 'c', 'c', 'c', 'c', '(', 'N', 'C', '(', '=', 'O', ')', 'c', '3', 'n', 's', 'c', '4', 'c', 'c', 'c', 'c', 'c', '3', '4', ')', 'c', '2', ')', 'c', '1', '<eos>']