Spaces:
Sleeping
Sleeping
Gilmullin Almaz
commited on
Commit
·
43616cc
1
Parent(s):
088dd6f
Refactor error handling in setup_planning_options to raise specific exceptions and improve clarity in SMILES parsing and resource loading.
Browse files
app.py
CHANGED
|
@@ -322,7 +322,7 @@ def setup_planning_options():
|
|
| 322 |
options=("uct", "puct", "value"),
|
| 323 |
index=0,
|
| 324 |
key="ucb_type_input",
|
| 325 |
-
)
|
| 326 |
c_ucb = st.number_input(
|
| 327 |
"C coefficient of UCB",
|
| 328 |
value=0.1,
|
|
@@ -392,70 +392,70 @@ def setup_planning_options():
|
|
| 392 |
try:
|
| 393 |
target_molecule = mol_from_smiles(active_smile_code, clean_stereo=True)
|
| 394 |
if target_molecule is None:
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
)
|
| 408 |
-
st.write("Loading reaction rules...")
|
| 409 |
-
reaction_rules = load_reaction_rules(reaction_rules_path)
|
| 410 |
-
st.write("Loading policy network...")
|
| 411 |
-
policy_config = PolicyNetworkConfig(
|
| 412 |
-
weights_path=ranking_policy_weights_path
|
| 413 |
-
)
|
| 414 |
-
policy_function = PolicyNetworkFunction(
|
| 415 |
-
policy_config=policy_config
|
| 416 |
-
)
|
| 417 |
-
status.update(label="Resources loaded!", state="complete")
|
| 418 |
-
|
| 419 |
-
tree_config = TreeConfig(
|
| 420 |
-
search_strategy=planning_params["search_strategy"],
|
| 421 |
-
evaluation_type="rollout", # This was hardcoded, keeping it.
|
| 422 |
-
max_iterations=planning_params["max_iterations"],
|
| 423 |
-
max_depth=planning_params["max_depth"],
|
| 424 |
-
min_mol_size=planning_params["min_mol_size"],
|
| 425 |
-
init_node_value=0.5, # This was hardcoded
|
| 426 |
-
ucb_type=planning_params["ucb_type"],
|
| 427 |
-
c_ucb=planning_params["c_ucb"],
|
| 428 |
-
silent=True, # This was hardcoded
|
| 429 |
)
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
building_blocks=building_blocks,
|
| 436 |
-
expansion_function=policy_function,
|
| 437 |
-
evaluation_function=None, # This was hardcoded
|
| 438 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 439 |
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 450 |
|
| 451 |
-
|
| 452 |
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
|
| 458 |
-
except
|
| 459 |
st.error(f"An error occurred during planning: {e}")
|
| 460 |
st.session_state.planning_done = False
|
| 461 |
|
|
|
|
| 322 |
options=("uct", "puct", "value"),
|
| 323 |
index=0,
|
| 324 |
key="ucb_type_input",
|
| 325 |
+
)
|
| 326 |
c_ucb = st.number_input(
|
| 327 |
"C coefficient of UCB",
|
| 328 |
value=0.1,
|
|
|
|
| 392 |
try:
|
| 393 |
target_molecule = mol_from_smiles(active_smile_code, clean_stereo=True)
|
| 394 |
if target_molecule is None:
|
| 395 |
+
raise ValueError(f"Could not parse the input SMILES: {active_smile_code}")
|
| 396 |
+
|
| 397 |
+
(
|
| 398 |
+
building_blocks_path,
|
| 399 |
+
ranking_policy_weights_path,
|
| 400 |
+
reaction_rules_path,
|
| 401 |
+
) = load_planning_resources_cached()
|
| 402 |
+
with st.spinner("Running retrosynthetic planning..."):
|
| 403 |
+
with st.status("Loading resources...", expanded=False) as status:
|
| 404 |
+
st.write("Loading building blocks...")
|
| 405 |
+
building_blocks = load_building_blocks(
|
| 406 |
+
building_blocks_path, standardize=False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 407 |
)
|
| 408 |
+
st.write("Loading reaction rules...")
|
| 409 |
+
reaction_rules = load_reaction_rules(reaction_rules_path)
|
| 410 |
+
st.write("Loading policy network...")
|
| 411 |
+
policy_config = PolicyNetworkConfig(
|
| 412 |
+
weights_path=ranking_policy_weights_path
|
|
|
|
|
|
|
|
|
|
| 413 |
)
|
| 414 |
+
policy_function = PolicyNetworkFunction(
|
| 415 |
+
policy_config=policy_config
|
| 416 |
+
)
|
| 417 |
+
status.update(label="Resources loaded!", state="complete")
|
| 418 |
+
|
| 419 |
+
tree_config = TreeConfig(
|
| 420 |
+
search_strategy=planning_params["search_strategy"],
|
| 421 |
+
evaluation_type="rollout",
|
| 422 |
+
max_iterations=planning_params["max_iterations"],
|
| 423 |
+
max_depth=planning_params["max_depth"],
|
| 424 |
+
min_mol_size=planning_params["min_mol_size"],
|
| 425 |
+
init_node_value=0.5,
|
| 426 |
+
ucb_type=planning_params["ucb_type"],
|
| 427 |
+
c_ucb=planning_params["c_ucb"],
|
| 428 |
+
silent=True,
|
| 429 |
+
)
|
| 430 |
|
| 431 |
+
tree = Tree(
|
| 432 |
+
target=target_molecule,
|
| 433 |
+
config=tree_config,
|
| 434 |
+
reaction_rules=reaction_rules,
|
| 435 |
+
building_blocks=building_blocks,
|
| 436 |
+
expansion_function=policy_function,
|
| 437 |
+
evaluation_function=None,
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
mcts_progress_text = "Running MCTS iterations..."
|
| 441 |
+
mcts_bar = st.progress(0, text=mcts_progress_text)
|
| 442 |
+
for step, (solved, route_id) in enumerate(tree):
|
| 443 |
+
progress_value = min(
|
| 444 |
+
1.0, (step + 1) / planning_params["max_iterations"]
|
| 445 |
+
)
|
| 446 |
+
mcts_bar.progress(
|
| 447 |
+
progress_value,
|
| 448 |
+
text=f"{mcts_progress_text} ({step+1}/{planning_params['max_iterations']})",
|
| 449 |
+
)
|
| 450 |
|
| 451 |
+
res = extract_tree_stats(tree, target_molecule)
|
| 452 |
|
| 453 |
+
st.session_state["tree"] = tree
|
| 454 |
+
st.session_state["res"] = res
|
| 455 |
+
st.session_state.planning_done = True
|
| 456 |
+
st.rerun()
|
| 457 |
|
| 458 |
+
except (ValueError, KeyError, FileNotFoundError, TypeError) as e:
|
| 459 |
st.error(f"An error occurred during planning: {e}")
|
| 460 |
st.session_state.planning_done = False
|
| 461 |
|