Spaces:
Runtime error
Runtime error
Update utils.py
Browse files
utils.py
CHANGED
|
@@ -59,10 +59,8 @@ class DataCollatorForCausalLMEval(object):
|
|
| 59 |
|
| 60 |
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
| 61 |
|
| 62 |
-
print(instances)
|
| 63 |
srcs = instances[0]['src']
|
| 64 |
task_type = instances[0]['task_type']
|
| 65 |
-
print(task_type)
|
| 66 |
|
| 67 |
if task_type == 'retrosynthesis':
|
| 68 |
src_start_str = self.product_start_str
|
|
@@ -78,7 +76,6 @@ class DataCollatorForCausalLMEval(object):
|
|
| 78 |
data_dict = {
|
| 79 |
'generation_prompts': generation_prompts
|
| 80 |
}
|
| 81 |
-
print(data_dict)
|
| 82 |
return data_dict
|
| 83 |
|
| 84 |
def smart_tokenizer_and_embedding_resize(
|
|
@@ -131,7 +128,6 @@ class ReactionPredictionModel():
|
|
| 131 |
)
|
| 132 |
self.load_forward_model(candidate_models[model])
|
| 133 |
|
| 134 |
-
print(self.forward_model.device, self.retro_model.device)
|
| 135 |
string_template_path = hf_hub_download(candidate_models[list(candidate_models.keys())[0]], filename="string_template.json", token = os.environ.get("TOKEN"))
|
| 136 |
string_template = json.load(open(string_template_path, 'r'))
|
| 137 |
reactant_start_str = string_template['REACTANTS_START_STRING']
|
|
@@ -220,8 +216,6 @@ class ReactionPredictionModel():
|
|
| 220 |
|
| 221 |
if task_type == "retrosynthesis":
|
| 222 |
inputs = {k: v.to(self.retro_model.device) for k, v in inputs.items()}
|
| 223 |
-
print(inputs)
|
| 224 |
-
print(self.retro_model.device)
|
| 225 |
with torch.no_grad():
|
| 226 |
outputs = self.retro_model.generate(**inputs, max_length=512, num_return_sequences=10,
|
| 227 |
do_sample=False, num_beams=10,
|
|
@@ -232,8 +226,6 @@ class ReactionPredictionModel():
|
|
| 232 |
)
|
| 233 |
else:
|
| 234 |
inputs = {k: v.to(self.forward_model.device) for k, v in inputs.items()}
|
| 235 |
-
print(inputs)
|
| 236 |
-
print(self.forward_model.device)
|
| 237 |
with torch.no_grad():
|
| 238 |
outputs = self.forward_model.generate(**inputs, max_length=512, num_return_sequences=10,
|
| 239 |
do_sample=False, num_beams=10,
|
|
@@ -243,11 +235,9 @@ class ReactionPredictionModel():
|
|
| 243 |
length_penalty=0.0,
|
| 244 |
)
|
| 245 |
|
| 246 |
-
print(outputs)
|
| 247 |
original_smiles_list = self.tokenizer.batch_decode(outputs.detach().cpu().numpy()[:, len(inputs['input_ids'][0]):],
|
| 248 |
skip_special_tokens=True)
|
| 249 |
original_smiles_list = map(lambda x: x.replace(" ", ""), original_smiles_list)
|
| 250 |
-
print(original_smiles_list)
|
| 251 |
# canonize the SMILES
|
| 252 |
canonized_smiles_list = []
|
| 253 |
temp = []
|
|
@@ -262,7 +252,6 @@ class ReactionPredictionModel():
|
|
| 262 |
predictions.append(canonized_smiles_list)
|
| 263 |
|
| 264 |
rank, invalid_rate = compute_rank(predictions)
|
| 265 |
-
print(predictions, rank)
|
| 266 |
return rank
|
| 267 |
|
| 268 |
def predict_single_smiles(self, smiles, task_type):
|
|
|
|
| 59 |
|
| 60 |
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
| 61 |
|
|
|
|
| 62 |
srcs = instances[0]['src']
|
| 63 |
task_type = instances[0]['task_type']
|
|
|
|
| 64 |
|
| 65 |
if task_type == 'retrosynthesis':
|
| 66 |
src_start_str = self.product_start_str
|
|
|
|
| 76 |
data_dict = {
|
| 77 |
'generation_prompts': generation_prompts
|
| 78 |
}
|
|
|
|
| 79 |
return data_dict
|
| 80 |
|
| 81 |
def smart_tokenizer_and_embedding_resize(
|
|
|
|
| 128 |
)
|
| 129 |
self.load_forward_model(candidate_models[model])
|
| 130 |
|
|
|
|
| 131 |
string_template_path = hf_hub_download(candidate_models[list(candidate_models.keys())[0]], filename="string_template.json", token = os.environ.get("TOKEN"))
|
| 132 |
string_template = json.load(open(string_template_path, 'r'))
|
| 133 |
reactant_start_str = string_template['REACTANTS_START_STRING']
|
|
|
|
| 216 |
|
| 217 |
if task_type == "retrosynthesis":
|
| 218 |
inputs = {k: v.to(self.retro_model.device) for k, v in inputs.items()}
|
|
|
|
|
|
|
| 219 |
with torch.no_grad():
|
| 220 |
outputs = self.retro_model.generate(**inputs, max_length=512, num_return_sequences=10,
|
| 221 |
do_sample=False, num_beams=10,
|
|
|
|
| 226 |
)
|
| 227 |
else:
|
| 228 |
inputs = {k: v.to(self.forward_model.device) for k, v in inputs.items()}
|
|
|
|
|
|
|
| 229 |
with torch.no_grad():
|
| 230 |
outputs = self.forward_model.generate(**inputs, max_length=512, num_return_sequences=10,
|
| 231 |
do_sample=False, num_beams=10,
|
|
|
|
| 235 |
length_penalty=0.0,
|
| 236 |
)
|
| 237 |
|
|
|
|
| 238 |
original_smiles_list = self.tokenizer.batch_decode(outputs.detach().cpu().numpy()[:, len(inputs['input_ids'][0]):],
|
| 239 |
skip_special_tokens=True)
|
| 240 |
original_smiles_list = map(lambda x: x.replace(" ", ""), original_smiles_list)
|
|
|
|
| 241 |
# canonize the SMILES
|
| 242 |
canonized_smiles_list = []
|
| 243 |
temp = []
|
|
|
|
| 252 |
predictions.append(canonized_smiles_list)
|
| 253 |
|
| 254 |
rank, invalid_rate = compute_rank(predictions)
|
|
|
|
| 255 |
return rank
|
| 256 |
|
| 257 |
def predict_single_smiles(self, smiles, task_type):
|