Spaces:
Sleeping
Sleeping
Merge branch 'main' of https://huggingface.co/spaces/Protolaw/SynPlanner
Browse files
app.py
CHANGED
|
@@ -275,9 +275,235 @@ if st.session_state.planning_done and st.session_state.res is not None and st.se
|
|
| 275 |
image_counter += 1
|
| 276 |
num_steps = len(tree.synthesis_route(node_id))
|
| 277 |
route_score = round(tree.route_score(node_id), 3)
|
| 278 |
-
st.image(get_route_svg(tree, node_id),
|
| 279 |
-
|
| 280 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 281 |
stat_col, download_col = st.columns(2, gap="medium")
|
| 282 |
with stat_col:
|
| 283 |
st.subheader("Statistics")
|
|
|
|
| 275 |
image_counter += 1
|
| 276 |
num_steps = len(tree.synthesis_route(node_id))
|
| 277 |
route_score = round(tree.route_score(node_id), 3)
|
| 278 |
+
st.image(get_route_svg(tree, node_id), caption=f"Route {node_id}; {num_steps} steps; Route score: {route_score}")
|
| 279 |
+
|
| 280 |
|
| 281 |
+
### Modified part
|
| 282 |
+
# cluster_box, z = st.columns(2, gap="medium")
|
| 283 |
+
# with cluster_box:
|
| 284 |
+
# num_clusters = st.slider('Number of clusters to display', min_value=2, max_value=10, value=2)
|
| 285 |
+
|
| 286 |
+
# submit_clustering = st.button('Start clustering')
|
| 287 |
+
|
| 288 |
+
# if submit_clustering:
|
| 289 |
+
# st.subheader("Examples of clusters")
|
| 290 |
+
# super_cgrs_dict = reassign_nums(tree)
|
| 291 |
+
|
| 292 |
+
# reduced_super_cgrs_dict = process_all_rs_cgrs(super_cgrs_dict)
|
| 293 |
+
|
| 294 |
+
# mfp = MorganFingerprint()
|
| 295 |
+
|
| 296 |
+
# results = cluster_molecules(reduced_super_cgrs_dict, mfp, max_clusters=num_clusters)
|
| 297 |
+
# cluster_box, z = st.columns(2, gap="medium")
|
| 298 |
+
# with cluster_box:
|
| 299 |
+
# # Initialize session state if not exists
|
| 300 |
+
# if 'memory_warning_shown' not in st.session_state:
|
| 301 |
+
# st.session_state.memory_warning_shown = False
|
| 302 |
+
|
| 303 |
+
# current_memory = psutil.Process().memory_info().rss / 1024 / 1024
|
| 304 |
+
# st.write(f"Current memory usage: {current_memory:.2f} MB")
|
| 305 |
+
# st.write(f"Number of winning nodes: {len(tree.winning_nodes)}")
|
| 306 |
+
|
| 307 |
+
# # Memory warning
|
| 308 |
+
# if current_memory > 1000 and not st.session_state.memory_warning_shown:
|
| 309 |
+
# st.warning("Memory usage is high. Consider reducing the number of routes or clearing cache.")
|
| 310 |
+
# st.session_state.memory_warning_shown = True
|
| 311 |
+
|
| 312 |
+
# # Store the previous value in session state
|
| 313 |
+
# if 'prev_num_clusters' not in st.session_state:
|
| 314 |
+
# st.session_state.prev_num_clusters = 2
|
| 315 |
+
|
| 316 |
+
# num_clusters = st.slider(
|
| 317 |
+
# 'Number of clusters to display',
|
| 318 |
+
# min_value=2,
|
| 319 |
+
# max_value=min(10, len(tree.winning_nodes)),
|
| 320 |
+
# value=st.session_state.prev_num_clusters
|
| 321 |
+
# )
|
| 322 |
+
|
| 323 |
+
# # Update the stored value only if it changed
|
| 324 |
+
# if num_clusters != st.session_state.prev_num_clusters:
|
| 325 |
+
# st.session_state.prev_num_clusters = num_clusters
|
| 326 |
+
|
| 327 |
+
# submit_clustering = st.button('Start clustering')
|
| 328 |
+
|
| 329 |
+
# if submit_clustering:
|
| 330 |
+
# try:
|
| 331 |
+
# with st.spinner("Processing clusters..."):
|
| 332 |
+
# # Clear memory before starting
|
| 333 |
+
# gc.collect()
|
| 334 |
+
|
| 335 |
+
# st.write("Starting clustering process...")
|
| 336 |
+
# memory_before = psutil.Process().memory_info().rss / 1024 / 1024
|
| 337 |
+
# st.write(f"Memory before clustering: {memory_before:.2f} MB")
|
| 338 |
+
|
| 339 |
+
# super_cgrs_dict = reassign_nums(tree)
|
| 340 |
+
# del tree # Free up memory from the tree object since we don't need it anymore
|
| 341 |
+
# gc.collect()
|
| 342 |
+
|
| 343 |
+
# reduced_super_cgrs_dict = process_all_rs_cgrs(super_cgrs_dict)
|
| 344 |
+
# del super_cgrs_dict # Free up memory
|
| 345 |
+
# gc.collect()
|
| 346 |
+
|
| 347 |
+
# memory_after = psutil.Process().memory_info().rss / 1024 / 1024
|
| 348 |
+
# st.write(f"Memory after CGR processing: {memory_after:.2f} MB")
|
| 349 |
+
|
| 350 |
+
# mfp = MorganFingerprint()
|
| 351 |
+
# results = cluster_molecules(reduced_super_cgrs_dict, mfp, max_clusters=num_clusters)
|
| 352 |
+
# del reduced_super_cgrs_dict # Free up memory
|
| 353 |
+
# gc.collect()
|
| 354 |
+
|
| 355 |
+
# st.write("Clustering completed")
|
| 356 |
+
|
| 357 |
+
# except Exception as e:
|
| 358 |
+
# st.error(f"Clustering failed with error: {str(e)}")
|
| 359 |
+
# st.write(f"Memory at error: {psutil.Process().memory_info().rss / 1024 / 1024:.2f} MB")
|
| 360 |
+
# raise e
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
# Access results
|
| 364 |
+
# clusters = results['clusters_dict']
|
| 365 |
+
|
| 366 |
+
# for cluster_num, node_id_list in clusters.items():
|
| 367 |
+
# st.markdown(f"Cluster's number: ``{cluster_num}``")
|
| 368 |
+
# node_id = node_id_list[0]
|
| 369 |
+
# num_steps = len(tree.synthesis_route(node_id))
|
| 370 |
+
# route_score = round(tree.route_score(node_id), 3)
|
| 371 |
+
# st.image(get_route_svg(tree, node_id), caption=f"Route {node_id}; {num_steps} steps; Route score: {route_score}")
|
| 372 |
+
|
| 373 |
+
@st.cache_data(hash_funcs={Tree: lambda _: None})
|
| 374 |
+
def prepare_clustering_data(tree):
|
| 375 |
+
try:
|
| 376 |
+
# Log the start and basic info from the Tree
|
| 377 |
+
print("Starting clustering data preparation.")
|
| 378 |
+
total_nodes = len(tree.winning_nodes)
|
| 379 |
+
print(f"Total winning nodes: {total_nodes}")
|
| 380 |
+
print(f"Tree id: {id(tree)}")
|
| 381 |
+
|
| 382 |
+
chunk_size = 10
|
| 383 |
+
super_cgrs_dict = {}
|
| 384 |
+
|
| 385 |
+
# Process winning nodes in chunks
|
| 386 |
+
for i in range(0, total_nodes, chunk_size):
|
| 387 |
+
current_chunk = list(tree.winning_nodes)[i:i+chunk_size]
|
| 388 |
+
print(f"Processing chunk {i // chunk_size + 1}: Nodes {current_chunk}")
|
| 389 |
+
|
| 390 |
+
temp_dict = {}
|
| 391 |
+
for node in current_chunk:
|
| 392 |
+
try:
|
| 393 |
+
# Log before processing each node
|
| 394 |
+
print(f"Processing node {node}")
|
| 395 |
+
route = tree.synthesis_route(node)
|
| 396 |
+
temp_dict[node] = route
|
| 397 |
+
print(f"Node {node} processed successfully (route length: {len(route)}).")
|
| 398 |
+
except Exception as e:
|
| 399 |
+
print(f"Error processing node {node}: {e}")
|
| 400 |
+
|
| 401 |
+
# Log before calling reassign_nums_chunk
|
| 402 |
+
print(f"Calling reassign_nums_chunk for nodes: {list(temp_dict.keys())}")
|
| 403 |
+
chunk_super_cgrs = reassign_nums_chunk(temp_dict)
|
| 404 |
+
super_cgrs_dict.update(chunk_super_cgrs)
|
| 405 |
+
print(f"Chunk {i // chunk_size + 1} processed. Keys: {list(chunk_super_cgrs.keys())}")
|
| 406 |
+
|
| 407 |
+
del temp_dict
|
| 408 |
+
gc.collect()
|
| 409 |
+
|
| 410 |
+
# Process reduced CGRs in chunks
|
| 411 |
+
reduced_super_cgrs_dict = {}
|
| 412 |
+
for i in range(0, len(super_cgrs_dict), chunk_size):
|
| 413 |
+
keys = list(super_cgrs_dict.keys())[i:i+chunk_size]
|
| 414 |
+
chunk_dict = {k: super_cgrs_dict[k] for k in keys}
|
| 415 |
+
print(f"Reducing chunk for keys: {keys}")
|
| 416 |
+
reduced_chunk = process_all_rs_cgrs(chunk_dict)
|
| 417 |
+
reduced_super_cgrs_dict.update(reduced_chunk)
|
| 418 |
+
print(f"Reduced chunk processed for keys: {list(reduced_chunk.keys())}")
|
| 419 |
+
|
| 420 |
+
del chunk_dict
|
| 421 |
+
gc.collect()
|
| 422 |
+
|
| 423 |
+
print("Clustering data preparation complete.")
|
| 424 |
+
return reduced_super_cgrs_dict
|
| 425 |
+
except Exception as e:
|
| 426 |
+
print(f"Error in prepare_clustering_data: {str(e)}")
|
| 427 |
+
st.error(f"Error in prepare_clustering_data: {str(e)}")
|
| 428 |
+
return None
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
def memory_status():
|
| 432 |
+
"""Get current memory status"""
|
| 433 |
+
process = psutil.Process()
|
| 434 |
+
memory = process.memory_info().rss / 1024 / 1024
|
| 435 |
+
return f"Memory usage: {memory:.2f} MB"
|
| 436 |
+
|
| 437 |
+
# Initialize session state for tree and clustering data
|
| 438 |
+
if 'tree_data' not in st.session_state:
|
| 439 |
+
st.session_state.tree_data = tree
|
| 440 |
+
if 'clustering_state' not in st.session_state:
|
| 441 |
+
st.session_state.clustering_state = {
|
| 442 |
+
'prepared': False,
|
| 443 |
+
'data': None,
|
| 444 |
+
'num_clusters': 2
|
| 445 |
+
}
|
| 446 |
+
|
| 447 |
+
cluster_box, z = st.columns(2, gap="medium")
|
| 448 |
+
with cluster_box:
|
| 449 |
+
st.write(memory_status())
|
| 450 |
+
st.write(f"Number of winning nodes: {len(st.session_state.tree_data.winning_nodes)}")
|
| 451 |
+
|
| 452 |
+
# Step 1: Prepare Data Button
|
| 453 |
+
if not st.session_state.clustering_state['prepared']:
|
| 454 |
+
if st.button('Step 1: Prepare clustering data'):
|
| 455 |
+
with st.spinner("Preparing data..."):
|
| 456 |
+
try:
|
| 457 |
+
st.session_state.clustering_state['data'] = prepare_clustering_data(st.session_state.tree_data)
|
| 458 |
+
st.session_state.clustering_state['prepared'] = True
|
| 459 |
+
st.success("Data prepared! Now you can proceed to Step 2.")
|
| 460 |
+
except Exception as e:
|
| 461 |
+
st.error(f"Preparation failed: {str(e)}")
|
| 462 |
+
|
| 463 |
+
# Step 2: Only show clustering controls if data is prepared
|
| 464 |
+
if st.session_state.clustering_state['prepared']:
|
| 465 |
+
st.markdown("### Step 2: Select number of clusters")
|
| 466 |
+
# Store slider value in session state
|
| 467 |
+
st.session_state.clustering_state['num_clusters'] = st.slider(
|
| 468 |
+
'Number of clusters',
|
| 469 |
+
min_value=2,
|
| 470 |
+
max_value=min(10, len(st.session_state.tree_data.winning_nodes)),
|
| 471 |
+
value=st.session_state.clustering_state['num_clusters']
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
# Step 3: Generate Clusters Button
|
| 475 |
+
if st.button('Step 3: Generate clusters'):
|
| 476 |
+
with st.spinner("Clustering..."):
|
| 477 |
+
try:
|
| 478 |
+
results = perform_clustering(
|
| 479 |
+
st.session_state.clustering_state['data'],
|
| 480 |
+
st.session_state.clustering_state['num_clusters']
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
if results:
|
| 484 |
+
st.success("Clustering complete!")
|
| 485 |
+
for cluster_num, node_ids in results['clusters_dict'].items():
|
| 486 |
+
with st.expander(f"Cluster {cluster_num}"):
|
| 487 |
+
if node_ids:
|
| 488 |
+
node_id = node_ids[0]
|
| 489 |
+
num_steps = len(st.session_state.tree_data.synthesis_route(node_id))
|
| 490 |
+
route_score = round(st.session_state.tree_data.route_score(node_id), 3)
|
| 491 |
+
st.image(
|
| 492 |
+
get_route_svg(st.session_state.tree_data, node_id),
|
| 493 |
+
caption=f"Route {node_id}; {num_steps} steps; Score: {route_score}"
|
| 494 |
+
)
|
| 495 |
+
except Exception as e:
|
| 496 |
+
st.error(f"Clustering failed: {str(e)}")
|
| 497 |
+
|
| 498 |
+
# Clear memory button
|
| 499 |
+
if st.button('Clear memory and start over'):
|
| 500 |
+
st.cache_data.clear()
|
| 501 |
+
del st.session_state.clustering_state
|
| 502 |
+
del st.session_state.tree_data
|
| 503 |
+
gc.collect()
|
| 504 |
+
st.success("Memory cleared! Please refresh the page to start over.")
|
| 505 |
+
st.rerun()
|
| 506 |
+
|
| 507 |
stat_col, download_col = st.columns(2, gap="medium")
|
| 508 |
with stat_col:
|
| 509 |
st.subheader("Statistics")
|