Spaces:
Sleeping
Sleeping
Gilmullin Almaz
commited on
Commit
·
52cfb6f
1
Parent(s):
53d6b47
optimized clustering
Browse files
app.py
CHANGED
|
@@ -34,6 +34,82 @@ disable_progress_bars("huggingface_hub")
|
|
| 34 |
smiles_parser = SMILESRead.create_parser(ignore=True)
|
| 35 |
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
def download_button(object_to_download, download_filename, button_text, pickle_it=False):
|
| 38 |
"""
|
| 39 |
Issued from
|
|
@@ -341,86 +417,139 @@ if submit_planning:
|
|
| 341 |
# route_score = round(tree.route_score(node_id), 3)
|
| 342 |
# st.image(get_route_svg(tree, node_id), caption=f"Route {node_id}; {num_steps} steps; Route score: {route_score}")
|
| 343 |
|
| 344 |
-
# Add these functions outside the if submit_planning block
|
| 345 |
@st.cache_data
|
| 346 |
def prepare_clustering_data(tree):
|
| 347 |
-
"""Pre-compute and cache the clustering data"""
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 351 |
|
| 352 |
@st.cache_data
|
| 353 |
-
def perform_clustering(_reduced_super_cgrs_dict, num_clusters):
|
| 354 |
-
"""Perform
|
| 355 |
-
|
| 356 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 357 |
|
| 358 |
cluster_box, z = st.columns(2, gap="medium")
|
| 359 |
with cluster_box:
|
| 360 |
-
# Initialize session state
|
| 361 |
-
if '
|
| 362 |
-
st.session_state.
|
| 363 |
-
|
|
|
|
|
|
|
|
|
|
| 364 |
|
| 365 |
-
|
| 366 |
-
st.write(f"Current memory usage: {current_memory:.2f} MB")
|
| 367 |
st.write(f"Number of winning nodes: {len(tree.winning_nodes)}")
|
| 368 |
|
| 369 |
-
#
|
| 370 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 371 |
if st.button('Prepare clustering data'):
|
| 372 |
-
with st.spinner("Preparing
|
| 373 |
try:
|
| 374 |
-
|
| 375 |
-
st.session_state.
|
| 376 |
-
st.session_state.
|
| 377 |
-
|
|
|
|
| 378 |
except Exception as e:
|
| 379 |
-
st.error(f"
|
| 380 |
-
st.write(f"Memory at error: {current_memory:.2f} MB")
|
| 381 |
|
| 382 |
-
#
|
| 383 |
-
if st.session_state.
|
| 384 |
num_clusters = st.slider(
|
| 385 |
-
'Number of clusters
|
| 386 |
-
min_value=2,
|
| 387 |
-
max_value=min(10, len(tree.winning_nodes)),
|
| 388 |
value=2
|
| 389 |
)
|
| 390 |
-
|
| 391 |
if st.button('Generate clusters'):
|
| 392 |
-
with st.spinner("
|
| 393 |
try:
|
| 394 |
results = perform_clustering(
|
| 395 |
-
st.session_state.
|
| 396 |
num_clusters
|
| 397 |
)
|
| 398 |
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
except Exception as e:
|
| 412 |
st.error(f"Clustering failed: {str(e)}")
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
# Add clear cache button
|
| 416 |
-
if st.button('Clear cache and memory'):
|
| 417 |
-
st.cache_data.clear()
|
| 418 |
-
st.session_state.clustering_prepared = False
|
| 419 |
-
st.session_state.reduced_super_cgrs_dict = None
|
| 420 |
-
gc.collect()
|
| 421 |
-
st.success("Cache and memory cleared!")
|
| 422 |
-
st.rerun()
|
| 423 |
-
|
| 424 |
stat_col, download_col = st.columns(2, gap="medium")
|
| 425 |
|
| 426 |
with stat_col:
|
|
|
|
| 34 |
smiles_parser = SMILESRead.create_parser(ignore=True)
|
| 35 |
|
| 36 |
|
| 37 |
+
def reassign_nums_chunk(route_dict):
|
| 38 |
+
"""Process a chunk of routes for reassigning numbers"""
|
| 39 |
+
return {k: reassign_nums(v) for k, v in route_dict.items()}
|
| 40 |
+
|
| 41 |
+
def cluster_molecules_optimized(fingerprints_dict, max_clusters):
|
| 42 |
+
"""Memory-optimized version of cluster_molecules.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
fingerprints_dict (dict): Dictionary of pre-computed fingerprints
|
| 46 |
+
max_clusters (int): Maximum number of clusters
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
dict: Clustering results containing clusters_dict and cluster_labels
|
| 50 |
+
"""
|
| 51 |
+
try:
|
| 52 |
+
# Convert dictionary to arrays for efficient processing
|
| 53 |
+
labels = np.array(list(fingerprints_dict.keys()))
|
| 54 |
+
fingerprints = np.array(list(fingerprints_dict.values()))
|
| 55 |
+
|
| 56 |
+
# Calculate similarity matrix in chunks to save memory
|
| 57 |
+
chunk_size = 100
|
| 58 |
+
n_samples = len(fingerprints)
|
| 59 |
+
similarity_matrix = np.zeros((n_samples, n_samples))
|
| 60 |
+
|
| 61 |
+
for i in range(0, n_samples, chunk_size):
|
| 62 |
+
chunk_end = min(i + chunk_size, n_samples)
|
| 63 |
+
chunk = fingerprints[i:chunk_end]
|
| 64 |
+
|
| 65 |
+
# Calculate similarity for this chunk against all fingerprints
|
| 66 |
+
similarity_chunk = tanimoto_similarity_continuous(chunk, fingerprints)
|
| 67 |
+
similarity_matrix[i:chunk_end] = similarity_chunk
|
| 68 |
+
|
| 69 |
+
# Clear memory
|
| 70 |
+
del similarity_chunk
|
| 71 |
+
gc.collect()
|
| 72 |
+
|
| 73 |
+
# Convert to distance matrix
|
| 74 |
+
distance_matrix = 1 - similarity_matrix
|
| 75 |
+
|
| 76 |
+
# Free memory
|
| 77 |
+
del similarity_matrix
|
| 78 |
+
gc.collect()
|
| 79 |
+
|
| 80 |
+
# Calculate condensed distance matrix
|
| 81 |
+
condensed_distance = squareform(distance_matrix)
|
| 82 |
+
|
| 83 |
+
# Free memory
|
| 84 |
+
del distance_matrix
|
| 85 |
+
gc.collect()
|
| 86 |
+
|
| 87 |
+
# Calculate linkage
|
| 88 |
+
Z = fastcluster.linkage(condensed_distance, method='average')
|
| 89 |
+
|
| 90 |
+
# Free memory
|
| 91 |
+
del condensed_distance
|
| 92 |
+
gc.collect()
|
| 93 |
+
|
| 94 |
+
# Perform clustering
|
| 95 |
+
cluster_labels = fcluster(Z, max_clusters, criterion='maxclust')
|
| 96 |
+
|
| 97 |
+
# Create clusters dictionary
|
| 98 |
+
clusters_dict = {}
|
| 99 |
+
for cluster in range(1, max_clusters + 1):
|
| 100 |
+
cluster_indices = np.where(cluster_labels == cluster)[0]
|
| 101 |
+
clusters_dict[cluster] = list(labels[cluster_indices])
|
| 102 |
+
|
| 103 |
+
return {
|
| 104 |
+
'clusters_dict': clusters_dict,
|
| 105 |
+
'cluster_labels': cluster_labels,
|
| 106 |
+
'linkage_matrix': Z
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
except Exception as e:
|
| 110 |
+
print(f"Error in cluster_molecules_optimized: {str(e)}")
|
| 111 |
+
raise e
|
| 112 |
+
|
| 113 |
def download_button(object_to_download, download_filename, button_text, pickle_it=False):
|
| 114 |
"""
|
| 115 |
Issued from
|
|
|
|
| 417 |
# route_score = round(tree.route_score(node_id), 3)
|
| 418 |
# st.image(get_route_svg(tree, node_id), caption=f"Route {node_id}; {num_steps} steps; Route score: {route_score}")
|
| 419 |
|
|
|
|
| 420 |
@st.cache_data
|
| 421 |
def prepare_clustering_data(tree):
|
| 422 |
+
"""Pre-compute and cache the clustering data in chunks"""
|
| 423 |
+
try:
|
| 424 |
+
# Free memory before starting
|
| 425 |
+
gc.collect()
|
| 426 |
+
|
| 427 |
+
# Process in chunks
|
| 428 |
+
chunk_size = 10
|
| 429 |
+
super_cgrs_dict = {}
|
| 430 |
+
|
| 431 |
+
for i in range(0, len(tree.winning_nodes), chunk_size):
|
| 432 |
+
chunk = list(tree.winning_nodes)[i:i+chunk_size]
|
| 433 |
+
temp_dict = {node: tree.synthesis_route(node) for node in chunk}
|
| 434 |
+
chunk_super_cgrs = reassign_nums_chunk(temp_dict)
|
| 435 |
+
super_cgrs_dict.update(chunk_super_cgrs)
|
| 436 |
+
del temp_dict
|
| 437 |
+
gc.collect()
|
| 438 |
+
|
| 439 |
+
# Process reduced CGRs in chunks
|
| 440 |
+
reduced_super_cgrs_dict = {}
|
| 441 |
+
for i in range(0, len(super_cgrs_dict), chunk_size):
|
| 442 |
+
keys = list(super_cgrs_dict.keys())[i:i+chunk_size]
|
| 443 |
+
chunk_dict = {k: super_cgrs_dict[k] for k in keys}
|
| 444 |
+
reduced_chunk = process_all_rs_cgrs(chunk_dict)
|
| 445 |
+
reduced_super_cgrs_dict.update(reduced_chunk)
|
| 446 |
+
del chunk_dict
|
| 447 |
+
gc.collect()
|
| 448 |
+
|
| 449 |
+
del super_cgrs_dict
|
| 450 |
+
gc.collect()
|
| 451 |
+
|
| 452 |
+
return reduced_super_cgrs_dict
|
| 453 |
+
except Exception as e:
|
| 454 |
+
st.error(f"Error in prepare_clustering_data: {str(e)}")
|
| 455 |
+
return None
|
| 456 |
|
| 457 |
@st.cache_data
|
| 458 |
+
def perform_clustering(_reduced_super_cgrs_dict, num_clusters, chunk_size=10):
|
| 459 |
+
"""Perform clustering with memory-efficient processing"""
|
| 460 |
+
try:
|
| 461 |
+
mfp = MorganFingerprint()
|
| 462 |
+
|
| 463 |
+
# Process fingerprints in chunks
|
| 464 |
+
all_fingerprints = {}
|
| 465 |
+
for i in range(0, len(_reduced_super_cgrs_dict), chunk_size):
|
| 466 |
+
keys = list(_reduced_super_cgrs_dict.keys())[i:i+chunk_size]
|
| 467 |
+
chunk_dict = {k: _reduced_super_cgrs_dict[k] for k in keys}
|
| 468 |
+
chunk_fingerprints = {k: mfp.calculate(v) for k, v in chunk_dict.items()}
|
| 469 |
+
all_fingerprints.update(chunk_fingerprints)
|
| 470 |
+
del chunk_dict
|
| 471 |
+
gc.collect()
|
| 472 |
+
|
| 473 |
+
return cluster_molecules_optimized(all_fingerprints, max_clusters=num_clusters)
|
| 474 |
+
except Exception as e:
|
| 475 |
+
st.error(f"Error in perform_clustering: {str(e)}")
|
| 476 |
+
return None
|
| 477 |
+
|
| 478 |
+
def memory_status():
|
| 479 |
+
"""Get current memory status"""
|
| 480 |
+
process = psutil.Process()
|
| 481 |
+
memory = process.memory_info().rss / 1024 / 1024
|
| 482 |
+
return f"Memory usage: {memory:.2f} MB"
|
| 483 |
|
| 484 |
cluster_box, z = st.columns(2, gap="medium")
|
| 485 |
with cluster_box:
|
| 486 |
+
# Initialize session state
|
| 487 |
+
if 'clustering_state' not in st.session_state:
|
| 488 |
+
st.session_state.clustering_state = {
|
| 489 |
+
'prepared': False,
|
| 490 |
+
'data': None,
|
| 491 |
+
'last_memory': 0
|
| 492 |
+
}
|
| 493 |
|
| 494 |
+
st.write(memory_status())
|
|
|
|
| 495 |
st.write(f"Number of winning nodes: {len(tree.winning_nodes)}")
|
| 496 |
|
| 497 |
+
# Memory management controls
|
| 498 |
+
if st.button('Clear memory'):
|
| 499 |
+
st.cache_data.clear()
|
| 500 |
+
st.session_state.clustering_state = {
|
| 501 |
+
'prepared': False,
|
| 502 |
+
'data': None,
|
| 503 |
+
'last_memory': 0
|
| 504 |
+
}
|
| 505 |
+
gc.collect()
|
| 506 |
+
st.success("Memory cleared!")
|
| 507 |
+
st.rerun()
|
| 508 |
+
|
| 509 |
+
# Prepare data with progress tracking
|
| 510 |
+
if not st.session_state.clustering_state['prepared']:
|
| 511 |
if st.button('Prepare clustering data'):
|
| 512 |
+
with st.spinner("Preparing data..."):
|
| 513 |
try:
|
| 514 |
+
progress_bar = st.progress(0)
|
| 515 |
+
st.session_state.clustering_state['data'] = prepare_clustering_data(tree)
|
| 516 |
+
st.session_state.clustering_state['prepared'] = True
|
| 517 |
+
progress_bar.progress(100)
|
| 518 |
+
st.success("Data prepared!")
|
| 519 |
except Exception as e:
|
| 520 |
+
st.error(f"Preparation failed: {str(e)}")
|
|
|
|
| 521 |
|
| 522 |
+
# Clustering controls
|
| 523 |
+
if st.session_state.clustering_state['prepared']:
|
| 524 |
num_clusters = st.slider(
|
| 525 |
+
'Number of clusters',
|
| 526 |
+
min_value=2,
|
| 527 |
+
max_value=min(10, len(tree.winning_nodes)),
|
| 528 |
value=2
|
| 529 |
)
|
| 530 |
+
|
| 531 |
if st.button('Generate clusters'):
|
| 532 |
+
with st.spinner("Clustering..."):
|
| 533 |
try:
|
| 534 |
results = perform_clustering(
|
| 535 |
+
st.session_state.clustering_state['data'],
|
| 536 |
num_clusters
|
| 537 |
)
|
| 538 |
|
| 539 |
+
if results:
|
| 540 |
+
for cluster_num, node_ids in results['clusters_dict'].items():
|
| 541 |
+
with st.expander(f"Cluster {cluster_num}"):
|
| 542 |
+
if node_ids:
|
| 543 |
+
node_id = node_ids[0]
|
| 544 |
+
num_steps = len(tree.synthesis_route(node_id))
|
| 545 |
+
route_score = round(tree.route_score(node_id), 3)
|
| 546 |
+
st.image(
|
| 547 |
+
get_route_svg(tree, node_id),
|
| 548 |
+
caption=f"Route {node_id}; {num_steps} steps; Score: {route_score}"
|
| 549 |
+
)
|
|
|
|
| 550 |
except Exception as e:
|
| 551 |
st.error(f"Clustering failed: {str(e)}")
|
| 552 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 553 |
stat_col, download_col = st.columns(2, gap="medium")
|
| 554 |
|
| 555 |
with stat_col:
|