Spaces:
Running
Running
| # pages/4_Data_Explorer.py | |
| import streamlit as st | |
| import pandas as pd | |
| import plotly.express as px | |
| from .ui_utils import list_datasets, fetch_dataset_data # MODIFIED | |
| import torch # Needed if we want to recreate tensors for inspection/plotting | |
| st.set_page_config(page_title="Data Explorer", layout="wide") | |
| st.title("🔍 Interactive Data Explorer") | |
| st.caption("Browse, filter, and visualize Tensorus datasets.") | |
| # --- Dataset Selection --- | |
| datasets = list_datasets() | |
| if not datasets: | |
| st.warning("No datasets found or API connection failed. Cannot explore data.") | |
| st.stop() # Stop execution if no datasets | |
| selected_dataset = st.selectbox("Select Dataset:", datasets) | |
| # --- Data Fetching & Filtering --- | |
| if selected_dataset: | |
| st.subheader(f"Exploring: {selected_dataset}") | |
| PAGE_SIZE = 20 | |
| page_key = f"page_{selected_dataset}" | |
| if page_key not in st.session_state: | |
| st.session_state[page_key] = 0 | |
| page = st.session_state[page_key] | |
| offset = page * PAGE_SIZE | |
| records = fetch_dataset_data(selected_dataset, offset=offset, limit=PAGE_SIZE) | |
| prev_disabled = page == 0 | |
| next_disabled = records is None or len(records) < PAGE_SIZE | |
| col_prev, col_next = st.columns(2) | |
| with col_prev: | |
| if st.button("Previous", disabled=prev_disabled, key="prev_btn"): | |
| st.session_state[page_key] = max(0, page - 1) | |
| st.experimental_rerun() | |
| with col_next: | |
| if st.button("Next", disabled=next_disabled, key="next_btn"): | |
| st.session_state[page_key] = page + 1 | |
| st.experimental_rerun() | |
| if records is None: | |
| st.error("Failed to fetch data for the selected dataset.") | |
| st.stop() | |
| elif not records: | |
| st.info("Selected dataset is empty.") | |
| st.stop() | |
| start_idx = offset + 1 | |
| end_idx = offset + len(records) | |
| st.info(f"Displaying records {start_idx} - {end_idx} (page {page + 1})") | |
| # Create DataFrame from metadata for filtering/display | |
| metadata_list = [r['metadata'] for r in records] | |
| df_meta = pd.DataFrame(metadata_list) | |
| # --- Metadata Filtering UI --- | |
| st.sidebar.header("Filter by Metadata") | |
| filter_cols = st.sidebar.multiselect("Select metadata columns to filter:", options=df_meta.columns.tolist()) | |
| filtered_df = df_meta.copy() | |
| for col in filter_cols: | |
| unique_values = filtered_df[col].dropna().unique().tolist() | |
| if pd.api.types.is_numeric_dtype(filtered_df[col]): | |
| # Numeric filter (slider) | |
| min_val, max_val = float(filtered_df[col].min()), float(filtered_df[col].max()) | |
| if min_val < max_val: | |
| selected_range = st.sidebar.slider(f"Filter {col}:", min_val, max_val, (min_val, max_val)) | |
| filtered_df = filtered_df[filtered_df[col].between(selected_range[0], selected_range[1])] | |
| else: | |
| st.sidebar.caption(f"{col}: Single numeric value ({min_val}), no range filter.") | |
| elif len(unique_values) > 0 and len(unique_values) <= 20: # Limit dropdown options | |
| # Categorical filter (multiselect) | |
| selected_values = st.sidebar.multiselect(f"Filter {col}:", options=unique_values, default=unique_values) | |
| if selected_values: # Only filter if some values are selected | |
| filtered_df = filtered_df[filtered_df[col].isin(selected_values)] | |
| else: # If user deselects everything, show nothing | |
| filtered_df = filtered_df[filtered_df[col].isnull()] # Hack to get empty DF matching columns | |
| else: | |
| st.sidebar.text_input(f"Filter {col} (Text contains):", key=f"text_{col}") | |
| search_term = st.session_state.get(f"text_{col}", "").lower() | |
| if search_term: | |
| # Ensure column is string type before using .str.contains | |
| filtered_df = filtered_df[filtered_df[col].astype(str).str.lower().str.contains(search_term, na=False)] | |
| st.divider() | |
| st.subheader("Filtered Data View") | |
| st.write(f"{len(filtered_df)} records matching filters.") | |
| st.dataframe(filtered_df, use_container_width=True) | |
| # --- Tensor Preview & Visualization --- | |
| st.divider() | |
| st.subheader("Tensor Preview") | |
| if not filtered_df.empty: | |
| # Allow selecting a record ID from the filtered results | |
| record_ids = filtered_df['record_id'].tolist() | |
| selected_record_id = st.selectbox("Select Record ID to Preview Tensor:", record_ids) | |
| if selected_record_id: | |
| # Find the full record data corresponding to the selected ID | |
| selected_record = next((r for r in records if r['metadata'].get('record_id') == selected_record_id), None) | |
| if selected_record: | |
| st.write("Metadata:") | |
| st.json(selected_record['metadata']) | |
| shape = selected_record.get("shape") | |
| dtype = selected_record.get("dtype") | |
| data_list = selected_record.get("data") | |
| st.write(f"Tensor Info: Shape={shape}, Dtype={dtype}") | |
| try: | |
| # Recreate tensor for potential plotting/display | |
| # Be careful with large tensors in Streamlit UI! | |
| # We might only want to show info or small slices. | |
| if shape and dtype and data_list is not None: | |
| tensor = torch.tensor(data_list, dtype=getattr(torch, dtype, torch.float32)) # Use getattr for dtype | |
| st.write("Tensor Data (first few elements):") | |
| st.code(f"{tensor.flatten()[:10].numpy()}...") # Show flattened start | |
| # --- Simple Visualizations --- | |
| if tensor.ndim == 1 and tensor.numel() > 1: | |
| st.line_chart(tensor.numpy()) | |
| elif tensor.ndim == 2 and tensor.shape[0] > 1 and tensor.shape[1] > 1 : | |
| # Simple heatmap using plotly (requires plotly) | |
| try: | |
| fig = px.imshow(tensor.numpy(), title="Tensor Heatmap", aspect="auto") | |
| st.plotly_chart(fig, use_container_width=True) | |
| except Exception as plot_err: | |
| st.warning(f"Could not generate heatmap: {plot_err}") | |
| elif tensor.ndim == 3 and tensor.shape[0] in [1, 3]: # Basic image check (C, H, W) or (1, H, W) | |
| try: | |
| # Permute if needed (e.g., C, H, W -> H, W, C for display) | |
| if tensor.shape[0] in [1, 3]: | |
| display_tensor = tensor.permute(1, 2, 0).squeeze() # H, W, C or H, W | |
| # Clamp/normalize data to display range [0, 1] or [0, 255] - basic attempt | |
| display_tensor = (display_tensor - display_tensor.min()) / (display_tensor.max() - display_tensor.min() + 1e-6) | |
| st.image(display_tensor.numpy(), caption="Tensor as Image (Attempted)", use_column_width=True) | |
| except ImportError: | |
| st.warning("Pillow needed for image display (`pip install Pillow`)") | |
| except Exception as img_err: | |
| st.warning(f"Could not display tensor as image: {img_err}") | |
| else: | |
| st.info("No specific visualization available for this tensor shape/dimension.") | |
| else: | |
| st.warning("Tensor data, shape, or dtype missing in the record.") | |
| except Exception as tensor_err: | |
| st.error(f"Error processing tensor data for preview: {tensor_err}") | |
| else: | |
| st.warning("Selected record details not found (this shouldn't happen).") | |
| else: | |
| st.info("Select a record ID above to preview its tensor.") | |
| else: | |
| st.info("No records match the current filters.") | |