core / pages /4_Data_Explorer.py
tensorus's picture
Upload 83 files
edfa748 verified
# pages/4_Data_Explorer.py
import streamlit as st
import pandas as pd
import plotly.express as px
from .ui_utils import list_datasets, fetch_dataset_data # MODIFIED
import torch # Needed if we want to recreate tensors for inspection/plotting
st.set_page_config(page_title="Data Explorer", layout="wide")
st.title("🔍 Interactive Data Explorer")
st.caption("Browse, filter, and visualize Tensorus datasets.")
# --- Dataset Selection ---
datasets = list_datasets()
if not datasets:
st.warning("No datasets found or API connection failed. Cannot explore data.")
st.stop() # Stop execution if no datasets
selected_dataset = st.selectbox("Select Dataset:", datasets)
# --- Data Fetching & Filtering ---
if selected_dataset:
st.subheader(f"Exploring: {selected_dataset}")
PAGE_SIZE = 20
page_key = f"page_{selected_dataset}"
if page_key not in st.session_state:
st.session_state[page_key] = 0
page = st.session_state[page_key]
offset = page * PAGE_SIZE
records = fetch_dataset_data(selected_dataset, offset=offset, limit=PAGE_SIZE)
prev_disabled = page == 0
next_disabled = records is None or len(records) < PAGE_SIZE
col_prev, col_next = st.columns(2)
with col_prev:
if st.button("Previous", disabled=prev_disabled, key="prev_btn"):
st.session_state[page_key] = max(0, page - 1)
st.experimental_rerun()
with col_next:
if st.button("Next", disabled=next_disabled, key="next_btn"):
st.session_state[page_key] = page + 1
st.experimental_rerun()
if records is None:
st.error("Failed to fetch data for the selected dataset.")
st.stop()
elif not records:
st.info("Selected dataset is empty.")
st.stop()
start_idx = offset + 1
end_idx = offset + len(records)
st.info(f"Displaying records {start_idx} - {end_idx} (page {page + 1})")
# Create DataFrame from metadata for filtering/display
metadata_list = [r['metadata'] for r in records]
df_meta = pd.DataFrame(metadata_list)
# --- Metadata Filtering UI ---
st.sidebar.header("Filter by Metadata")
filter_cols = st.sidebar.multiselect("Select metadata columns to filter:", options=df_meta.columns.tolist())
filtered_df = df_meta.copy()
for col in filter_cols:
unique_values = filtered_df[col].dropna().unique().tolist()
if pd.api.types.is_numeric_dtype(filtered_df[col]):
# Numeric filter (slider)
min_val, max_val = float(filtered_df[col].min()), float(filtered_df[col].max())
if min_val < max_val:
selected_range = st.sidebar.slider(f"Filter {col}:", min_val, max_val, (min_val, max_val))
filtered_df = filtered_df[filtered_df[col].between(selected_range[0], selected_range[1])]
else:
st.sidebar.caption(f"{col}: Single numeric value ({min_val}), no range filter.")
elif len(unique_values) > 0 and len(unique_values) <= 20: # Limit dropdown options
# Categorical filter (multiselect)
selected_values = st.sidebar.multiselect(f"Filter {col}:", options=unique_values, default=unique_values)
if selected_values: # Only filter if some values are selected
filtered_df = filtered_df[filtered_df[col].isin(selected_values)]
else: # If user deselects everything, show nothing
filtered_df = filtered_df[filtered_df[col].isnull()] # Hack to get empty DF matching columns
else:
st.sidebar.text_input(f"Filter {col} (Text contains):", key=f"text_{col}")
search_term = st.session_state.get(f"text_{col}", "").lower()
if search_term:
# Ensure column is string type before using .str.contains
filtered_df = filtered_df[filtered_df[col].astype(str).str.lower().str.contains(search_term, na=False)]
st.divider()
st.subheader("Filtered Data View")
st.write(f"{len(filtered_df)} records matching filters.")
st.dataframe(filtered_df, use_container_width=True)
# --- Tensor Preview & Visualization ---
st.divider()
st.subheader("Tensor Preview")
if not filtered_df.empty:
# Allow selecting a record ID from the filtered results
record_ids = filtered_df['record_id'].tolist()
selected_record_id = st.selectbox("Select Record ID to Preview Tensor:", record_ids)
if selected_record_id:
# Find the full record data corresponding to the selected ID
selected_record = next((r for r in records if r['metadata'].get('record_id') == selected_record_id), None)
if selected_record:
st.write("Metadata:")
st.json(selected_record['metadata'])
shape = selected_record.get("shape")
dtype = selected_record.get("dtype")
data_list = selected_record.get("data")
st.write(f"Tensor Info: Shape={shape}, Dtype={dtype}")
try:
# Recreate tensor for potential plotting/display
# Be careful with large tensors in Streamlit UI!
# We might only want to show info or small slices.
if shape and dtype and data_list is not None:
tensor = torch.tensor(data_list, dtype=getattr(torch, dtype, torch.float32)) # Use getattr for dtype
st.write("Tensor Data (first few elements):")
st.code(f"{tensor.flatten()[:10].numpy()}...") # Show flattened start
# --- Simple Visualizations ---
if tensor.ndim == 1 and tensor.numel() > 1:
st.line_chart(tensor.numpy())
elif tensor.ndim == 2 and tensor.shape[0] > 1 and tensor.shape[1] > 1 :
# Simple heatmap using plotly (requires plotly)
try:
fig = px.imshow(tensor.numpy(), title="Tensor Heatmap", aspect="auto")
st.plotly_chart(fig, use_container_width=True)
except Exception as plot_err:
st.warning(f"Could not generate heatmap: {plot_err}")
elif tensor.ndim == 3 and tensor.shape[0] in [1, 3]: # Basic image check (C, H, W) or (1, H, W)
try:
# Permute if needed (e.g., C, H, W -> H, W, C for display)
if tensor.shape[0] in [1, 3]:
display_tensor = tensor.permute(1, 2, 0).squeeze() # H, W, C or H, W
# Clamp/normalize data to display range [0, 1] or [0, 255] - basic attempt
display_tensor = (display_tensor - display_tensor.min()) / (display_tensor.max() - display_tensor.min() + 1e-6)
st.image(display_tensor.numpy(), caption="Tensor as Image (Attempted)", use_column_width=True)
except ImportError:
st.warning("Pillow needed for image display (`pip install Pillow`)")
except Exception as img_err:
st.warning(f"Could not display tensor as image: {img_err}")
else:
st.info("No specific visualization available for this tensor shape/dimension.")
else:
st.warning("Tensor data, shape, or dtype missing in the record.")
except Exception as tensor_err:
st.error(f"Error processing tensor data for preview: {tensor_err}")
else:
st.warning("Selected record details not found (this shouldn't happen).")
else:
st.info("Select a record ID above to preview its tensor.")
else:
st.info("No records match the current filters.")