Gilmullin Almaz commited on
Commit
52cfb6f
·
1 Parent(s): 53d6b47

optimized clustering

Browse files
Files changed (1) hide show
  1. app.py +184 -55
app.py CHANGED
@@ -34,6 +34,82 @@ disable_progress_bars("huggingface_hub")
34
  smiles_parser = SMILESRead.create_parser(ignore=True)
35
 
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  def download_button(object_to_download, download_filename, button_text, pickle_it=False):
38
  """
39
  Issued from
@@ -341,86 +417,139 @@ if submit_planning:
341
  # route_score = round(tree.route_score(node_id), 3)
342
  # st.image(get_route_svg(tree, node_id), caption=f"Route {node_id}; {num_steps} steps; Route score: {route_score}")
343
 
344
- # Add these functions outside the if submit_planning block
345
  @st.cache_data
346
  def prepare_clustering_data(tree):
347
- """Pre-compute and cache the clustering data"""
348
- super_cgrs_dict = reassign_nums(tree)
349
- reduced_super_cgrs_dict = process_all_rs_cgrs(super_cgrs_dict)
350
- return reduced_super_cgrs_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
 
352
  @st.cache_data
353
- def perform_clustering(_reduced_super_cgrs_dict, num_clusters):
354
- """Perform the actual clustering with cached results"""
355
- mfp = MorganFingerprint()
356
- return cluster_molecules(_reduced_super_cgrs_dict, mfp, max_clusters=num_clusters)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
 
358
  cluster_box, z = st.columns(2, gap="medium")
359
  with cluster_box:
360
- # Initialize session state for clustering data
361
- if 'clustering_prepared' not in st.session_state:
362
- st.session_state.clustering_prepared = False
363
- st.session_state.reduced_super_cgrs_dict = None
 
 
 
364
 
365
- current_memory = psutil.Process().memory_info().rss / 1024 / 1024
366
- st.write(f"Current memory usage: {current_memory:.2f} MB")
367
  st.write(f"Number of winning nodes: {len(tree.winning_nodes)}")
368
 
369
- # Prepare data button
370
- if not st.session_state.clustering_prepared:
 
 
 
 
 
 
 
 
 
 
 
 
371
  if st.button('Prepare clustering data'):
372
- with st.spinner("Preparing clustering data..."):
373
  try:
374
- gc.collect()
375
- st.session_state.reduced_super_cgrs_dict = prepare_clustering_data(tree)
376
- st.session_state.clustering_prepared = True
377
- st.success("Data prepared successfully!")
 
378
  except Exception as e:
379
- st.error(f"Failed to prepare data: {str(e)}")
380
- st.write(f"Memory at error: {current_memory:.2f} MB")
381
 
382
- # Only show clustering controls if data is prepared
383
- if st.session_state.clustering_prepared:
384
  num_clusters = st.slider(
385
- 'Number of clusters to display',
386
- min_value=2,
387
- max_value=min(10, len(tree.winning_nodes)),
388
  value=2
389
  )
390
-
391
  if st.button('Generate clusters'):
392
- with st.spinner("Generating clusters..."):
393
  try:
394
  results = perform_clustering(
395
- st.session_state.reduced_super_cgrs_dict,
396
  num_clusters
397
  )
398
 
399
- # Display clusters
400
- clusters = results['clusters_dict']
401
- for cluster_num, node_id_list in clusters.items():
402
- st.markdown(f"Cluster's number: ``{cluster_num}``")
403
- node_id = node_id_list[0]
404
- num_steps = len(tree.synthesis_route(node_id))
405
- route_score = round(tree.route_score(node_id), 3)
406
- st.image(
407
- get_route_svg(tree, node_id),
408
- caption=f"Route {node_id}; {num_steps} steps; Route score: {route_score}"
409
- )
410
-
411
  except Exception as e:
412
  st.error(f"Clustering failed: {str(e)}")
413
- st.write(f"Memory at error: {current_memory:.2f} MB")
414
-
415
- # Add clear cache button
416
- if st.button('Clear cache and memory'):
417
- st.cache_data.clear()
418
- st.session_state.clustering_prepared = False
419
- st.session_state.reduced_super_cgrs_dict = None
420
- gc.collect()
421
- st.success("Cache and memory cleared!")
422
- st.rerun()
423
-
424
  stat_col, download_col = st.columns(2, gap="medium")
425
 
426
  with stat_col:
 
34
  smiles_parser = SMILESRead.create_parser(ignore=True)
35
 
36
 
37
+ def reassign_nums_chunk(route_dict):
38
+ """Process a chunk of routes for reassigning numbers"""
39
+ return {k: reassign_nums(v) for k, v in route_dict.items()}
40
+
41
+ def cluster_molecules_optimized(fingerprints_dict, max_clusters):
42
+ """Memory-optimized version of cluster_molecules.
43
+
44
+ Args:
45
+ fingerprints_dict (dict): Dictionary of pre-computed fingerprints
46
+ max_clusters (int): Maximum number of clusters
47
+
48
+ Returns:
49
+ dict: Clustering results containing clusters_dict and cluster_labels
50
+ """
51
+ try:
52
+ # Convert dictionary to arrays for efficient processing
53
+ labels = np.array(list(fingerprints_dict.keys()))
54
+ fingerprints = np.array(list(fingerprints_dict.values()))
55
+
56
+ # Calculate similarity matrix in chunks to save memory
57
+ chunk_size = 100
58
+ n_samples = len(fingerprints)
59
+ similarity_matrix = np.zeros((n_samples, n_samples))
60
+
61
+ for i in range(0, n_samples, chunk_size):
62
+ chunk_end = min(i + chunk_size, n_samples)
63
+ chunk = fingerprints[i:chunk_end]
64
+
65
+ # Calculate similarity for this chunk against all fingerprints
66
+ similarity_chunk = tanimoto_similarity_continuous(chunk, fingerprints)
67
+ similarity_matrix[i:chunk_end] = similarity_chunk
68
+
69
+ # Clear memory
70
+ del similarity_chunk
71
+ gc.collect()
72
+
73
+ # Convert to distance matrix
74
+ distance_matrix = 1 - similarity_matrix
75
+
76
+ # Free memory
77
+ del similarity_matrix
78
+ gc.collect()
79
+
80
+ # Calculate condensed distance matrix
81
+ condensed_distance = squareform(distance_matrix)
82
+
83
+ # Free memory
84
+ del distance_matrix
85
+ gc.collect()
86
+
87
+ # Calculate linkage
88
+ Z = fastcluster.linkage(condensed_distance, method='average')
89
+
90
+ # Free memory
91
+ del condensed_distance
92
+ gc.collect()
93
+
94
+ # Perform clustering
95
+ cluster_labels = fcluster(Z, max_clusters, criterion='maxclust')
96
+
97
+ # Create clusters dictionary
98
+ clusters_dict = {}
99
+ for cluster in range(1, max_clusters + 1):
100
+ cluster_indices = np.where(cluster_labels == cluster)[0]
101
+ clusters_dict[cluster] = list(labels[cluster_indices])
102
+
103
+ return {
104
+ 'clusters_dict': clusters_dict,
105
+ 'cluster_labels': cluster_labels,
106
+ 'linkage_matrix': Z
107
+ }
108
+
109
+ except Exception as e:
110
+ print(f"Error in cluster_molecules_optimized: {str(e)}")
111
+ raise e
112
+
113
  def download_button(object_to_download, download_filename, button_text, pickle_it=False):
114
  """
115
  Issued from
 
417
  # route_score = round(tree.route_score(node_id), 3)
418
  # st.image(get_route_svg(tree, node_id), caption=f"Route {node_id}; {num_steps} steps; Route score: {route_score}")
419
 
 
420
  @st.cache_data
421
  def prepare_clustering_data(tree):
422
+ """Pre-compute and cache the clustering data in chunks"""
423
+ try:
424
+ # Free memory before starting
425
+ gc.collect()
426
+
427
+ # Process in chunks
428
+ chunk_size = 10
429
+ super_cgrs_dict = {}
430
+
431
+ for i in range(0, len(tree.winning_nodes), chunk_size):
432
+ chunk = list(tree.winning_nodes)[i:i+chunk_size]
433
+ temp_dict = {node: tree.synthesis_route(node) for node in chunk}
434
+ chunk_super_cgrs = reassign_nums_chunk(temp_dict)
435
+ super_cgrs_dict.update(chunk_super_cgrs)
436
+ del temp_dict
437
+ gc.collect()
438
+
439
+ # Process reduced CGRs in chunks
440
+ reduced_super_cgrs_dict = {}
441
+ for i in range(0, len(super_cgrs_dict), chunk_size):
442
+ keys = list(super_cgrs_dict.keys())[i:i+chunk_size]
443
+ chunk_dict = {k: super_cgrs_dict[k] for k in keys}
444
+ reduced_chunk = process_all_rs_cgrs(chunk_dict)
445
+ reduced_super_cgrs_dict.update(reduced_chunk)
446
+ del chunk_dict
447
+ gc.collect()
448
+
449
+ del super_cgrs_dict
450
+ gc.collect()
451
+
452
+ return reduced_super_cgrs_dict
453
+ except Exception as e:
454
+ st.error(f"Error in prepare_clustering_data: {str(e)}")
455
+ return None
456
 
457
  @st.cache_data
458
+ def perform_clustering(_reduced_super_cgrs_dict, num_clusters, chunk_size=10):
459
+ """Perform clustering with memory-efficient processing"""
460
+ try:
461
+ mfp = MorganFingerprint()
462
+
463
+ # Process fingerprints in chunks
464
+ all_fingerprints = {}
465
+ for i in range(0, len(_reduced_super_cgrs_dict), chunk_size):
466
+ keys = list(_reduced_super_cgrs_dict.keys())[i:i+chunk_size]
467
+ chunk_dict = {k: _reduced_super_cgrs_dict[k] for k in keys}
468
+ chunk_fingerprints = {k: mfp.calculate(v) for k, v in chunk_dict.items()}
469
+ all_fingerprints.update(chunk_fingerprints)
470
+ del chunk_dict
471
+ gc.collect()
472
+
473
+ return cluster_molecules_optimized(all_fingerprints, max_clusters=num_clusters)
474
+ except Exception as e:
475
+ st.error(f"Error in perform_clustering: {str(e)}")
476
+ return None
477
+
478
+ def memory_status():
479
+ """Get current memory status"""
480
+ process = psutil.Process()
481
+ memory = process.memory_info().rss / 1024 / 1024
482
+ return f"Memory usage: {memory:.2f} MB"
483
 
484
  cluster_box, z = st.columns(2, gap="medium")
485
  with cluster_box:
486
+ # Initialize session state
487
+ if 'clustering_state' not in st.session_state:
488
+ st.session_state.clustering_state = {
489
+ 'prepared': False,
490
+ 'data': None,
491
+ 'last_memory': 0
492
+ }
493
 
494
+ st.write(memory_status())
 
495
  st.write(f"Number of winning nodes: {len(tree.winning_nodes)}")
496
 
497
+ # Memory management controls
498
+ if st.button('Clear memory'):
499
+ st.cache_data.clear()
500
+ st.session_state.clustering_state = {
501
+ 'prepared': False,
502
+ 'data': None,
503
+ 'last_memory': 0
504
+ }
505
+ gc.collect()
506
+ st.success("Memory cleared!")
507
+ st.rerun()
508
+
509
+ # Prepare data with progress tracking
510
+ if not st.session_state.clustering_state['prepared']:
511
  if st.button('Prepare clustering data'):
512
+ with st.spinner("Preparing data..."):
513
  try:
514
+ progress_bar = st.progress(0)
515
+ st.session_state.clustering_state['data'] = prepare_clustering_data(tree)
516
+ st.session_state.clustering_state['prepared'] = True
517
+ progress_bar.progress(100)
518
+ st.success("Data prepared!")
519
  except Exception as e:
520
+ st.error(f"Preparation failed: {str(e)}")
 
521
 
522
+ # Clustering controls
523
+ if st.session_state.clustering_state['prepared']:
524
  num_clusters = st.slider(
525
+ 'Number of clusters',
526
+ min_value=2,
527
+ max_value=min(10, len(tree.winning_nodes)),
528
  value=2
529
  )
530
+
531
  if st.button('Generate clusters'):
532
+ with st.spinner("Clustering..."):
533
  try:
534
  results = perform_clustering(
535
+ st.session_state.clustering_state['data'],
536
  num_clusters
537
  )
538
 
539
+ if results:
540
+ for cluster_num, node_ids in results['clusters_dict'].items():
541
+ with st.expander(f"Cluster {cluster_num}"):
542
+ if node_ids:
543
+ node_id = node_ids[0]
544
+ num_steps = len(tree.synthesis_route(node_id))
545
+ route_score = round(tree.route_score(node_id), 3)
546
+ st.image(
547
+ get_route_svg(tree, node_id),
548
+ caption=f"Route {node_id}; {num_steps} steps; Score: {route_score}"
549
+ )
 
550
  except Exception as e:
551
  st.error(f"Clustering failed: {str(e)}")
552
+
 
 
 
 
 
 
 
 
 
 
553
  stat_col, download_col = st.columns(2, gap="medium")
554
 
555
  with stat_col: