Update tool/comget/generator.py
Browse files- tool/comget/generator.py +3 -4
tool/comget/generator.py
CHANGED
|
@@ -17,8 +17,7 @@ import sys
|
|
| 17 |
sys.path.append(os.path.join(RDConfig.RDContribDir, 'SA_Score'))
|
| 18 |
from rdkit import Chem
|
| 19 |
import os
|
| 20 |
-
|
| 21 |
-
|
| 22 |
def get_mol(smiles_or_mol):
|
| 23 |
'''
|
| 24 |
Loads SMILES/molecule into RDKit's object
|
|
@@ -206,9 +205,9 @@ def generation(value):
|
|
| 206 |
x = torch.tensor([stoi[s] for s in regex.findall(context)], dtype=torch.long)[None,...].repeat(args.batch_size, 1).to('cuda')
|
| 207 |
p = None
|
| 208 |
if len(args.props) == 1:
|
| 209 |
-
p = torch.tensor([c]).repeat(args.batch_size, 1).to('
|
| 210 |
else:
|
| 211 |
-
p = torch.tensor([c]).repeat(args.batch_size, 1).unsqueeze(1).to('
|
| 212 |
sca = None
|
| 213 |
y = sample(model, x, 300, temperature= 1.0, sample=True, top_k = 10, prop = p, scaffold = sca)
|
| 214 |
for gen_mol in y:
|
|
|
|
| 17 |
sys.path.append(os.path.join(RDConfig.RDContribDir, 'SA_Score'))
|
| 18 |
from rdkit import Chem
|
| 19 |
import os
|
| 20 |
+
|
|
|
|
| 21 |
def get_mol(smiles_or_mol):
|
| 22 |
'''
|
| 23 |
Loads SMILES/molecule into RDKit's object
|
|
|
|
| 205 |
x = torch.tensor([stoi[s] for s in regex.findall(context)], dtype=torch.long)[None,...].repeat(args.batch_size, 1).to('cuda')
|
| 206 |
p = None
|
| 207 |
if len(args.props) == 1:
|
| 208 |
+
p = torch.tensor([c]).repeat(args.batch_size, 1).to('cpu') # for single condition
|
| 209 |
else:
|
| 210 |
+
p = torch.tensor([c]).repeat(args.batch_size, 1).unsqueeze(1).to('cpu') # for multiple conditions
|
| 211 |
sca = None
|
| 212 |
y = sample(model, x, 300, temperature= 1.0, sample=True, top_k = 10, prop = p, scaffold = sca)
|
| 213 |
for gen_mol in y:
|