Songyou commited on
Commit
adf1e27
·
verified ·
1 Parent(s): 2516db3

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +13 -13
generate.py CHANGED
@@ -41,18 +41,18 @@ class GenerateRunner():
41
 
42
  # self.save_path = os.path.join('experiments', opt.save_directory, opt.test_file_name,
43
  # f'evaluation_{opt.epoch}')
44
- path = Path(os.path.join(opt.save_directory))
45
- path.mkdir(parents=True, exist_ok=True)
46
- self.save_path = os.path.join(path)
47
- self.exist_flag = Path(f'{self.save_path}/generated_molecules.csv').exists()
48
- self.overwrite = opt.overwrite
49
  self.dev_no = opt.dev_no
50
  self.device = torch.device('cpu')
51
- global LOG
52
- LOG = ul.get_logger(name="generate",
53
- log_path=os.path.join(self.save_path, 'generate.log'))
54
- LOG.info(opt)
55
- LOG.info("Save directory: {}".format(self.save_path))
56
 
57
  # Load vocabulary
58
  with open(os.path.join(opt.vocab_path, 'vocab.pkl'), "rb") as input_file:
@@ -144,7 +144,7 @@ class GenerateRunner():
144
  data_sorted['Predicted_smi_{}'.format(i + 1)] = sampled_smiles_list[:, i]
145
 
146
  result_path = os.path.join(self.save_path, "generated_molecules.csv")
147
- LOG.info("Save to {}".format(result_path))
148
  data_sorted.to_csv(result_path, index=False)
149
 
150
  def sample(self, model_choice, model, src, src_mask, source_length, decode_type, num_samples=10,
@@ -204,8 +204,8 @@ class GenerateRunner():
204
  elif model_choice == 'seq2seq':
205
  sequences = self.sample_seq2seq(model, mask_current, batch_index_current, decoder_hidden,
206
  encoder_outputs, max_len, device)
207
- else:
208
- LOG.info('Specify transformer or seq2seq for model_choice')
209
 
210
  # Check valid and unique
211
  smiles = []
 
41
 
42
  # self.save_path = os.path.join('experiments', opt.save_directory, opt.test_file_name,
43
  # f'evaluation_{opt.epoch}')
44
+ # path = Path(os.path.join(opt.save_directory))
45
+ # path.mkdir(parents=True, exist_ok=True)
46
+ # self.save_path = os.path.join(path)
47
+ # self.exist_flag = Path(f'{self.save_path}/generated_molecules.csv').exists()
48
+ # self.overwrite = opt.overwrite
49
  self.dev_no = opt.dev_no
50
  self.device = torch.device('cpu')
51
+ # global LOG
52
+ # LOG = ul.get_logger(name="generate",
53
+ # log_path=os.path.join(self.save_path, 'generate.log'))
54
+ # LOG.info(opt)
55
+ # LOG.info("Save directory: {}".format(self.save_path))
56
 
57
  # Load vocabulary
58
  with open(os.path.join(opt.vocab_path, 'vocab.pkl'), "rb") as input_file:
 
144
  data_sorted['Predicted_smi_{}'.format(i + 1)] = sampled_smiles_list[:, i]
145
 
146
  result_path = os.path.join(self.save_path, "generated_molecules.csv")
147
+ # LOG.info("Save to {}".format(result_path))
148
  data_sorted.to_csv(result_path, index=False)
149
 
150
  def sample(self, model_choice, model, src, src_mask, source_length, decode_type, num_samples=10,
 
204
  elif model_choice == 'seq2seq':
205
  sequences = self.sample_seq2seq(model, mask_current, batch_index_current, decoder_hidden,
206
  encoder_outputs, max_len, device)
207
+ # else:
208
+ # LOG.info('Specify transformer or seq2seq for model_choice')
209
 
210
  # Check valid and unique
211
  smiles = []