netzhang's picture
Deploy merged demo: representative images (#42), t-SNE exact solver (#45), PCA reproducibility (#46), decoupled projection/KMeans + thread pipeline, demo header/footer
269ea1f verified
Raw
History Blame Contribute Delete
9.06 kB
"""
Shared clustering summary components.
"""
import streamlit as st
import os
import pandas as pd
from shared.utils.taxonomy_tree import build_taxonomic_tree, format_tree_string, get_tree_statistics
from shared.components.representatives import render_representative_images
from shared.utils.logging_config import get_logger
logger = get_logger(__name__)
def render_taxonomic_tree_summary():
"""Render taxonomic tree summary for precalculated embeddings."""
df_plot = st.session_state.get("data", None)
labels = st.session_state.get("labels", None)
filtered_df = st.session_state.get("filtered_df_for_clustering", None)
if df_plot is not None and filtered_df is not None:
st.markdown("### Taxonomic Distribution")
# Detect available KMeans columns
kmeans_cols = sorted(
[c for c in df_plot.columns if c.startswith("KMeans (k=")],
key=lambda c: int(c.split("=")[1].rstrip(")"))
)
# Fallback for embed_explore app (has 'cluster' column directly)
has_embed_explore_cluster = 'cluster' in df_plot.columns and not kmeans_cols
# Add controls at the top of the taxonomy section
col1, col2, col3, col4 = st.columns([1.5, 1.5, 1, 1])
with col1:
if kmeans_cols:
# Precalculated app: let user pick which KMeans run
group_by = st.selectbox(
"Group by",
options=["(none)"] + kmeans_cols,
index=0,
key="taxonomy_group_by",
help="Select a KMeans result to filter taxonomy by cluster"
)
if group_by == "(none)":
group_by = None
elif has_embed_explore_cluster:
group_by = "cluster"
else:
group_by = None
with col2:
if group_by and group_by in df_plot.columns:
unique_clusters = sorted(df_plot[group_by].unique(), key=lambda x: int(x))
cluster_options = ["All"] + [str(c) for c in unique_clusters]
selected_cluster = st.selectbox(
"Cluster",
options=cluster_options,
index=0,
key="taxonomy_cluster_selector",
help="Select a specific cluster or 'All'"
)
else:
selected_cluster = "All"
with col2:
min_count = st.number_input(
"Minimum count",
min_value=1,
max_value=1000,
value=5,
step=1,
key="taxonomy_min_count",
help="Minimum number of records for a taxon to appear in the tree"
)
with col3:
tree_depth = st.slider(
"Tree depth",
min_value=1,
max_value=7,
value=7,
key="taxonomy_tree_depth",
help="Maximum depth of the taxonomy tree to display"
)
# Create a stable cache key based on the data characteristics and filter parameters
data_length = len(filtered_df)
# Use a stable string representation instead of hash for consistency
sample_uuids = filtered_df['uuid'].iloc[:min(10, len(filtered_df))].tolist()
data_id = f"{data_length}_{len(sample_uuids)}_{sample_uuids[0] if sample_uuids else 'empty'}"
cache_key = f"taxonomy_{data_id}_{group_by}_{selected_cluster}_{min_count}_{tree_depth}"
# Check if we have cached results and they're still valid
current_cache_key = st.session_state.get("taxonomy_cache_key")
cache_exists = cache_key in st.session_state
if (not cache_exists or current_cache_key != cache_key):
with st.spinner("Building taxonomy tree..."):
# Filter data based on group_by + selected_cluster
if group_by and selected_cluster != "All" and group_by in df_plot.columns:
cluster_mask = df_plot[group_by] == selected_cluster
cluster_uuids = df_plot[cluster_mask]['uuid'].tolist()
tree_df = filtered_df[filtered_df['uuid'].isin(cluster_uuids)]
display_title = f"Taxonomic Tree for {group_by} = {selected_cluster}"
else:
tree_df = filtered_df
display_title = "Taxonomic Tree for All Data"
# Build taxonomic tree for the selected data (only when needed)
tree = build_taxonomic_tree(tree_df)
stats = get_tree_statistics(tree)
tree_string = format_tree_string(tree, max_depth=tree_depth, min_count=min_count)
# Cache the results
st.session_state[cache_key] = {
'tree': tree,
'stats': stats,
'tree_string': tree_string,
'display_title': display_title
}
st.session_state["taxonomy_cache_key"] = cache_key
# Use cached results (no regeneration)
cached_data = st.session_state[cache_key]
# Show statistics
st.markdown(f"**{cached_data['display_title']}**")
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Total Records", f"{cached_data['stats']['total_records']:,}")
with col2:
st.metric("Kingdoms", cached_data['stats']['kingdoms'])
with col3:
st.metric("Families", cached_data['stats']['families'])
with col4:
st.metric("Species", cached_data['stats']['species'])
# Display the tree
if cached_data['tree_string']:
st.code(cached_data['tree_string'], language="text")
else:
st.info("No taxonomic data meets the display criteria. Try lowering the minimum count.")
def render_clustering_summary(show_taxonomy=False):
"""Render the clustering summary panel using cached results per KMeans run.
For the embed_explore app, when multiple KMeans runs exist on df_plot,
the user can pick which run's summary + representative images to display.
Summaries are cached per kmeans_col by `_run_kmeans` so switching is instant.
"""
df_plot = st.session_state.get("data", None)
if df_plot is None:
st.info("Summary will appear here after projection.")
return
has_images = 'image_path' in df_plot.columns
if has_images:
# embed_explore app: full clustering summary with representative images
kmeans_cols = sorted(
[c for c in df_plot.columns if c.startswith("KMeans (k=")],
key=lambda c: int(c.split("=")[1].rstrip(")")),
)
if not kmeans_cols:
st.subheader("Clustering Summary")
st.info("Run KMeans to see the clustering summary and representative images.")
return
summaries = st.session_state.get("clustering_summaries", {}) or {}
reps_by_col = st.session_state.get("clustering_representatives_by_col", {}) or {}
st.subheader("Clustering Summary")
default_idx = len(kmeans_cols) - 1 # most recent run
selected_kmeans_col = st.selectbox(
"KMeans result",
options=kmeans_cols,
index=default_idx,
key="summary_kmeans_selector",
help="Select which KMeans run to view summary + representative images for.",
)
summary_df = summaries.get(selected_kmeans_col)
representatives = reps_by_col.get(selected_kmeans_col)
if summary_df is None or representatives is None:
st.info(
f"No cached summary for {selected_kmeans_col}. "
"Re-run KMeans with this k to regenerate it."
)
return
logger.debug(f"Displaying cached clustering summary for {selected_kmeans_col}")
st.dataframe(summary_df, hide_index=True, width='stretch')
st.markdown("#### Representative Images")
def _resolve_local_image(idx):
"""Return the local image path if it exists, else None (fallback)."""
path = df_plot.iloc[idx]["image_path"]
if isinstance(path, str) and os.path.exists(path):
return path
return None
def _local_caption(idx):
path = df_plot.iloc[idx]["image_path"]
return os.path.basename(path) if isinstance(path, str) else None
render_representative_images(
representatives,
resolve_image=_resolve_local_image,
n_per_cluster=3,
caption_fn=_local_caption,
)
else:
# Precalculated app: show taxonomy tree (works with or without KMeans)
if show_taxonomy:
filtered_df = st.session_state.get("filtered_df_for_clustering", None)
if filtered_df is not None:
render_taxonomic_tree_summary()