Spaces:
Sleeping
Sleeping
Update generate.py
Browse files- generate.py +13 -13
generate.py
CHANGED
|
@@ -41,18 +41,18 @@ class GenerateRunner():
|
|
| 41 |
|
| 42 |
# self.save_path = os.path.join('experiments', opt.save_directory, opt.test_file_name,
|
| 43 |
# f'evaluation_{opt.epoch}')
|
| 44 |
-
path = Path(os.path.join(opt.save_directory))
|
| 45 |
-
path.mkdir(parents=True, exist_ok=True)
|
| 46 |
-
self.save_path = os.path.join(path)
|
| 47 |
-
self.exist_flag = Path(f'{self.save_path}/generated_molecules.csv').exists()
|
| 48 |
-
self.overwrite = opt.overwrite
|
| 49 |
self.dev_no = opt.dev_no
|
| 50 |
self.device = torch.device('cpu')
|
| 51 |
-
global LOG
|
| 52 |
-
LOG = ul.get_logger(name="generate",
|
| 53 |
-
|
| 54 |
-
LOG.info(opt)
|
| 55 |
-
LOG.info("Save directory: {}".format(self.save_path))
|
| 56 |
|
| 57 |
# Load vocabulary
|
| 58 |
with open(os.path.join(opt.vocab_path, 'vocab.pkl'), "rb") as input_file:
|
|
@@ -144,7 +144,7 @@ class GenerateRunner():
|
|
| 144 |
data_sorted['Predicted_smi_{}'.format(i + 1)] = sampled_smiles_list[:, i]
|
| 145 |
|
| 146 |
result_path = os.path.join(self.save_path, "generated_molecules.csv")
|
| 147 |
-
LOG.info("Save to {}".format(result_path))
|
| 148 |
data_sorted.to_csv(result_path, index=False)
|
| 149 |
|
| 150 |
def sample(self, model_choice, model, src, src_mask, source_length, decode_type, num_samples=10,
|
|
@@ -204,8 +204,8 @@ class GenerateRunner():
|
|
| 204 |
elif model_choice == 'seq2seq':
|
| 205 |
sequences = self.sample_seq2seq(model, mask_current, batch_index_current, decoder_hidden,
|
| 206 |
encoder_outputs, max_len, device)
|
| 207 |
-
else:
|
| 208 |
-
|
| 209 |
|
| 210 |
# Check valid and unique
|
| 211 |
smiles = []
|
|
|
|
| 41 |
|
| 42 |
# self.save_path = os.path.join('experiments', opt.save_directory, opt.test_file_name,
|
| 43 |
# f'evaluation_{opt.epoch}')
|
| 44 |
+
# path = Path(os.path.join(opt.save_directory))
|
| 45 |
+
# path.mkdir(parents=True, exist_ok=True)
|
| 46 |
+
# self.save_path = os.path.join(path)
|
| 47 |
+
# self.exist_flag = Path(f'{self.save_path}/generated_molecules.csv').exists()
|
| 48 |
+
# self.overwrite = opt.overwrite
|
| 49 |
self.dev_no = opt.dev_no
|
| 50 |
self.device = torch.device('cpu')
|
| 51 |
+
# global LOG
|
| 52 |
+
# LOG = ul.get_logger(name="generate",
|
| 53 |
+
# log_path=os.path.join(self.save_path, 'generate.log'))
|
| 54 |
+
# LOG.info(opt)
|
| 55 |
+
# LOG.info("Save directory: {}".format(self.save_path))
|
| 56 |
|
| 57 |
# Load vocabulary
|
| 58 |
with open(os.path.join(opt.vocab_path, 'vocab.pkl'), "rb") as input_file:
|
|
|
|
| 144 |
data_sorted['Predicted_smi_{}'.format(i + 1)] = sampled_smiles_list[:, i]
|
| 145 |
|
| 146 |
result_path = os.path.join(self.save_path, "generated_molecules.csv")
|
| 147 |
+
# LOG.info("Save to {}".format(result_path))
|
| 148 |
data_sorted.to_csv(result_path, index=False)
|
| 149 |
|
| 150 |
def sample(self, model_choice, model, src, src_mask, source_length, decode_type, num_samples=10,
|
|
|
|
| 204 |
elif model_choice == 'seq2seq':
|
| 205 |
sequences = self.sample_seq2seq(model, mask_current, batch_index_current, decoder_hidden,
|
| 206 |
encoder_outputs, max_len, device)
|
| 207 |
+
# else:
|
| 208 |
+
# LOG.info('Specify transformer or seq2seq for model_choice')
|
| 209 |
|
| 210 |
# Check valid and unique
|
| 211 |
smiles = []
|