Gilmullin Almaz commited on
Commit
1c0b880
·
2 Parent(s): 27a7101 a227a1b

Merge branch 'main' of https://huggingface.co/spaces/Protolaw/SynPlanner

Browse files
Files changed (1) hide show
  1. app.py +228 -2
app.py CHANGED
@@ -275,9 +275,235 @@ if st.session_state.planning_done and st.session_state.res is not None and st.se
275
  image_counter += 1
276
  num_steps = len(tree.synthesis_route(node_id))
277
  route_score = round(tree.route_score(node_id), 3)
278
- st.image(get_route_svg(tree, node_id),
279
- caption=f"Route {node_id}; {num_steps} steps; Route score: {route_score}")
280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  stat_col, download_col = st.columns(2, gap="medium")
282
  with stat_col:
283
  st.subheader("Statistics")
 
275
  image_counter += 1
276
  num_steps = len(tree.synthesis_route(node_id))
277
  route_score = round(tree.route_score(node_id), 3)
278
+ st.image(get_route_svg(tree, node_id), caption=f"Route {node_id}; {num_steps} steps; Route score: {route_score}")
279
+
280
 
281
+ ### Modified part
282
+ # cluster_box, z = st.columns(2, gap="medium")
283
+ # with cluster_box:
284
+ # num_clusters = st.slider('Number of clusters to display', min_value=2, max_value=10, value=2)
285
+
286
+ # submit_clustering = st.button('Start clustering')
287
+
288
+ # if submit_clustering:
289
+ # st.subheader("Examples of clusters")
290
+ # super_cgrs_dict = reassign_nums(tree)
291
+
292
+ # reduced_super_cgrs_dict = process_all_rs_cgrs(super_cgrs_dict)
293
+
294
+ # mfp = MorganFingerprint()
295
+
296
+ # results = cluster_molecules(reduced_super_cgrs_dict, mfp, max_clusters=num_clusters)
297
+ # cluster_box, z = st.columns(2, gap="medium")
298
+ # with cluster_box:
299
+ # # Initialize session state if not exists
300
+ # if 'memory_warning_shown' not in st.session_state:
301
+ # st.session_state.memory_warning_shown = False
302
+
303
+ # current_memory = psutil.Process().memory_info().rss / 1024 / 1024
304
+ # st.write(f"Current memory usage: {current_memory:.2f} MB")
305
+ # st.write(f"Number of winning nodes: {len(tree.winning_nodes)}")
306
+
307
+ # # Memory warning
308
+ # if current_memory > 1000 and not st.session_state.memory_warning_shown:
309
+ # st.warning("Memory usage is high. Consider reducing the number of routes or clearing cache.")
310
+ # st.session_state.memory_warning_shown = True
311
+
312
+ # # Store the previous value in session state
313
+ # if 'prev_num_clusters' not in st.session_state:
314
+ # st.session_state.prev_num_clusters = 2
315
+
316
+ # num_clusters = st.slider(
317
+ # 'Number of clusters to display',
318
+ # min_value=2,
319
+ # max_value=min(10, len(tree.winning_nodes)),
320
+ # value=st.session_state.prev_num_clusters
321
+ # )
322
+
323
+ # # Update the stored value only if it changed
324
+ # if num_clusters != st.session_state.prev_num_clusters:
325
+ # st.session_state.prev_num_clusters = num_clusters
326
+
327
+ # submit_clustering = st.button('Start clustering')
328
+
329
+ # if submit_clustering:
330
+ # try:
331
+ # with st.spinner("Processing clusters..."):
332
+ # # Clear memory before starting
333
+ # gc.collect()
334
+
335
+ # st.write("Starting clustering process...")
336
+ # memory_before = psutil.Process().memory_info().rss / 1024 / 1024
337
+ # st.write(f"Memory before clustering: {memory_before:.2f} MB")
338
+
339
+ # super_cgrs_dict = reassign_nums(tree)
340
+ # del tree # Free up memory from the tree object since we don't need it anymore
341
+ # gc.collect()
342
+
343
+ # reduced_super_cgrs_dict = process_all_rs_cgrs(super_cgrs_dict)
344
+ # del super_cgrs_dict # Free up memory
345
+ # gc.collect()
346
+
347
+ # memory_after = psutil.Process().memory_info().rss / 1024 / 1024
348
+ # st.write(f"Memory after CGR processing: {memory_after:.2f} MB")
349
+
350
+ # mfp = MorganFingerprint()
351
+ # results = cluster_molecules(reduced_super_cgrs_dict, mfp, max_clusters=num_clusters)
352
+ # del reduced_super_cgrs_dict # Free up memory
353
+ # gc.collect()
354
+
355
+ # st.write("Clustering completed")
356
+
357
+ # except Exception as e:
358
+ # st.error(f"Clustering failed with error: {str(e)}")
359
+ # st.write(f"Memory at error: {psutil.Process().memory_info().rss / 1024 / 1024:.2f} MB")
360
+ # raise e
361
+
362
+
363
+ # Access results
364
+ # clusters = results['clusters_dict']
365
+
366
+ # for cluster_num, node_id_list in clusters.items():
367
+ # st.markdown(f"Cluster's number: ``{cluster_num}``")
368
+ # node_id = node_id_list[0]
369
+ # num_steps = len(tree.synthesis_route(node_id))
370
+ # route_score = round(tree.route_score(node_id), 3)
371
+ # st.image(get_route_svg(tree, node_id), caption=f"Route {node_id}; {num_steps} steps; Route score: {route_score}")
372
+
373
+ @st.cache_data(hash_funcs={Tree: lambda _: None})
374
+ def prepare_clustering_data(tree):
375
+ try:
376
+ # Log the start and basic info from the Tree
377
+ print("Starting clustering data preparation.")
378
+ total_nodes = len(tree.winning_nodes)
379
+ print(f"Total winning nodes: {total_nodes}")
380
+ print(f"Tree id: {id(tree)}")
381
+
382
+ chunk_size = 10
383
+ super_cgrs_dict = {}
384
+
385
+ # Process winning nodes in chunks
386
+ for i in range(0, total_nodes, chunk_size):
387
+ current_chunk = list(tree.winning_nodes)[i:i+chunk_size]
388
+ print(f"Processing chunk {i // chunk_size + 1}: Nodes {current_chunk}")
389
+
390
+ temp_dict = {}
391
+ for node in current_chunk:
392
+ try:
393
+ # Log before processing each node
394
+ print(f"Processing node {node}")
395
+ route = tree.synthesis_route(node)
396
+ temp_dict[node] = route
397
+ print(f"Node {node} processed successfully (route length: {len(route)}).")
398
+ except Exception as e:
399
+ print(f"Error processing node {node}: {e}")
400
+
401
+ # Log before calling reassign_nums_chunk
402
+ print(f"Calling reassign_nums_chunk for nodes: {list(temp_dict.keys())}")
403
+ chunk_super_cgrs = reassign_nums_chunk(temp_dict)
404
+ super_cgrs_dict.update(chunk_super_cgrs)
405
+ print(f"Chunk {i // chunk_size + 1} processed. Keys: {list(chunk_super_cgrs.keys())}")
406
+
407
+ del temp_dict
408
+ gc.collect()
409
+
410
+ # Process reduced CGRs in chunks
411
+ reduced_super_cgrs_dict = {}
412
+ for i in range(0, len(super_cgrs_dict), chunk_size):
413
+ keys = list(super_cgrs_dict.keys())[i:i+chunk_size]
414
+ chunk_dict = {k: super_cgrs_dict[k] for k in keys}
415
+ print(f"Reducing chunk for keys: {keys}")
416
+ reduced_chunk = process_all_rs_cgrs(chunk_dict)
417
+ reduced_super_cgrs_dict.update(reduced_chunk)
418
+ print(f"Reduced chunk processed for keys: {list(reduced_chunk.keys())}")
419
+
420
+ del chunk_dict
421
+ gc.collect()
422
+
423
+ print("Clustering data preparation complete.")
424
+ return reduced_super_cgrs_dict
425
+ except Exception as e:
426
+ print(f"Error in prepare_clustering_data: {str(e)}")
427
+ st.error(f"Error in prepare_clustering_data: {str(e)}")
428
+ return None
429
+
430
+
431
+ def memory_status():
432
+ """Get current memory status"""
433
+ process = psutil.Process()
434
+ memory = process.memory_info().rss / 1024 / 1024
435
+ return f"Memory usage: {memory:.2f} MB"
436
+
437
+ # Initialize session state for tree and clustering data
438
+ if 'tree_data' not in st.session_state:
439
+ st.session_state.tree_data = tree
440
+ if 'clustering_state' not in st.session_state:
441
+ st.session_state.clustering_state = {
442
+ 'prepared': False,
443
+ 'data': None,
444
+ 'num_clusters': 2
445
+ }
446
+
447
+ cluster_box, z = st.columns(2, gap="medium")
448
+ with cluster_box:
449
+ st.write(memory_status())
450
+ st.write(f"Number of winning nodes: {len(st.session_state.tree_data.winning_nodes)}")
451
+
452
+ # Step 1: Prepare Data Button
453
+ if not st.session_state.clustering_state['prepared']:
454
+ if st.button('Step 1: Prepare clustering data'):
455
+ with st.spinner("Preparing data..."):
456
+ try:
457
+ st.session_state.clustering_state['data'] = prepare_clustering_data(st.session_state.tree_data)
458
+ st.session_state.clustering_state['prepared'] = True
459
+ st.success("Data prepared! Now you can proceed to Step 2.")
460
+ except Exception as e:
461
+ st.error(f"Preparation failed: {str(e)}")
462
+
463
+ # Step 2: Only show clustering controls if data is prepared
464
+ if st.session_state.clustering_state['prepared']:
465
+ st.markdown("### Step 2: Select number of clusters")
466
+ # Store slider value in session state
467
+ st.session_state.clustering_state['num_clusters'] = st.slider(
468
+ 'Number of clusters',
469
+ min_value=2,
470
+ max_value=min(10, len(st.session_state.tree_data.winning_nodes)),
471
+ value=st.session_state.clustering_state['num_clusters']
472
+ )
473
+
474
+ # Step 3: Generate Clusters Button
475
+ if st.button('Step 3: Generate clusters'):
476
+ with st.spinner("Clustering..."):
477
+ try:
478
+ results = perform_clustering(
479
+ st.session_state.clustering_state['data'],
480
+ st.session_state.clustering_state['num_clusters']
481
+ )
482
+
483
+ if results:
484
+ st.success("Clustering complete!")
485
+ for cluster_num, node_ids in results['clusters_dict'].items():
486
+ with st.expander(f"Cluster {cluster_num}"):
487
+ if node_ids:
488
+ node_id = node_ids[0]
489
+ num_steps = len(st.session_state.tree_data.synthesis_route(node_id))
490
+ route_score = round(st.session_state.tree_data.route_score(node_id), 3)
491
+ st.image(
492
+ get_route_svg(st.session_state.tree_data, node_id),
493
+ caption=f"Route {node_id}; {num_steps} steps; Score: {route_score}"
494
+ )
495
+ except Exception as e:
496
+ st.error(f"Clustering failed: {str(e)}")
497
+
498
+ # Clear memory button
499
+ if st.button('Clear memory and start over'):
500
+ st.cache_data.clear()
501
+ del st.session_state.clustering_state
502
+ del st.session_state.tree_data
503
+ gc.collect()
504
+ st.success("Memory cleared! Please refresh the page to start over.")
505
+ st.rerun()
506
+
507
  stat_col, download_col = st.columns(2, gap="medium")
508
  with stat_col:
509
  st.subheader("Statistics")