jinysun commited on
Commit
ceb03ca
·
verified ·
1 Parent(s): 1e09aa4

Update tool/comget/generator.py

Browse files
Files changed (1) hide show
  1. tool/comget/generator.py +3 -4
tool/comget/generator.py CHANGED
@@ -17,8 +17,7 @@ import sys
17
  sys.path.append(os.path.join(RDConfig.RDContribDir, 'SA_Score'))
18
  from rdkit import Chem
19
  import os
20
- os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
21
-
22
  def get_mol(smiles_or_mol):
23
  '''
24
  Loads SMILES/molecule into RDKit's object
@@ -206,9 +205,9 @@ def generation(value):
206
  x = torch.tensor([stoi[s] for s in regex.findall(context)], dtype=torch.long)[None,...].repeat(args.batch_size, 1).to('cuda')
207
  p = None
208
  if len(args.props) == 1:
209
- p = torch.tensor([c]).repeat(args.batch_size, 1).to('cuda') # for single condition
210
  else:
211
- p = torch.tensor([c]).repeat(args.batch_size, 1).unsqueeze(1).to('cuda') # for multiple conditions
212
  sca = None
213
  y = sample(model, x, 300, temperature= 1.0, sample=True, top_k = 10, prop = p, scaffold = sca)
214
  for gen_mol in y:
 
17
  sys.path.append(os.path.join(RDConfig.RDContribDir, 'SA_Score'))
18
  from rdkit import Chem
19
  import os
20
+
 
21
  def get_mol(smiles_or_mol):
22
  '''
23
  Loads SMILES/molecule into RDKit's object
 
205
  x = torch.tensor([stoi[s] for s in regex.findall(context)], dtype=torch.long)[None,...].repeat(args.batch_size, 1).to('cuda')
206
  p = None
207
  if len(args.props) == 1:
208
+ p = torch.tensor([c]).repeat(args.batch_size, 1).to('cpu') # for single condition
209
  else:
210
+ p = torch.tensor([c]).repeat(args.batch_size, 1).unsqueeze(1).to('cpu') # for multiple conditions
211
  sca = None
212
  y = sample(model, x, 300, temperature= 1.0, sample=True, top_k = 10, prop = p, scaffold = sca)
213
  for gen_mol in y: