""" Shared clustering summary components. """ import streamlit as st import os import pandas as pd from shared.utils.taxonomy_tree import build_taxonomic_tree, format_tree_string, get_tree_statistics from shared.components.representatives import render_representative_images from shared.utils.logging_config import get_logger logger = get_logger(__name__) def render_taxonomic_tree_summary(): """Render taxonomic tree summary for precalculated embeddings.""" df_plot = st.session_state.get("data", None) labels = st.session_state.get("labels", None) filtered_df = st.session_state.get("filtered_df_for_clustering", None) if df_plot is not None and filtered_df is not None: st.markdown("### Taxonomic Distribution") # Detect available KMeans columns kmeans_cols = sorted( [c for c in df_plot.columns if c.startswith("KMeans (k=")], key=lambda c: int(c.split("=")[1].rstrip(")")) ) # Fallback for embed_explore app (has 'cluster' column directly) has_embed_explore_cluster = 'cluster' in df_plot.columns and not kmeans_cols # Add controls at the top of the taxonomy section col1, col2, col3, col4 = st.columns([1.5, 1.5, 1, 1]) with col1: if kmeans_cols: # Precalculated app: let user pick which KMeans run group_by = st.selectbox( "Group by", options=["(none)"] + kmeans_cols, index=0, key="taxonomy_group_by", help="Select a KMeans result to filter taxonomy by cluster" ) if group_by == "(none)": group_by = None elif has_embed_explore_cluster: group_by = "cluster" else: group_by = None with col2: if group_by and group_by in df_plot.columns: unique_clusters = sorted(df_plot[group_by].unique(), key=lambda x: int(x)) cluster_options = ["All"] + [str(c) for c in unique_clusters] selected_cluster = st.selectbox( "Cluster", options=cluster_options, index=0, key="taxonomy_cluster_selector", help="Select a specific cluster or 'All'" ) else: selected_cluster = "All" with col2: min_count = st.number_input( "Minimum count", min_value=1, max_value=1000, value=5, step=1, key="taxonomy_min_count", help="Minimum number of records for a taxon to appear in the tree" ) with col3: tree_depth = st.slider( "Tree depth", min_value=1, max_value=7, value=7, key="taxonomy_tree_depth", help="Maximum depth of the taxonomy tree to display" ) # Create a stable cache key based on the data characteristics and filter parameters data_length = len(filtered_df) # Use a stable string representation instead of hash for consistency sample_uuids = filtered_df['uuid'].iloc[:min(10, len(filtered_df))].tolist() data_id = f"{data_length}_{len(sample_uuids)}_{sample_uuids[0] if sample_uuids else 'empty'}" cache_key = f"taxonomy_{data_id}_{group_by}_{selected_cluster}_{min_count}_{tree_depth}" # Check if we have cached results and they're still valid current_cache_key = st.session_state.get("taxonomy_cache_key") cache_exists = cache_key in st.session_state if (not cache_exists or current_cache_key != cache_key): with st.spinner("Building taxonomy tree..."): # Filter data based on group_by + selected_cluster if group_by and selected_cluster != "All" and group_by in df_plot.columns: cluster_mask = df_plot[group_by] == selected_cluster cluster_uuids = df_plot[cluster_mask]['uuid'].tolist() tree_df = filtered_df[filtered_df['uuid'].isin(cluster_uuids)] display_title = f"Taxonomic Tree for {group_by} = {selected_cluster}" else: tree_df = filtered_df display_title = "Taxonomic Tree for All Data" # Build taxonomic tree for the selected data (only when needed) tree = build_taxonomic_tree(tree_df) stats = get_tree_statistics(tree) tree_string = format_tree_string(tree, max_depth=tree_depth, min_count=min_count) # Cache the results st.session_state[cache_key] = { 'tree': tree, 'stats': stats, 'tree_string': tree_string, 'display_title': display_title } st.session_state["taxonomy_cache_key"] = cache_key # Use cached results (no regeneration) cached_data = st.session_state[cache_key] # Show statistics st.markdown(f"**{cached_data['display_title']}**") col1, col2, col3, col4 = st.columns(4) with col1: st.metric("Total Records", f"{cached_data['stats']['total_records']:,}") with col2: st.metric("Kingdoms", cached_data['stats']['kingdoms']) with col3: st.metric("Families", cached_data['stats']['families']) with col4: st.metric("Species", cached_data['stats']['species']) # Display the tree if cached_data['tree_string']: st.code(cached_data['tree_string'], language="text") else: st.info("No taxonomic data meets the display criteria. Try lowering the minimum count.") def render_clustering_summary(show_taxonomy=False): """Render the clustering summary panel using cached results per KMeans run. For the embed_explore app, when multiple KMeans runs exist on df_plot, the user can pick which run's summary + representative images to display. Summaries are cached per kmeans_col by `_run_kmeans` so switching is instant. """ df_plot = st.session_state.get("data", None) if df_plot is None: st.info("Summary will appear here after projection.") return has_images = 'image_path' in df_plot.columns if has_images: # embed_explore app: full clustering summary with representative images kmeans_cols = sorted( [c for c in df_plot.columns if c.startswith("KMeans (k=")], key=lambda c: int(c.split("=")[1].rstrip(")")), ) if not kmeans_cols: st.subheader("Clustering Summary") st.info("Run KMeans to see the clustering summary and representative images.") return summaries = st.session_state.get("clustering_summaries", {}) or {} reps_by_col = st.session_state.get("clustering_representatives_by_col", {}) or {} st.subheader("Clustering Summary") default_idx = len(kmeans_cols) - 1 # most recent run selected_kmeans_col = st.selectbox( "KMeans result", options=kmeans_cols, index=default_idx, key="summary_kmeans_selector", help="Select which KMeans run to view summary + representative images for.", ) summary_df = summaries.get(selected_kmeans_col) representatives = reps_by_col.get(selected_kmeans_col) if summary_df is None or representatives is None: st.info( f"No cached summary for {selected_kmeans_col}. " "Re-run KMeans with this k to regenerate it." ) return logger.debug(f"Displaying cached clustering summary for {selected_kmeans_col}") st.dataframe(summary_df, hide_index=True, width='stretch') st.markdown("#### Representative Images") def _resolve_local_image(idx): """Return the local image path if it exists, else None (fallback).""" path = df_plot.iloc[idx]["image_path"] if isinstance(path, str) and os.path.exists(path): return path return None def _local_caption(idx): path = df_plot.iloc[idx]["image_path"] return os.path.basename(path) if isinstance(path, str) else None render_representative_images( representatives, resolve_image=_resolve_local_image, n_per_cluster=3, caption_fn=_local_caption, ) else: # Precalculated app: show taxonomy tree (works with or without KMeans) if show_taxonomy: filtered_df = st.session_state.get("filtered_df_for_clustering", None) if filtered_df is not None: render_taxonomic_tree_summary()