Gilmullin Almaz commited on
Commit
2df9441
·
1 Parent(s): bcecb42
Files changed (1) hide show
  1. app.py +7 -253
app.py CHANGED
@@ -127,7 +127,6 @@ if 'clustering_started' not in st.session_state:
127
  if 'clusters_downloaded' not in st.session_state:
128
  st.session_state.clusters_downloaded = False
129
 
130
- # st.write("Initial session state:", dict(st.session_state))
131
 
132
  intro_text = '''
133
  This is a demo of the graphical user interface of
@@ -256,7 +255,7 @@ if submit_planning and not st.session_state.planning_done:
256
  st.session_state.planning_done = True
257
 
258
  # Display results if planning has been completed
259
- if st.session_state.planning_done and st.session_state.res is not None and st.session_state.clustering_started:
260
  res = st.session_state.res
261
  tree = st.session_state.tree
262
 
@@ -275,235 +274,9 @@ if st.session_state.planning_done and st.session_state.res is not None and st.se
275
  image_counter += 1
276
  num_steps = len(tree.synthesis_route(node_id))
277
  route_score = round(tree.route_score(node_id), 3)
278
- st.image(get_route_svg(tree, node_id), caption=f"Route {node_id}; {num_steps} steps; Route score: {route_score}")
279
-
280
 
281
- ### Modified part
282
- # cluster_box, z = st.columns(2, gap="medium")
283
- # with cluster_box:
284
- # num_clusters = st.slider('Number of clusters to display', min_value=2, max_value=10, value=2)
285
-
286
- # submit_clustering = st.button('Start clustering')
287
-
288
- # if submit_clustering:
289
- # st.subheader("Examples of clusters")
290
- # super_cgrs_dict = reassign_nums(tree)
291
-
292
- # reduced_super_cgrs_dict = process_all_rs_cgrs(super_cgrs_dict)
293
-
294
- # mfp = MorganFingerprint()
295
-
296
- # results = cluster_molecules(reduced_super_cgrs_dict, mfp, max_clusters=num_clusters)
297
- # cluster_box, z = st.columns(2, gap="medium")
298
- # with cluster_box:
299
- # # Initialize session state if not exists
300
- # if 'memory_warning_shown' not in st.session_state:
301
- # st.session_state.memory_warning_shown = False
302
-
303
- # current_memory = psutil.Process().memory_info().rss / 1024 / 1024
304
- # st.write(f"Current memory usage: {current_memory:.2f} MB")
305
- # st.write(f"Number of winning nodes: {len(tree.winning_nodes)}")
306
-
307
- # # Memory warning
308
- # if current_memory > 1000 and not st.session_state.memory_warning_shown:
309
- # st.warning("Memory usage is high. Consider reducing the number of routes or clearing cache.")
310
- # st.session_state.memory_warning_shown = True
311
-
312
- # # Store the previous value in session state
313
- # if 'prev_num_clusters' not in st.session_state:
314
- # st.session_state.prev_num_clusters = 2
315
-
316
- # num_clusters = st.slider(
317
- # 'Number of clusters to display',
318
- # min_value=2,
319
- # max_value=min(10, len(tree.winning_nodes)),
320
- # value=st.session_state.prev_num_clusters
321
- # )
322
-
323
- # # Update the stored value only if it changed
324
- # if num_clusters != st.session_state.prev_num_clusters:
325
- # st.session_state.prev_num_clusters = num_clusters
326
-
327
- # submit_clustering = st.button('Start clustering')
328
-
329
- # if submit_clustering:
330
- # try:
331
- # with st.spinner("Processing clusters..."):
332
- # # Clear memory before starting
333
- # gc.collect()
334
-
335
- # st.write("Starting clustering process...")
336
- # memory_before = psutil.Process().memory_info().rss / 1024 / 1024
337
- # st.write(f"Memory before clustering: {memory_before:.2f} MB")
338
-
339
- # super_cgrs_dict = reassign_nums(tree)
340
- # del tree # Free up memory from the tree object since we don't need it anymore
341
- # gc.collect()
342
-
343
- # reduced_super_cgrs_dict = process_all_rs_cgrs(super_cgrs_dict)
344
- # del super_cgrs_dict # Free up memory
345
- # gc.collect()
346
-
347
- # memory_after = psutil.Process().memory_info().rss / 1024 / 1024
348
- # st.write(f"Memory after CGR processing: {memory_after:.2f} MB")
349
-
350
- # mfp = MorganFingerprint()
351
- # results = cluster_molecules(reduced_super_cgrs_dict, mfp, max_clusters=num_clusters)
352
- # del reduced_super_cgrs_dict # Free up memory
353
- # gc.collect()
354
-
355
- # st.write("Clustering completed")
356
-
357
- # except Exception as e:
358
- # st.error(f"Clustering failed with error: {str(e)}")
359
- # st.write(f"Memory at error: {psutil.Process().memory_info().rss / 1024 / 1024:.2f} MB")
360
- # raise e
361
-
362
-
363
- # Access results
364
- # clusters = results['clusters_dict']
365
-
366
- # for cluster_num, node_id_list in clusters.items():
367
- # st.markdown(f"Cluster's number: ``{cluster_num}``")
368
- # node_id = node_id_list[0]
369
- # num_steps = len(tree.synthesis_route(node_id))
370
- # route_score = round(tree.route_score(node_id), 3)
371
- # st.image(get_route_svg(tree, node_id), caption=f"Route {node_id}; {num_steps} steps; Route score: {route_score}")
372
-
373
- @st.cache_data(hash_funcs={Tree: lambda _: None})
374
- def prepare_clustering_data(tree):
375
- try:
376
- # Log the start and basic info from the Tree
377
- print("Starting clustering data preparation.")
378
- total_nodes = len(tree.winning_nodes)
379
- print(f"Total winning nodes: {total_nodes}")
380
- print(f"Tree id: {id(tree)}")
381
-
382
- chunk_size = 10
383
- super_cgrs_dict = {}
384
-
385
- # Process winning nodes in chunks
386
- for i in range(0, total_nodes, chunk_size):
387
- current_chunk = list(tree.winning_nodes)[i:i+chunk_size]
388
- print(f"Processing chunk {i // chunk_size + 1}: Nodes {current_chunk}")
389
-
390
- temp_dict = {}
391
- for node in current_chunk:
392
- try:
393
- # Log before processing each node
394
- print(f"Processing node {node}")
395
- route = tree.synthesis_route(node)
396
- temp_dict[node] = route
397
- print(f"Node {node} processed successfully (route length: {len(route)}).")
398
- except Exception as e:
399
- print(f"Error processing node {node}: {e}")
400
-
401
- # Log before calling reassign_nums_chunk
402
- print(f"Calling reassign_nums_chunk for nodes: {list(temp_dict.keys())}")
403
- chunk_super_cgrs = reassign_nums_chunk(temp_dict)
404
- super_cgrs_dict.update(chunk_super_cgrs)
405
- print(f"Chunk {i // chunk_size + 1} processed. Keys: {list(chunk_super_cgrs.keys())}")
406
-
407
- del temp_dict
408
- gc.collect()
409
-
410
- # Process reduced CGRs in chunks
411
- reduced_super_cgrs_dict = {}
412
- for i in range(0, len(super_cgrs_dict), chunk_size):
413
- keys = list(super_cgrs_dict.keys())[i:i+chunk_size]
414
- chunk_dict = {k: super_cgrs_dict[k] for k in keys}
415
- print(f"Reducing chunk for keys: {keys}")
416
- reduced_chunk = process_all_rs_cgrs(chunk_dict)
417
- reduced_super_cgrs_dict.update(reduced_chunk)
418
- print(f"Reduced chunk processed for keys: {list(reduced_chunk.keys())}")
419
-
420
- del chunk_dict
421
- gc.collect()
422
-
423
- print("Clustering data preparation complete.")
424
- return reduced_super_cgrs_dict
425
- except Exception as e:
426
- print(f"Error in prepare_clustering_data: {str(e)}")
427
- st.error(f"Error in prepare_clustering_data: {str(e)}")
428
- return None
429
-
430
-
431
- def memory_status():
432
- """Get current memory status"""
433
- process = psutil.Process()
434
- memory = process.memory_info().rss / 1024 / 1024
435
- return f"Memory usage: {memory:.2f} MB"
436
-
437
- # Initialize session state for tree and clustering data
438
- if 'tree_data' not in st.session_state:
439
- st.session_state.tree_data = tree
440
- if 'clustering_state' not in st.session_state:
441
- st.session_state.clustering_state = {
442
- 'prepared': False,
443
- 'data': None,
444
- 'num_clusters': 2
445
- }
446
-
447
- cluster_box, z = st.columns(2, gap="medium")
448
- with cluster_box:
449
- st.write(memory_status())
450
- st.write(f"Number of winning nodes: {len(st.session_state.tree_data.winning_nodes)}")
451
-
452
- # Step 1: Prepare Data Button
453
- if not st.session_state.clustering_state['prepared']:
454
- if st.button('Step 1: Prepare clustering data'):
455
- with st.spinner("Preparing data..."):
456
- try:
457
- st.session_state.clustering_state['data'] = prepare_clustering_data(st.session_state.tree_data)
458
- st.session_state.clustering_state['prepared'] = True
459
- st.success("Data prepared! Now you can proceed to Step 2.")
460
- except Exception as e:
461
- st.error(f"Preparation failed: {str(e)}")
462
-
463
- # Step 2: Only show clustering controls if data is prepared
464
- if st.session_state.clustering_state['prepared']:
465
- st.markdown("### Step 2: Select number of clusters")
466
- # Store slider value in session state
467
- st.session_state.clustering_state['num_clusters'] = st.slider(
468
- 'Number of clusters',
469
- min_value=2,
470
- max_value=min(10, len(st.session_state.tree_data.winning_nodes)),
471
- value=st.session_state.clustering_state['num_clusters']
472
- )
473
-
474
- # Step 3: Generate Clusters Button
475
- if st.button('Step 3: Generate clusters'):
476
- with st.spinner("Clustering..."):
477
- try:
478
- results = perform_clustering(
479
- st.session_state.clustering_state['data'],
480
- st.session_state.clustering_state['num_clusters']
481
- )
482
-
483
- if results:
484
- st.success("Clustering complete!")
485
- for cluster_num, node_ids in results['clusters_dict'].items():
486
- with st.expander(f"Cluster {cluster_num}"):
487
- if node_ids:
488
- node_id = node_ids[0]
489
- num_steps = len(st.session_state.tree_data.synthesis_route(node_id))
490
- route_score = round(st.session_state.tree_data.route_score(node_id), 3)
491
- st.image(
492
- get_route_svg(st.session_state.tree_data, node_id),
493
- caption=f"Route {node_id}; {num_steps} steps; Score: {route_score}"
494
- )
495
- except Exception as e:
496
- st.error(f"Clustering failed: {str(e)}")
497
-
498
- # Clear memory button
499
- if st.button('Clear memory and start over'):
500
- st.cache_data.clear()
501
- del st.session_state.clustering_state
502
- del st.session_state.tree_data
503
- gc.collect()
504
- st.success("Memory cleared! Please refresh the page to start over.")
505
- st.rerun()
506
-
507
  stat_col, download_col = st.columns(2, gap="medium")
508
  with stat_col:
509
  st.subheader("Statistics")
@@ -519,7 +292,6 @@ if st.session_state.planning_done and st.session_state.res is not None and st.se
519
 
520
  st.header("Clustering the retrosynthetic routes")
521
 
522
- # Initialize slider state if not already set
523
  if 'num_clusters' not in st.session_state:
524
  st.session_state['num_clusters'] = 10
525
 
@@ -532,13 +304,10 @@ if st.session_state.planning_done and st.session_state.res is not None and st.se
532
  value=st.session_state['num_clusters'],
533
  key='cluster_slider'
534
  )
535
- # Save the current slider value to session_state
536
  st.session_state['num_clusters'] = num_clusters
537
 
538
  if st.button('Start clustering', key='submit_clustering'):
539
  st.session_state.clustering_started = True
540
- # st.write("Clustering started; session state now:", dict(st.session_state))
541
- # st.write("Clustering started!")
542
  st.subheader("Examples of clusters")
543
  super_cgrs_dict = reassign_nums(tree)
544
 
@@ -562,16 +331,13 @@ if st.session_state.planning_done and st.session_state.res is not None and st.se
562
  cluster_stat_col, cluster_download_col = st.columns(2, gap="medium")
563
  with cluster_stat_col:
564
  st.subheader("Statistics")
565
- # st.write(cluster_sizes)
566
  cluster_df = pd.DataFrame({'Cluster': range(len(cluster_sizes)), 'Routes': cluster_sizes})
567
- # cluster_df = pd.DataFrame(cluster_sizes, index=[0])
568
  st.write(cluster_df)
569
 
570
  def on_download_click():
571
  st.session_state.clusters_downloaded = True
572
  st.write("Download clusters button pressed via on_click. Updated session state:", dict(st.session_state))
573
  save_route_images(tree, reactions_dict, cluster_dict=clusters_converted)
574
- # Here you can call save_route_images(...) if desired.
575
 
576
  with cluster_download_col:
577
  st.subheader("Downloads: Don't work. Resets evey time")
@@ -581,24 +347,8 @@ if st.session_state.planning_done and st.session_state.res is not None and st.se
581
  # Use on_click to capture the click event reliably.
582
  st.button('Download clusters', key='download_clusters_button', on_click=on_download_click)
583
 
584
- # Log whether the flag has been set after the button definition.
585
  st.write("Clusters downloaded flag (from session_state):", st.session_state.get("clusters_downloaded"))
586
 
587
- # # save_route_images(tree, reactions_dict, cluster_dict=clusters_converted)
588
- # with cluster_download_col:
589
- # st.subheader("Downloads")
590
- # reactions_dict = extract_reactions(tree)
591
- # clusters_converted = {int(key): value for key, value in clusters.items()} if clusters else clusters
592
-
593
- # if st.session_state.clustering_started:
594
- # st.write("Rendering download clusters button. Session state:", dict(st.session_state))
595
- # # Use a more unique key for the download button.
596
- # download_clusters = st.button('Download clusters', key='download_clusters_button')
597
- # st.write("download_clusters value:", download_clusters)
598
- # if download_clusters:
599
- # st.session_state.clusters_downloaded = True
600
- # st.write("Download clusters button pressed. Updated session state:", dict(st.session_state))
601
-
602
  col1, _ = st.columns([.2, .8])
603
  with col1:
604
  fig = pie_chart(cluster_sizes)
@@ -631,4 +381,8 @@ st.divider()
631
  st.header('Restart from the beginning?')
632
  if st.button("Restart"):
633
  st.session_state.planning_done = False
 
 
 
 
634
  st.rerun()
 
127
  if 'clusters_downloaded' not in st.session_state:
128
  st.session_state.clusters_downloaded = False
129
 
 
130
 
131
  intro_text = '''
132
  This is a demo of the graphical user interface of
 
255
  st.session_state.planning_done = True
256
 
257
  # Display results if planning has been completed
258
+ if st.session_state.planning_done and st.session_state.res is not None:
259
  res = st.session_state.res
260
  tree = st.session_state.tree
261
 
 
274
  image_counter += 1
275
  num_steps = len(tree.synthesis_route(node_id))
276
  route_score = round(tree.route_score(node_id), 3)
277
+ st.image(get_route_svg(tree, node_id),
278
+ caption=f"Route {node_id}; {num_steps} steps; Route score: {route_score}")
279
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  stat_col, download_col = st.columns(2, gap="medium")
281
  with stat_col:
282
  st.subheader("Statistics")
 
292
 
293
  st.header("Clustering the retrosynthetic routes")
294
 
 
295
  if 'num_clusters' not in st.session_state:
296
  st.session_state['num_clusters'] = 10
297
 
 
304
  value=st.session_state['num_clusters'],
305
  key='cluster_slider'
306
  )
 
307
  st.session_state['num_clusters'] = num_clusters
308
 
309
  if st.button('Start clustering', key='submit_clustering'):
310
  st.session_state.clustering_started = True
 
 
311
  st.subheader("Examples of clusters")
312
  super_cgrs_dict = reassign_nums(tree)
313
 
 
331
  cluster_stat_col, cluster_download_col = st.columns(2, gap="medium")
332
  with cluster_stat_col:
333
  st.subheader("Statistics")
 
334
  cluster_df = pd.DataFrame({'Cluster': range(len(cluster_sizes)), 'Routes': cluster_sizes})
 
335
  st.write(cluster_df)
336
 
337
  def on_download_click():
338
  st.session_state.clusters_downloaded = True
339
  st.write("Download clusters button pressed via on_click. Updated session state:", dict(st.session_state))
340
  save_route_images(tree, reactions_dict, cluster_dict=clusters_converted)
 
341
 
342
  with cluster_download_col:
343
  st.subheader("Downloads: Don't work. Resets evey time")
 
347
  # Use on_click to capture the click event reliably.
348
  st.button('Download clusters', key='download_clusters_button', on_click=on_download_click)
349
 
 
350
  st.write("Clusters downloaded flag (from session_state):", st.session_state.get("clusters_downloaded"))
351
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
  col1, _ = st.columns([.2, .8])
353
  with col1:
354
  fig = pie_chart(cluster_sizes)
 
381
  st.header('Restart from the beginning?')
382
  if st.button("Restart"):
383
  st.session_state.planning_done = False
384
+ st.session_state.tree = None
385
+ st.session_state.res = None
386
+ st.session_state.clustering_started = False
387
+ st.session_state.clusters_downloaded = False
388
  st.rerun()