| """ |
| Shared clustering summary components. |
| """ |
|
|
| import streamlit as st |
| import os |
| import pandas as pd |
| from shared.utils.taxonomy_tree import build_taxonomic_tree, format_tree_string, get_tree_statistics |
| from shared.components.representatives import render_representative_images |
| from shared.utils.logging_config import get_logger |
|
|
| logger = get_logger(__name__) |
|
|
|
|
| def render_taxonomic_tree_summary(): |
| """Render taxonomic tree summary for precalculated embeddings.""" |
| df_plot = st.session_state.get("data", None) |
| labels = st.session_state.get("labels", None) |
| filtered_df = st.session_state.get("filtered_df_for_clustering", None) |
|
|
| if df_plot is not None and filtered_df is not None: |
| st.markdown("### Taxonomic Distribution") |
|
|
| |
| kmeans_cols = sorted( |
| [c for c in df_plot.columns if c.startswith("KMeans (k=")], |
| key=lambda c: int(c.split("=")[1].rstrip(")")) |
| ) |
| |
| has_embed_explore_cluster = 'cluster' in df_plot.columns and not kmeans_cols |
|
|
| |
| col1, col2, col3, col4 = st.columns([1.5, 1.5, 1, 1]) |
|
|
| with col1: |
| if kmeans_cols: |
| |
| group_by = st.selectbox( |
| "Group by", |
| options=["(none)"] + kmeans_cols, |
| index=0, |
| key="taxonomy_group_by", |
| help="Select a KMeans result to filter taxonomy by cluster" |
| ) |
| if group_by == "(none)": |
| group_by = None |
| elif has_embed_explore_cluster: |
| group_by = "cluster" |
| else: |
| group_by = None |
|
|
| with col2: |
| if group_by and group_by in df_plot.columns: |
| unique_clusters = sorted(df_plot[group_by].unique(), key=lambda x: int(x)) |
| cluster_options = ["All"] + [str(c) for c in unique_clusters] |
| selected_cluster = st.selectbox( |
| "Cluster", |
| options=cluster_options, |
| index=0, |
| key="taxonomy_cluster_selector", |
| help="Select a specific cluster or 'All'" |
| ) |
| else: |
| selected_cluster = "All" |
|
|
| with col2: |
| min_count = st.number_input( |
| "Minimum count", |
| min_value=1, |
| max_value=1000, |
| value=5, |
| step=1, |
| key="taxonomy_min_count", |
| help="Minimum number of records for a taxon to appear in the tree" |
| ) |
|
|
| with col3: |
| tree_depth = st.slider( |
| "Tree depth", |
| min_value=1, |
| max_value=7, |
| value=7, |
| key="taxonomy_tree_depth", |
| help="Maximum depth of the taxonomy tree to display" |
| ) |
|
|
| |
| data_length = len(filtered_df) |
| |
| sample_uuids = filtered_df['uuid'].iloc[:min(10, len(filtered_df))].tolist() |
| data_id = f"{data_length}_{len(sample_uuids)}_{sample_uuids[0] if sample_uuids else 'empty'}" |
| cache_key = f"taxonomy_{data_id}_{group_by}_{selected_cluster}_{min_count}_{tree_depth}" |
|
|
| |
| current_cache_key = st.session_state.get("taxonomy_cache_key") |
| cache_exists = cache_key in st.session_state |
|
|
| if (not cache_exists or current_cache_key != cache_key): |
|
|
| with st.spinner("Building taxonomy tree..."): |
| |
| if group_by and selected_cluster != "All" and group_by in df_plot.columns: |
| cluster_mask = df_plot[group_by] == selected_cluster |
| cluster_uuids = df_plot[cluster_mask]['uuid'].tolist() |
| tree_df = filtered_df[filtered_df['uuid'].isin(cluster_uuids)] |
| display_title = f"Taxonomic Tree for {group_by} = {selected_cluster}" |
| else: |
| tree_df = filtered_df |
| display_title = "Taxonomic Tree for All Data" |
|
|
| |
| tree = build_taxonomic_tree(tree_df) |
| stats = get_tree_statistics(tree) |
| tree_string = format_tree_string(tree, max_depth=tree_depth, min_count=min_count) |
|
|
| |
| st.session_state[cache_key] = { |
| 'tree': tree, |
| 'stats': stats, |
| 'tree_string': tree_string, |
| 'display_title': display_title |
| } |
| st.session_state["taxonomy_cache_key"] = cache_key |
|
|
| |
| cached_data = st.session_state[cache_key] |
|
|
| |
| st.markdown(f"**{cached_data['display_title']}**") |
| col1, col2, col3, col4 = st.columns(4) |
| with col1: |
| st.metric("Total Records", f"{cached_data['stats']['total_records']:,}") |
| with col2: |
| st.metric("Kingdoms", cached_data['stats']['kingdoms']) |
| with col3: |
| st.metric("Families", cached_data['stats']['families']) |
| with col4: |
| st.metric("Species", cached_data['stats']['species']) |
|
|
| |
| if cached_data['tree_string']: |
| st.code(cached_data['tree_string'], language="text") |
| else: |
| st.info("No taxonomic data meets the display criteria. Try lowering the minimum count.") |
|
|
|
|
| def render_clustering_summary(show_taxonomy=False): |
| """Render the clustering summary panel using cached results per KMeans run. |
| |
| For the embed_explore app, when multiple KMeans runs exist on df_plot, |
| the user can pick which run's summary + representative images to display. |
| Summaries are cached per kmeans_col by `_run_kmeans` so switching is instant. |
| """ |
| df_plot = st.session_state.get("data", None) |
|
|
| if df_plot is None: |
| st.info("Summary will appear here after projection.") |
| return |
|
|
| has_images = 'image_path' in df_plot.columns |
|
|
| if has_images: |
| |
| kmeans_cols = sorted( |
| [c for c in df_plot.columns if c.startswith("KMeans (k=")], |
| key=lambda c: int(c.split("=")[1].rstrip(")")), |
| ) |
|
|
| if not kmeans_cols: |
| st.subheader("Clustering Summary") |
| st.info("Run KMeans to see the clustering summary and representative images.") |
| return |
|
|
| summaries = st.session_state.get("clustering_summaries", {}) or {} |
| reps_by_col = st.session_state.get("clustering_representatives_by_col", {}) or {} |
|
|
| st.subheader("Clustering Summary") |
| default_idx = len(kmeans_cols) - 1 |
| selected_kmeans_col = st.selectbox( |
| "KMeans result", |
| options=kmeans_cols, |
| index=default_idx, |
| key="summary_kmeans_selector", |
| help="Select which KMeans run to view summary + representative images for.", |
| ) |
|
|
| summary_df = summaries.get(selected_kmeans_col) |
| representatives = reps_by_col.get(selected_kmeans_col) |
|
|
| if summary_df is None or representatives is None: |
| st.info( |
| f"No cached summary for {selected_kmeans_col}. " |
| "Re-run KMeans with this k to regenerate it." |
| ) |
| return |
|
|
| logger.debug(f"Displaying cached clustering summary for {selected_kmeans_col}") |
| st.dataframe(summary_df, hide_index=True, width='stretch') |
|
|
| st.markdown("#### Representative Images") |
|
|
| def _resolve_local_image(idx): |
| """Return the local image path if it exists, else None (fallback).""" |
| path = df_plot.iloc[idx]["image_path"] |
| if isinstance(path, str) and os.path.exists(path): |
| return path |
| return None |
|
|
| def _local_caption(idx): |
| path = df_plot.iloc[idx]["image_path"] |
| return os.path.basename(path) if isinstance(path, str) else None |
|
|
| render_representative_images( |
| representatives, |
| resolve_image=_resolve_local_image, |
| n_per_cluster=3, |
| caption_fn=_local_caption, |
| ) |
| else: |
| |
| if show_taxonomy: |
| filtered_df = st.session_state.get("filtered_df_for_clustering", None) |
| if filtered_df is not None: |
| render_taxonomic_tree_summary() |
|
|