File size: 10,090 Bytes
51545af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269ea1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51545af
269ea1f
51545af
 
269ea1f
 
 
 
 
 
 
 
 
 
 
 
 
 
51545af
 
 
 
 
 
 
 
269ea1f
 
 
 
 
51545af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
"""
Shared visualization components for scatter plots.
"""

import streamlit as st
import altair as alt

from shared.utils.logging_config import get_logger

logger = get_logger(__name__)


def render_scatter_plot():
    """Render the main clustering scatter plot with dynamic tooltips.

    The chart is rendered inside a @st.fragment so that zoom/pan interactions
    only rerun the chart itself — the rest of the page (data preview, summary)
    stays untouched.  A full page rerun is triggered explicitly only when the
    user clicks a *different* point or changes the "Color by" column.
    """
    df_plot = st.session_state.get("data", None)

    if df_plot is not None and len(df_plot) > 1:
        _render_chart_fragment(df_plot)
    else:
        # Detect app type for appropriate message
        is_precalculated = st.session_state.get("page_type") == "precalculated_app"
        if is_precalculated:
            st.info("Run projection to see the scatter plot.")
        else:
            st.info("Run clustering to see the cluster scatter plot.")
        st.session_state['selected_image_idx'] = None


@st.fragment
def _render_chart_fragment(df_plot):
    """Fragment-isolated chart rendering — zoom/pan do NOT rerun the page."""
    # Track previous density mode to detect changes
    prev_density_mode = st.session_state.get("_prev_density_mode", None)

    # Detect app type: precalculated has uuid but no image_path
    is_precalculated = 'uuid' in df_plot.columns and 'image_path' not in df_plot.columns

    # Plot options
    opt_col1, opt_col2 = st.columns([2, 1])

    with opt_col1:
        density_mode = st.radio(
            "Density visualization",
            options=["Off", "Opacity", "Heatmap"],
            index=0,
            horizontal=True,
            key="density_mode",
            help="Off: normal view | Opacity: lower opacity to show overlap | Heatmap: 2D binned density (disables selection)"
        )

    # Log density mode change
    if prev_density_mode != density_mode:
        logger.info(f"[Visualization] Density mode changed: {prev_density_mode} -> {density_mode}")
        st.session_state["_prev_density_mode"] = density_mode

    with opt_col2:
        if density_mode == "Heatmap":
            prev_bins = st.session_state.get("_prev_heatmap_bins", 40)
            heatmap_bins = st.slider(
                "Grid resolution",
                min_value=10,
                max_value=80,
                value=40,
                step=5,
                key="heatmap_bins",
                help="Number of bins for density grid (higher = finer detail)"
            )
            if prev_bins != heatmap_bins:
                logger.info(f"[Visualization] Heatmap bins changed: {prev_bins} -> {heatmap_bins}")
                st.session_state["_prev_heatmap_bins"] = heatmap_bins
        else:
            heatmap_bins = 40  # Default, not used

    # Determine color column — same dropdown pattern for both apps.
    # Build list of colorable columns (skip technical/identifier columns).
    skip_color_cols = {'x', 'y', 'idx', 'uuid', 'emb', 'embedding', 'embeddings', 'vector',
                       'identifier', 'image_url', 'url', 'img_url', 'image',
                       'image_path', 'file_name'}
    colorable_cols = [c for c in df_plot.columns
                      if c not in skip_color_cols and df_plot[c].nunique() <= 100]

    # Sort KMeans columns to front (all runs, sorted by k)
    kmeans_cols = sorted(
        [c for c in colorable_cols if c.startswith("KMeans (k=")],
        key=lambda c: int(c.split("=")[1].rstrip(")"))
    )
    other_cols = [c for c in colorable_cols if not c.startswith("KMeans (k=")]
    colorable_cols = kmeans_cols + other_cols

    # Build unique count lookup for display
    col_nunique = {c: df_plot[c].nunique() for c in colorable_cols}

    if colorable_cols:
        color_col = st.selectbox(
            "Color by",
            options=["(none)"] + colorable_cols,
            index=0,
            key="color_by_column",
            format_func=lambda c: c if c == "(none)" else f"{c} ({col_nunique[c]})",
            help="Select a column to color the points by"
        )
        if color_col == "(none)":
            color_col = None
    else:
        color_col = None

    # Warning for high cardinality
    if color_col and df_plot[color_col].nunique() > 20:
        st.warning(f"'{color_col}' has {df_plot[color_col].nunique()} unique values. Colors may repeat.")

    # Trigger full page rerun when color changes (so bottom section updates).
    # Use a sentinel to distinguish "never set" from "set to None".
    _sentinel = object()
    prev_color = st.session_state.get("_prev_color_by", _sentinel)
    if color_col != prev_color:
        st.session_state["_prev_color_by"] = color_col
        if prev_color is not _sentinel:
            st.rerun(scope="app")

    point_selector = alt.selection_point(fields=["idx"], name="point_selection")

    # Build tooltip fields
    tooltip_fields = []
    skip_cols = {'x', 'y', 'idx', 'emb', 'embedding', 'embeddings', 'vector',
                 'uuid', 'identifier', 'image_url', 'url', 'img_url', 'image'}

    # For embed_explore, include the file_name in the tooltip for quick reference
    if not is_precalculated and 'file_name' in df_plot.columns:
        tooltip_fields.append('file_name:N')
        skip_cols.add('file_name')
    skip_cols.add('image_path')

    # Add the color column first if set (and not already in tooltip)
    if color_col and color_col not in skip_cols:
        tooltip_fields.append(f'{color_col}:N')
        skip_cols.add(color_col)

    # Add remaining metadata columns
    metadata_cols = [c for c in df_plot.columns if c not in skip_cols][:15]
    tooltip_fields.extend(metadata_cols)

    # Title
    if is_precalculated:
        title = "Embedding Space (click a point to view details)"
    else:
        title = "Image Clusters (click a point to preview image)"

    # Set opacity based on density mode
    if density_mode == "Opacity":
        point_opacity = 0.15
    elif density_mode == "Heatmap":
        point_opacity = 0.5
    else:
        point_opacity = 0.7

    # Build chart
    if color_col:
        # Sort legend: numeric for KMeans labels, alphabetical for strings
        unique_vals = df_plot[color_col].unique()
        try:
            sorted_vals = sorted(unique_vals, key=int)
        except (ValueError, TypeError):
            sorted_vals = sorted(unique_vals, key=str)

        scatter = (
            alt.Chart(df_plot)
            .mark_circle(size=60, opacity=point_opacity)
            .encode(
                x=alt.X('x:Q', scale=alt.Scale(zero=False)),
                y=alt.Y('y:Q', scale=alt.Scale(zero=False)),
                color=alt.Color(
                    f'{color_col}:N',
                    legend=alt.Legend(title=color_col),
                    sort=sorted_vals,
                    scale=alt.Scale(scheme='tableau20')
                ),
                tooltip=tooltip_fields,
                fillOpacity=alt.condition(point_selector, alt.value(1), alt.value(0.3))
            )
            .add_params(point_selector)
        )
    else:
        # No color column: all points same color
        scatter = (
            alt.Chart(df_plot)
            .mark_circle(size=60, opacity=point_opacity)
            .encode(
                x=alt.X('x:Q', scale=alt.Scale(zero=False)),
                y=alt.Y('y:Q', scale=alt.Scale(zero=False)),
                tooltip=tooltip_fields,
                fillOpacity=alt.condition(point_selector, alt.value(1), alt.value(0.3))
            )
            .add_params(point_selector)
        )

    if density_mode == "Heatmap":
        density = (
            alt.Chart(df_plot)
            .mark_rect(opacity=0.4)
            .encode(
                x=alt.X('x:Q', bin=alt.Bin(maxbins=heatmap_bins), scale=alt.Scale(zero=False)),
                y=alt.Y('y:Q', bin=alt.Bin(maxbins=heatmap_bins), scale=alt.Scale(zero=False)),
                color=alt.Color(
                    'count():Q',
                    scale=alt.Scale(scheme='blues'),
                    legend=None
                )
            )
        )
        chart = alt.layer(density, scatter)
    else:
        chart = scatter

    # Apply common properties and interactivity
    title_suffix = " (scroll to zoom, drag to pan)"
    if density_mode != "Heatmap":
        title_suffix += ", click to select"

    chart = (
        chart
        .properties(
            width=800,
            height=700,
            title=title + title_suffix
        )
        .interactive()
    )

    logger.debug(f"[Visualization] Rendering chart: {len(df_plot)} points, density={density_mode}, "
                 f"color={color_col or 'none'}")

    # Include data_version in key so zoom/pan resets when projection changes
    data_version = st.session_state.get("data_version", "")
    chart_key = f"alt_chart_{data_version}"

    if density_mode == "Heatmap":
        st.altair_chart(chart, key=chart_key, width="stretch")
        st.caption("Note: Point selection is disabled when heatmap is shown.")
    else:
        event = st.altair_chart(chart, key=chart_key, on_select="rerun", width="stretch")

        if (
            event
            and "selection" in event
            and "point_selection" in event["selection"]
            and event["selection"]["point_selection"]
        ):
            new_idx = int(event["selection"]["point_selection"][0]["idx"])
            prev_idx = st.session_state.get("selected_image_idx")
            if prev_idx != new_idx:
                label = ''
                if color_col and color_col in df_plot.columns:
                    label = f", {color_col}={df_plot.iloc[new_idx][color_col]}"
                logger.info(f"[Visualization] Point selected: idx={new_idx}{label}")
                st.session_state["selected_image_idx"] = new_idx
                st.session_state["selection_data_version"] = st.session_state.get("data_version", None)
                st.rerun(scope="app")