Spaces:
Sleeping
Sleeping
Gilmullin Almaz
commited on
Commit
·
27a7101
1
Parent(s):
57a9d9a
draft subclustering - need to solve resetting
Browse files
app.py
CHANGED
|
@@ -23,6 +23,9 @@ from synplan.utils.visualisation import generate_results_html, get_route_svg
|
|
| 23 |
from cluster.super_cgr import *
|
| 24 |
from cluster.rs_cgr import *
|
| 25 |
from cluster.clustering import *
|
|
|
|
|
|
|
|
|
|
| 26 |
from StructureFingerprint import MorganFingerprint
|
| 27 |
|
| 28 |
import psutil
|
|
@@ -33,83 +36,6 @@ disable_progress_bars("huggingface_hub")
|
|
| 33 |
|
| 34 |
smiles_parser = SMILESRead.create_parser(ignore=True)
|
| 35 |
|
| 36 |
-
|
| 37 |
-
def reassign_nums_chunk(route_dict):
|
| 38 |
-
"""Process a chunk of routes for reassigning numbers"""
|
| 39 |
-
return {k: reassign_nums(v) for k, v in route_dict.items()}
|
| 40 |
-
|
| 41 |
-
def cluster_molecules_optimized(fingerprints_dict, max_clusters):
|
| 42 |
-
"""Memory-optimized version of cluster_molecules.
|
| 43 |
-
|
| 44 |
-
Args:
|
| 45 |
-
fingerprints_dict (dict): Dictionary of pre-computed fingerprints
|
| 46 |
-
max_clusters (int): Maximum number of clusters
|
| 47 |
-
|
| 48 |
-
Returns:
|
| 49 |
-
dict: Clustering results containing clusters_dict and cluster_labels
|
| 50 |
-
"""
|
| 51 |
-
try:
|
| 52 |
-
# Convert dictionary to arrays for efficient processing
|
| 53 |
-
labels = np.array(list(fingerprints_dict.keys()))
|
| 54 |
-
fingerprints = np.array(list(fingerprints_dict.values()))
|
| 55 |
-
|
| 56 |
-
# Calculate similarity matrix in chunks to save memory
|
| 57 |
-
chunk_size = 100
|
| 58 |
-
n_samples = len(fingerprints)
|
| 59 |
-
similarity_matrix = np.zeros((n_samples, n_samples))
|
| 60 |
-
|
| 61 |
-
for i in range(0, n_samples, chunk_size):
|
| 62 |
-
chunk_end = min(i + chunk_size, n_samples)
|
| 63 |
-
chunk = fingerprints[i:chunk_end]
|
| 64 |
-
|
| 65 |
-
# Calculate similarity for this chunk against all fingerprints
|
| 66 |
-
similarity_chunk = tanimoto_similarity_continuous(chunk, fingerprints)
|
| 67 |
-
similarity_matrix[i:chunk_end] = similarity_chunk
|
| 68 |
-
|
| 69 |
-
# Clear memory
|
| 70 |
-
del similarity_chunk
|
| 71 |
-
gc.collect()
|
| 72 |
-
|
| 73 |
-
# Convert to distance matrix
|
| 74 |
-
distance_matrix = 1 - similarity_matrix
|
| 75 |
-
|
| 76 |
-
# Free memory
|
| 77 |
-
del similarity_matrix
|
| 78 |
-
gc.collect()
|
| 79 |
-
|
| 80 |
-
# Calculate condensed distance matrix
|
| 81 |
-
condensed_distance = squareform(distance_matrix)
|
| 82 |
-
|
| 83 |
-
# Free memory
|
| 84 |
-
del distance_matrix
|
| 85 |
-
gc.collect()
|
| 86 |
-
|
| 87 |
-
# Calculate linkage
|
| 88 |
-
Z = fastcluster.linkage(condensed_distance, method='average')
|
| 89 |
-
|
| 90 |
-
# Free memory
|
| 91 |
-
del condensed_distance
|
| 92 |
-
gc.collect()
|
| 93 |
-
|
| 94 |
-
# Perform clustering
|
| 95 |
-
cluster_labels = fcluster(Z, max_clusters, criterion='maxclust')
|
| 96 |
-
|
| 97 |
-
# Create clusters dictionary
|
| 98 |
-
clusters_dict = {}
|
| 99 |
-
for cluster in range(1, max_clusters + 1):
|
| 100 |
-
cluster_indices = np.where(cluster_labels == cluster)[0]
|
| 101 |
-
clusters_dict[cluster] = list(labels[cluster_indices])
|
| 102 |
-
|
| 103 |
-
return {
|
| 104 |
-
'clusters_dict': clusters_dict,
|
| 105 |
-
'cluster_labels': cluster_labels,
|
| 106 |
-
'linkage_matrix': Z
|
| 107 |
-
}
|
| 108 |
-
|
| 109 |
-
except Exception as e:
|
| 110 |
-
print(f"Error in cluster_molecules_optimized: {str(e)}")
|
| 111 |
-
raise e
|
| 112 |
-
|
| 113 |
def download_button(object_to_download, download_filename, button_text, pickle_it=False):
|
| 114 |
"""
|
| 115 |
Issued from
|
|
@@ -186,6 +112,23 @@ def download_button(object_to_download, download_filename, button_text, pickle_i
|
|
| 186 |
|
| 187 |
st.set_page_config(page_title="SynPlanner GUI", page_icon="🧪", layout="wide")
|
| 188 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
intro_text = '''
|
| 190 |
This is a demo of the graphical user interface of
|
| 191 |
[SynPlanner](https://github.com/Laboratoire-de-Chemoinformatique/SynPlanner/).
|
|
@@ -268,7 +211,8 @@ search_strategy = search_strategy_translator[search_strategy_input]
|
|
| 268 |
|
| 269 |
submit_planning = st.button('Start retrosynthetic planning')
|
| 270 |
|
| 271 |
-
if submit_planning:
|
|
|
|
| 272 |
with st.status("Downloading data"):
|
| 273 |
st.write("Downloading building blocks")
|
| 274 |
building_blocks = load_building_blocks(building_blocks_path, standardize=False)
|
|
@@ -306,12 +250,21 @@ if submit_planning:
|
|
| 306 |
|
| 307 |
res = extract_tree_stats(tree, target_molecule)
|
| 308 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
st.header('Results')
|
| 310 |
if res["solved"]:
|
| 311 |
-
st.balloons()
|
| 312 |
-
|
| 313 |
st.subheader("Examples of found retrosynthetic routes")
|
| 314 |
-
|
| 315 |
image_counter = 0
|
| 316 |
visualised_node_ids = set()
|
| 317 |
for n, node_id in enumerate(sorted(set(tree.winning_nodes))):
|
|
@@ -322,255 +275,134 @@ if submit_planning:
|
|
| 322 |
image_counter += 1
|
| 323 |
num_steps = len(tree.synthesis_route(node_id))
|
| 324 |
route_score = round(tree.route_score(node_id), 3)
|
| 325 |
-
st.image(get_route_svg(tree, node_id),
|
|
|
|
| 326 |
|
| 327 |
-
|
| 328 |
-
### Modified part
|
| 329 |
-
# cluster_box, z = st.columns(2, gap="medium")
|
| 330 |
-
# with cluster_box:
|
| 331 |
-
# num_clusters = st.slider('Number of clusters to display', min_value=2, max_value=10, value=2)
|
| 332 |
-
|
| 333 |
-
# submit_clustering = st.button('Start clustering')
|
| 334 |
-
|
| 335 |
-
# if submit_clustering:
|
| 336 |
-
# st.subheader("Examples of clusters")
|
| 337 |
-
# super_cgrs_dict = reassign_nums(tree)
|
| 338 |
-
|
| 339 |
-
# reduced_super_cgrs_dict = process_all_rs_cgrs(super_cgrs_dict)
|
| 340 |
-
|
| 341 |
-
# mfp = MorganFingerprint()
|
| 342 |
-
|
| 343 |
-
# results = cluster_molecules(reduced_super_cgrs_dict, mfp, max_clusters=num_clusters)
|
| 344 |
-
# cluster_box, z = st.columns(2, gap="medium")
|
| 345 |
-
# with cluster_box:
|
| 346 |
-
# # Initialize session state if not exists
|
| 347 |
-
# if 'memory_warning_shown' not in st.session_state:
|
| 348 |
-
# st.session_state.memory_warning_shown = False
|
| 349 |
-
|
| 350 |
-
# current_memory = psutil.Process().memory_info().rss / 1024 / 1024
|
| 351 |
-
# st.write(f"Current memory usage: {current_memory:.2f} MB")
|
| 352 |
-
# st.write(f"Number of winning nodes: {len(tree.winning_nodes)}")
|
| 353 |
-
|
| 354 |
-
# # Memory warning
|
| 355 |
-
# if current_memory > 1000 and not st.session_state.memory_warning_shown:
|
| 356 |
-
# st.warning("Memory usage is high. Consider reducing the number of routes or clearing cache.")
|
| 357 |
-
# st.session_state.memory_warning_shown = True
|
| 358 |
-
|
| 359 |
-
# # Store the previous value in session state
|
| 360 |
-
# if 'prev_num_clusters' not in st.session_state:
|
| 361 |
-
# st.session_state.prev_num_clusters = 2
|
| 362 |
-
|
| 363 |
-
# num_clusters = st.slider(
|
| 364 |
-
# 'Number of clusters to display',
|
| 365 |
-
# min_value=2,
|
| 366 |
-
# max_value=min(10, len(tree.winning_nodes)),
|
| 367 |
-
# value=st.session_state.prev_num_clusters
|
| 368 |
-
# )
|
| 369 |
-
|
| 370 |
-
# # Update the stored value only if it changed
|
| 371 |
-
# if num_clusters != st.session_state.prev_num_clusters:
|
| 372 |
-
# st.session_state.prev_num_clusters = num_clusters
|
| 373 |
-
|
| 374 |
-
# submit_clustering = st.button('Start clustering')
|
| 375 |
-
|
| 376 |
-
# if submit_clustering:
|
| 377 |
-
# try:
|
| 378 |
-
# with st.spinner("Processing clusters..."):
|
| 379 |
-
# # Clear memory before starting
|
| 380 |
-
# gc.collect()
|
| 381 |
-
|
| 382 |
-
# st.write("Starting clustering process...")
|
| 383 |
-
# memory_before = psutil.Process().memory_info().rss / 1024 / 1024
|
| 384 |
-
# st.write(f"Memory before clustering: {memory_before:.2f} MB")
|
| 385 |
-
|
| 386 |
-
# super_cgrs_dict = reassign_nums(tree)
|
| 387 |
-
# del tree # Free up memory from the tree object since we don't need it anymore
|
| 388 |
-
# gc.collect()
|
| 389 |
-
|
| 390 |
-
# reduced_super_cgrs_dict = process_all_rs_cgrs(super_cgrs_dict)
|
| 391 |
-
# del super_cgrs_dict # Free up memory
|
| 392 |
-
# gc.collect()
|
| 393 |
-
|
| 394 |
-
# memory_after = psutil.Process().memory_info().rss / 1024 / 1024
|
| 395 |
-
# st.write(f"Memory after CGR processing: {memory_after:.2f} MB")
|
| 396 |
-
|
| 397 |
-
# mfp = MorganFingerprint()
|
| 398 |
-
# results = cluster_molecules(reduced_super_cgrs_dict, mfp, max_clusters=num_clusters)
|
| 399 |
-
# del reduced_super_cgrs_dict # Free up memory
|
| 400 |
-
# gc.collect()
|
| 401 |
-
|
| 402 |
-
# st.write("Clustering completed")
|
| 403 |
-
|
| 404 |
-
# except Exception as e:
|
| 405 |
-
# st.error(f"Clustering failed with error: {str(e)}")
|
| 406 |
-
# st.write(f"Memory at error: {psutil.Process().memory_info().rss / 1024 / 1024:.2f} MB")
|
| 407 |
-
# raise e
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
# Access results
|
| 411 |
-
# clusters = results['clusters_dict']
|
| 412 |
-
|
| 413 |
-
# for cluster_num, node_id_list in clusters.items():
|
| 414 |
-
# st.markdown(f"Cluster's number: ``{cluster_num}``")
|
| 415 |
-
# node_id = node_id_list[0]
|
| 416 |
-
# num_steps = len(tree.synthesis_route(node_id))
|
| 417 |
-
# route_score = round(tree.route_score(node_id), 3)
|
| 418 |
-
# st.image(get_route_svg(tree, node_id), caption=f"Route {node_id}; {num_steps} steps; Route score: {route_score}")
|
| 419 |
-
|
| 420 |
-
@st.cache_data(hash_funcs={Tree: lambda _: None})
|
| 421 |
-
def prepare_clustering_data(tree):
|
| 422 |
-
try:
|
| 423 |
-
# Log the start and basic info from the Tree
|
| 424 |
-
print("Starting clustering data preparation.")
|
| 425 |
-
total_nodes = len(tree.winning_nodes)
|
| 426 |
-
print(f"Total winning nodes: {total_nodes}")
|
| 427 |
-
print(f"Tree id: {id(tree)}")
|
| 428 |
-
|
| 429 |
-
chunk_size = 10
|
| 430 |
-
super_cgrs_dict = {}
|
| 431 |
-
|
| 432 |
-
# Process winning nodes in chunks
|
| 433 |
-
for i in range(0, total_nodes, chunk_size):
|
| 434 |
-
current_chunk = list(tree.winning_nodes)[i:i+chunk_size]
|
| 435 |
-
print(f"Processing chunk {i // chunk_size + 1}: Nodes {current_chunk}")
|
| 436 |
-
|
| 437 |
-
temp_dict = {}
|
| 438 |
-
for node in current_chunk:
|
| 439 |
-
try:
|
| 440 |
-
# Log before processing each node
|
| 441 |
-
print(f"Processing node {node}")
|
| 442 |
-
route = tree.synthesis_route(node)
|
| 443 |
-
temp_dict[node] = route
|
| 444 |
-
print(f"Node {node} processed successfully (route length: {len(route)}).")
|
| 445 |
-
except Exception as e:
|
| 446 |
-
print(f"Error processing node {node}: {e}")
|
| 447 |
-
|
| 448 |
-
# Log before calling reassign_nums_chunk
|
| 449 |
-
print(f"Calling reassign_nums_chunk for nodes: {list(temp_dict.keys())}")
|
| 450 |
-
chunk_super_cgrs = reassign_nums_chunk(temp_dict)
|
| 451 |
-
super_cgrs_dict.update(chunk_super_cgrs)
|
| 452 |
-
print(f"Chunk {i // chunk_size + 1} processed. Keys: {list(chunk_super_cgrs.keys())}")
|
| 453 |
-
|
| 454 |
-
del temp_dict
|
| 455 |
-
gc.collect()
|
| 456 |
-
|
| 457 |
-
# Process reduced CGRs in chunks
|
| 458 |
-
reduced_super_cgrs_dict = {}
|
| 459 |
-
for i in range(0, len(super_cgrs_dict), chunk_size):
|
| 460 |
-
keys = list(super_cgrs_dict.keys())[i:i+chunk_size]
|
| 461 |
-
chunk_dict = {k: super_cgrs_dict[k] for k in keys}
|
| 462 |
-
print(f"Reducing chunk for keys: {keys}")
|
| 463 |
-
reduced_chunk = process_all_rs_cgrs(chunk_dict)
|
| 464 |
-
reduced_super_cgrs_dict.update(reduced_chunk)
|
| 465 |
-
print(f"Reduced chunk processed for keys: {list(reduced_chunk.keys())}")
|
| 466 |
-
|
| 467 |
-
del chunk_dict
|
| 468 |
-
gc.collect()
|
| 469 |
-
|
| 470 |
-
print("Clustering data preparation complete.")
|
| 471 |
-
return reduced_super_cgrs_dict
|
| 472 |
-
except Exception as e:
|
| 473 |
-
print(f"Error in prepare_clustering_data: {str(e)}")
|
| 474 |
-
st.error(f"Error in prepare_clustering_data: {str(e)}")
|
| 475 |
-
return None
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
def memory_status():
|
| 479 |
-
"""Get current memory status"""
|
| 480 |
-
process = psutil.Process()
|
| 481 |
-
memory = process.memory_info().rss / 1024 / 1024
|
| 482 |
-
return f"Memory usage: {memory:.2f} MB"
|
| 483 |
-
|
| 484 |
-
# Initialize session state for tree and clustering data
|
| 485 |
-
if 'tree_data' not in st.session_state:
|
| 486 |
-
st.session_state.tree_data = tree
|
| 487 |
-
if 'clustering_state' not in st.session_state:
|
| 488 |
-
st.session_state.clustering_state = {
|
| 489 |
-
'prepared': False,
|
| 490 |
-
'data': None,
|
| 491 |
-
'num_clusters': 2
|
| 492 |
-
}
|
| 493 |
-
|
| 494 |
-
cluster_box, z = st.columns(2, gap="medium")
|
| 495 |
-
with cluster_box:
|
| 496 |
-
st.write(memory_status())
|
| 497 |
-
st.write(f"Number of winning nodes: {len(st.session_state.tree_data.winning_nodes)}")
|
| 498 |
-
|
| 499 |
-
# Step 1: Prepare Data Button
|
| 500 |
-
if not st.session_state.clustering_state['prepared']:
|
| 501 |
-
if st.button('Step 1: Prepare clustering data'):
|
| 502 |
-
with st.spinner("Preparing data..."):
|
| 503 |
-
try:
|
| 504 |
-
st.session_state.clustering_state['data'] = prepare_clustering_data(st.session_state.tree_data)
|
| 505 |
-
st.session_state.clustering_state['prepared'] = True
|
| 506 |
-
st.success("Data prepared! Now you can proceed to Step 2.")
|
| 507 |
-
except Exception as e:
|
| 508 |
-
st.error(f"Preparation failed: {str(e)}")
|
| 509 |
-
|
| 510 |
-
# Step 2: Only show clustering controls if data is prepared
|
| 511 |
-
if st.session_state.clustering_state['prepared']:
|
| 512 |
-
st.markdown("### Step 2: Select number of clusters")
|
| 513 |
-
# Store slider value in session state
|
| 514 |
-
st.session_state.clustering_state['num_clusters'] = st.slider(
|
| 515 |
-
'Number of clusters',
|
| 516 |
-
min_value=2,
|
| 517 |
-
max_value=min(10, len(st.session_state.tree_data.winning_nodes)),
|
| 518 |
-
value=st.session_state.clustering_state['num_clusters']
|
| 519 |
-
)
|
| 520 |
-
|
| 521 |
-
# Step 3: Generate Clusters Button
|
| 522 |
-
if st.button('Step 3: Generate clusters'):
|
| 523 |
-
with st.spinner("Clustering..."):
|
| 524 |
-
try:
|
| 525 |
-
results = perform_clustering(
|
| 526 |
-
st.session_state.clustering_state['data'],
|
| 527 |
-
st.session_state.clustering_state['num_clusters']
|
| 528 |
-
)
|
| 529 |
-
|
| 530 |
-
if results:
|
| 531 |
-
st.success("Clustering complete!")
|
| 532 |
-
for cluster_num, node_ids in results['clusters_dict'].items():
|
| 533 |
-
with st.expander(f"Cluster {cluster_num}"):
|
| 534 |
-
if node_ids:
|
| 535 |
-
node_id = node_ids[0]
|
| 536 |
-
num_steps = len(st.session_state.tree_data.synthesis_route(node_id))
|
| 537 |
-
route_score = round(st.session_state.tree_data.route_score(node_id), 3)
|
| 538 |
-
st.image(
|
| 539 |
-
get_route_svg(st.session_state.tree_data, node_id),
|
| 540 |
-
caption=f"Route {node_id}; {num_steps} steps; Score: {route_score}"
|
| 541 |
-
)
|
| 542 |
-
except Exception as e:
|
| 543 |
-
st.error(f"Clustering failed: {str(e)}")
|
| 544 |
-
|
| 545 |
-
# Clear memory button
|
| 546 |
-
if st.button('Clear memory and start over'):
|
| 547 |
-
st.cache_data.clear()
|
| 548 |
-
del st.session_state.clustering_state
|
| 549 |
-
del st.session_state.tree_data
|
| 550 |
-
gc.collect()
|
| 551 |
-
st.success("Memory cleared! Please refresh the page to start over.")
|
| 552 |
-
st.rerun()
|
| 553 |
-
|
| 554 |
stat_col, download_col = st.columns(2, gap="medium")
|
| 555 |
-
|
| 556 |
with stat_col:
|
| 557 |
st.subheader("Statistics")
|
| 558 |
df = pd.DataFrame(res, index=[0])
|
| 559 |
st.write(df[["target_smiles", "num_routes", "num_nodes", "num_iter", "search_time"]])
|
| 560 |
-
|
| 561 |
with download_col:
|
| 562 |
st.subheader("Downloads")
|
| 563 |
html_body = generate_results_html(tree, html_path=None, extended=True)
|
| 564 |
dl_html = download_button(html_body, 'results_synplanner.html', 'Download results as a HTML file')
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
'Download statistics as a csv file')
|
| 568 |
st.markdown(dl_html + dl_csv, unsafe_allow_html=True)
|
| 569 |
|
| 570 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 571 |
st.write("Found no reaction path.")
|
| 572 |
|
| 573 |
st.divider()
|
| 574 |
st.header('Restart from the beginning?')
|
| 575 |
if st.button("Restart"):
|
|
|
|
| 576 |
st.rerun()
|
|
|
|
| 23 |
from cluster.super_cgr import *
|
| 24 |
from cluster.rs_cgr import *
|
| 25 |
from cluster.clustering import *
|
| 26 |
+
from cluster.visualize import *
|
| 27 |
+
from cluster.utils import *
|
| 28 |
+
from cluster.subcluster import *
|
| 29 |
from StructureFingerprint import MorganFingerprint
|
| 30 |
|
| 31 |
import psutil
|
|
|
|
| 36 |
|
| 37 |
smiles_parser = SMILESRead.create_parser(ignore=True)
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
def download_button(object_to_download, download_filename, button_text, pickle_it=False):
|
| 40 |
"""
|
| 41 |
Issued from
|
|
|
|
| 112 |
|
| 113 |
st.set_page_config(page_title="SynPlanner GUI", page_icon="🧪", layout="wide")
|
| 114 |
|
| 115 |
+
# Initialize session state variables if they don't exist.
|
| 116 |
+
if "planning_done" not in st.session_state:
|
| 117 |
+
st.session_state.planning_done = False
|
| 118 |
+
if "tree" not in st.session_state:
|
| 119 |
+
st.session_state.tree = None
|
| 120 |
+
if "res" not in st.session_state:
|
| 121 |
+
st.session_state.res = None
|
| 122 |
+
if "num_clusters" not in st.session_state:
|
| 123 |
+
st.session_state.num_clusters = 10
|
| 124 |
+
|
| 125 |
+
if 'clustering_started' not in st.session_state:
|
| 126 |
+
st.session_state.clustering_started = False
|
| 127 |
+
if 'clusters_downloaded' not in st.session_state:
|
| 128 |
+
st.session_state.clusters_downloaded = False
|
| 129 |
+
|
| 130 |
+
# st.write("Initial session state:", dict(st.session_state))
|
| 131 |
+
|
| 132 |
intro_text = '''
|
| 133 |
This is a demo of the graphical user interface of
|
| 134 |
[SynPlanner](https://github.com/Laboratoire-de-Chemoinformatique/SynPlanner/).
|
|
|
|
| 211 |
|
| 212 |
submit_planning = st.button('Start retrosynthetic planning')
|
| 213 |
|
| 214 |
+
# if submit_planning:
|
| 215 |
+
if submit_planning and not st.session_state.planning_done:
|
| 216 |
with st.status("Downloading data"):
|
| 217 |
st.write("Downloading building blocks")
|
| 218 |
building_blocks = load_building_blocks(building_blocks_path, standardize=False)
|
|
|
|
| 250 |
|
| 251 |
res = extract_tree_stats(tree, target_molecule)
|
| 252 |
|
| 253 |
+
# Store planning outputs in session_state so they persist
|
| 254 |
+
st.session_state['tree'] = tree
|
| 255 |
+
st.session_state['res'] = res
|
| 256 |
+
st.session_state.planning_done = True
|
| 257 |
+
|
| 258 |
+
# Display results if planning has been completed
|
| 259 |
+
if st.session_state.planning_done and st.session_state.res is not None and st.session_state.clustering_started:
|
| 260 |
+
res = st.session_state.res
|
| 261 |
+
tree = st.session_state.tree
|
| 262 |
+
|
| 263 |
st.header('Results')
|
| 264 |
if res["solved"]:
|
| 265 |
+
# st.balloons()
|
| 266 |
+
|
| 267 |
st.subheader("Examples of found retrosynthetic routes")
|
|
|
|
| 268 |
image_counter = 0
|
| 269 |
visualised_node_ids = set()
|
| 270 |
for n, node_id in enumerate(sorted(set(tree.winning_nodes))):
|
|
|
|
| 275 |
image_counter += 1
|
| 276 |
num_steps = len(tree.synthesis_route(node_id))
|
| 277 |
route_score = round(tree.route_score(node_id), 3)
|
| 278 |
+
st.image(get_route_svg(tree, node_id),
|
| 279 |
+
caption=f"Route {node_id}; {num_steps} steps; Route score: {route_score}")
|
| 280 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 281 |
stat_col, download_col = st.columns(2, gap="medium")
|
|
|
|
| 282 |
with stat_col:
|
| 283 |
st.subheader("Statistics")
|
| 284 |
df = pd.DataFrame(res, index=[0])
|
| 285 |
st.write(df[["target_smiles", "num_routes", "num_nodes", "num_iter", "search_time"]])
|
|
|
|
| 286 |
with download_col:
|
| 287 |
st.subheader("Downloads")
|
| 288 |
html_body = generate_results_html(tree, html_path=None, extended=True)
|
| 289 |
dl_html = download_button(html_body, 'results_synplanner.html', 'Download results as a HTML file')
|
| 290 |
+
dl_csv = download_button(pd.DataFrame(res, index=[0]),
|
| 291 |
+
'results_synplanner.csv', 'Download statistics as a csv file')
|
|
|
|
| 292 |
st.markdown(dl_html + dl_csv, unsafe_allow_html=True)
|
| 293 |
|
| 294 |
+
st.header("Clustering the retrosynthetic routes")
|
| 295 |
+
|
| 296 |
+
# Initialize slider state if not already set
|
| 297 |
+
if 'num_clusters' not in st.session_state:
|
| 298 |
+
st.session_state['num_clusters'] = 10
|
| 299 |
+
|
| 300 |
+
cluster_box, _ = st.columns(2, gap="medium")
|
| 301 |
+
with cluster_box:
|
| 302 |
+
num_clusters = st.slider(
|
| 303 |
+
'Number of clusters to display',
|
| 304 |
+
min_value=2,
|
| 305 |
+
max_value=10,
|
| 306 |
+
value=st.session_state['num_clusters'],
|
| 307 |
+
key='cluster_slider'
|
| 308 |
+
)
|
| 309 |
+
# Save the current slider value to session_state
|
| 310 |
+
st.session_state['num_clusters'] = num_clusters
|
| 311 |
+
|
| 312 |
+
if st.button('Start clustering', key='submit_clustering'):
|
| 313 |
+
st.session_state.clustering_started = True
|
| 314 |
+
# st.write("Clustering started; session state now:", dict(st.session_state))
|
| 315 |
+
# st.write("Clustering started!")
|
| 316 |
+
st.subheader("Examples of clusters")
|
| 317 |
+
super_cgrs_dict = reassign_nums(tree)
|
| 318 |
+
|
| 319 |
+
reduced_super_cgrs_dict = process_all_rs_cgrs(super_cgrs_dict)
|
| 320 |
+
|
| 321 |
+
mfp = MorganFingerprint()
|
| 322 |
+
|
| 323 |
+
results = cluster_molecules(reduced_super_cgrs_dict, mfp, max_clusters=num_clusters)
|
| 324 |
+
|
| 325 |
+
clusters = results['clusters_dict']
|
| 326 |
+
|
| 327 |
+
for cluster_num, node_id_list in clusters.items():
|
| 328 |
+
st.markdown(f"Cluster's number: {cluster_num}; Size {len(node_id_list)}")
|
| 329 |
+
node_id = node_id_list[0]
|
| 330 |
+
num_steps = len(tree.synthesis_route(node_id))
|
| 331 |
+
route_score = round(tree.route_score(node_id), 3)
|
| 332 |
+
st.image(get_route_svg(tree, node_id), caption=f"Route {node_id}; {num_steps} steps; Route score: {route_score}")
|
| 333 |
+
|
| 334 |
+
cluster_sizes = [len(cluster) for cluster in clusters.values()]
|
| 335 |
+
|
| 336 |
+
cluster_stat_col, cluster_download_col = st.columns(2, gap="medium")
|
| 337 |
+
with cluster_stat_col:
|
| 338 |
+
st.subheader("Statistics")
|
| 339 |
+
# st.write(cluster_sizes)
|
| 340 |
+
cluster_df = pd.DataFrame({'Cluster': range(len(cluster_sizes)), 'Routes': cluster_sizes})
|
| 341 |
+
# cluster_df = pd.DataFrame(cluster_sizes, index=[0])
|
| 342 |
+
st.write(cluster_df)
|
| 343 |
+
|
| 344 |
+
def on_download_click():
|
| 345 |
+
st.session_state.clusters_downloaded = True
|
| 346 |
+
st.write("Download clusters button pressed via on_click. Updated session state:", dict(st.session_state))
|
| 347 |
+
save_route_images(tree, reactions_dict, cluster_dict=clusters_converted)
|
| 348 |
+
# Here you can call save_route_images(...) if desired.
|
| 349 |
+
|
| 350 |
+
with cluster_download_col:
|
| 351 |
+
st.subheader("Downloads: Don't work. Resets evey time")
|
| 352 |
+
reactions_dict = extract_reactions(tree)
|
| 353 |
+
clusters_converted = {int(key): value for key, value in clusters.items()} if clusters else clusters
|
| 354 |
+
|
| 355 |
+
# Use on_click to capture the click event reliably.
|
| 356 |
+
st.button('Download clusters', key='download_clusters_button', on_click=on_download_click)
|
| 357 |
+
|
| 358 |
+
# Log whether the flag has been set after the button definition.
|
| 359 |
+
st.write("Clusters downloaded flag (from session_state):", st.session_state.get("clusters_downloaded"))
|
| 360 |
+
|
| 361 |
+
# # save_route_images(tree, reactions_dict, cluster_dict=clusters_converted)
|
| 362 |
+
# with cluster_download_col:
|
| 363 |
+
# st.subheader("Downloads")
|
| 364 |
+
# reactions_dict = extract_reactions(tree)
|
| 365 |
+
# clusters_converted = {int(key): value for key, value in clusters.items()} if clusters else clusters
|
| 366 |
+
|
| 367 |
+
# if st.session_state.clustering_started:
|
| 368 |
+
# st.write("Rendering download clusters button. Session state:", dict(st.session_state))
|
| 369 |
+
# # Use a more unique key for the download button.
|
| 370 |
+
# download_clusters = st.button('Download clusters', key='download_clusters_button')
|
| 371 |
+
# st.write("download_clusters value:", download_clusters)
|
| 372 |
+
# if download_clusters:
|
| 373 |
+
# st.session_state.clusters_downloaded = True
|
| 374 |
+
# st.write("Download clusters button pressed. Updated session state:", dict(st.session_state))
|
| 375 |
+
|
| 376 |
+
col1, _ = st.columns([.2, .8])
|
| 377 |
+
with col1:
|
| 378 |
+
fig = pie_chart(cluster_sizes)
|
| 379 |
+
st.pyplot(fig)
|
| 380 |
+
st.header("Sub Clustering the retrosynthetic routes - Resets every time when i interact with input widget")
|
| 381 |
+
sub = sublcuster_all(clusters, reactions_dict)
|
| 382 |
+
col2, _ = st.columns([.2, .8])
|
| 383 |
+
with col2:
|
| 384 |
+
user_input_cluster_num = st.number_input("Enter a number:", min_value=1,
|
| 385 |
+
max_value=max(clusters.keys()), value=1, step=1)
|
| 386 |
+
|
| 387 |
+
st.write(f"You entered the # cluster: {user_input_cluster_num}")
|
| 388 |
+
sub_step_cluster = sub[user_input_cluster_num]
|
| 389 |
+
allowed_numbers = sub_step_cluster.keys()
|
| 390 |
+
selected_number = st.selectbox("Choose a number:", allowed_numbers)
|
| 391 |
+
st.write(f"You entered number of steps: {selected_number}")
|
| 392 |
+
subclusters = sub_step_cluster[selected_number]
|
| 393 |
+
|
| 394 |
+
st.subheader(f"Found number of subclusters: {len(subclusters)}")
|
| 395 |
+
for subcluster_num, subcluster_set in enumerate(subclusters):
|
| 396 |
+
st.write(f"Subcluster #: {subcluster_num + 1}")
|
| 397 |
+
for route_id in subcluster_set:
|
| 398 |
+
st.write(f"Node_ID: {route_id}")
|
| 399 |
+
st.image(get_route_svg(tree, route_id), caption=f"Route {node_id}; {num_steps} steps; Route score: {route_score}")
|
| 400 |
+
|
| 401 |
+
else:
|
| 402 |
st.write("Found no reaction path.")
|
| 403 |
|
| 404 |
st.divider()
|
| 405 |
st.header('Restart from the beginning?')
|
| 406 |
if st.button("Restart"):
|
| 407 |
+
st.session_state.planning_done = False
|
| 408 |
st.rerun()
|