similarity_analysis / gui /analysis_tab.py
DanJChong's picture
Upload folder using huggingface_hub
e0ee7d2 verified
# ==================== gui/analysis_tab.py ====================
"""Analysis & Plots tab components"""
import gradio as gr
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from similarity_analysis.app import SimilarityApp
class AnalysisTab:
"""Handles the Analysis & Plots tab"""
def __init__(self, app: 'SimilarityApp'):
self.app = app
def create_tab(self, brain_options, ml_options) -> dict:
"""Create the Analysis & Plots tab"""
# Find first non-header option as default
default_brain = next((opt[1] for opt in brain_options if not (isinstance(opt[1], str) and opt[1].startswith('header'))), None)
default_ml = next((opt[1] for opt in ml_options if not (isinstance(opt[1], str) and opt[1].startswith('header'))), None)
with gr.Row():
# Controls panel
with gr.Column(scale=1):
gr.Markdown("### Controls")
brain_dropdown = gr.Dropdown(
choices=brain_options,
value=default_brain,
label="Brain Response Type",
info=""
)
ml_dropdown = gr.Dropdown(
choices=ml_options,
value=default_ml,
label="ML Model",
info=""
)
normalize_checkbox = gr.Checkbox(
label="Use Normalized Values (0-1)",
value=False,
info="Normalize all axes to 0-1 range"
)
update_btn = gr.Button("Update Plots", variant="primary")
stats_display = gr.Markdown("Select parameters and click 'Update Plots'")
# 3D visualization
with gr.Column(scale=2):
gr.Markdown("### 3D Visualization")
plot_3d = gr.Plot(label="3D Scatter Plot")
# 2D plots section
with gr.Row():
with gr.Column():
gr.Markdown("### 2D Pairwise Comparisons")
plot_2d = gr.Plot(label="2D Scatter Plots", show_label=False)
# Corner distribution section
with gr.Row():
with gr.Column():
gr.Markdown("### Corner Distribution Analysis")
gr.Markdown("Shows how many image pairs are closest to each corner of the 3D space (Human × Brain × ML)")
corner_plot = gr.Plot(label="Corner Distribution")
corner_stats = gr.HTML("<div>Click 'Update Plots' to see corner distribution</div>")
return {
'brain_dropdown': brain_dropdown,
'ml_dropdown': ml_dropdown,
'normalize_checkbox': normalize_checkbox,
'update_btn': update_btn,
'stats_display': stats_display,
'plot_3d': plot_3d,
'plot_2d': plot_2d,
'corner_plot': corner_plot,
'corner_stats': corner_stats
}
def connect_events(self, components):
"""Connect event handlers for this tab"""
def update_all(brain_measure, ml_model_selection, normalize):
if ml_model_selection == "separator":
return None, None, "Please select a valid model or average option", None, ""
fig_3d, fig_2d = self.app.update_plots(brain_measure, ml_model_selection, normalize)
stats = self.app.get_correlations(brain_measure, ml_model_selection)
corner_fig, corner_html = self.app.get_corner_distribution(brain_measure, ml_model_selection)
return fig_3d, fig_2d, stats, corner_fig, corner_html
# Connect main plot updates
components['update_btn'].click(
fn=update_all,
inputs=[components['brain_dropdown'], components['ml_dropdown'], components['normalize_checkbox']],
outputs=[components['plot_3d'], components['plot_2d'], components['stats_display'],
components['corner_plot'], components['corner_stats']]
)
components['brain_dropdown'].change(
fn=update_all,
inputs=[components['brain_dropdown'], components['ml_dropdown'], components['normalize_checkbox']],
outputs=[components['plot_3d'], components['plot_2d'], components['stats_display'],
components['corner_plot'], components['corner_stats']]
)
components['ml_dropdown'].change(
fn=update_all,
inputs=[components['brain_dropdown'], components['ml_dropdown'], components['normalize_checkbox']],
outputs=[components['plot_3d'], components['plot_2d'], components['stats_display'],
components['corner_plot'], components['corner_stats']]
)
components['normalize_checkbox'].change(
fn=update_all,
inputs=[components['brain_dropdown'], components['ml_dropdown'], components['normalize_checkbox']],
outputs=[components['plot_3d'], components['plot_2d'], components['stats_display'],
components['corner_plot'], components['corner_stats']]
)