similarity_analysis / gui /corner_cases_tab.py
DanJChong's picture
Upload folder using huggingface_hub
e0ee7d2 verified
# ==================== gui/corner_cases_tab.py ====================
"""Corner Cases tab components"""
import gradio as gr
from typing import TYPE_CHECKING
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg') # Use non-interactive backend
import io
import base64
if TYPE_CHECKING:
from similarity_analysis.app import SimilarityApp
class CornerCasesTab:
"""Handles the Corner Cases tab"""
def __init__(self, app: 'SimilarityApp'):
self.app = app
def create_tab(self) -> dict:
"""Create the Corner Cases tab"""
brain_options = self.app.data_loader.get_brain_measure_options()
ml_options = self.app.data_loader.get_ml_model_options()
# Keep all options including headers/dividers for visual organization
gr.Markdown("## Corner Cases Analysis")
# Find first non-header option as default
default_brain = next((opt[1] for opt in brain_options if not (isinstance(opt[1], str) and opt[1].startswith('header'))), None)
default_ml = next((opt[1] for opt in ml_options if not (isinstance(opt[1], str) and opt[1].startswith('header'))), None)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Controls")
brain_dropdown = gr.Dropdown(
choices=brain_options,
value=default_brain,
label="Brain Response Type"
)
ml_dropdown = gr.Dropdown(
choices=ml_options,
value=default_ml,
label="ML Model"
)
top_n_slider = gr.Slider(
minimum=1,
maximum=20,
value=5,
step=1,
label="Number of pairs to show per corner",
info=""
)
show_images_checkbox = gr.Checkbox(
label="Show images",
value=True,
info=""
)
analyze_btn = gr.Button("Analyze Corner Cases", variant="primary")
with gr.Column(scale=2):
gr.Markdown("### Instructions")
gr.Markdown("""
This analysis finds image pairs closest to each of the 8 corners in normalized 3D space:
- Each axis (Human, Brain, ML) is normalized to 0-1
- Distance is calculated using Euclidean distance
- The closest pairs to each corner are shown
**Corners represent:**
- (0,0,0): All low - general disagreement on similarity
- (1,1,1): All high - strong agreement on similarity
- Mixed corners show interesting disagreements between measures
""")
# Results display - now using HTML for better image layout
gr.Markdown("---")
results_display = gr.HTML("")
return {
'brain_dropdown': brain_dropdown,
'ml_dropdown': ml_dropdown,
'top_n_slider': top_n_slider,
'show_images_checkbox': show_images_checkbox,
'analyze_btn': analyze_btn,
'results_display': results_display
}
def create_single_pair_bar_plot(self, result, pair_index):
"""Create a bar plot showing normalized values for a single pair using matplotlib"""
# Create figure
fig, ax = plt.subplots(figsize=(6, 4))
categories = ['Human', 'Brain', 'ML']
values = [result['human_norm'], result['brain_norm'], result['ml_norm']]
colors = ['#4A90E2', '#50C878', '#E24A4A']
# Create bars
bars = ax.bar(categories, values, color=colors, alpha=0.8, edgecolor='black', linewidth=1.5)
# Add value labels on top of bars
for bar, val in zip(bars, values):
height = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2., height + 0.02,
f'{val:.3f}',
ha='center', va='bottom', fontsize=11, fontweight='bold')
# Styling
ax.set_ylabel('Normalized Value (0-1)', fontsize=11, fontweight='bold')
ax.set_xlabel('Measure', fontsize=11, fontweight='bold')
ax.set_title(f'Normalized Values for Pair #{pair_index}', fontsize=12, fontweight='bold')
ax.set_ylim(0, 1.15)
ax.grid(axis='y', alpha=0.3, linestyle='--')
ax.set_axisbelow(True)
# Style the plot
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.tight_layout()
# Convert to base64 image
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
buf.seek(0)
img_base64 = base64.b64encode(buf.read()).decode()
plt.close(fig)
return f'<img src="data:image/png;base64,{img_base64}" style="width: 100%; max-width: 450px; margin: 10px auto; display: block; border: 1px solid #ddd; border-radius: 5px; padding: 5px; background: white;" />'
def connect_events(self, components):
"""Connect event handlers for this tab"""
def analyze_corners(brain_measure, ml_model_selection, top_n, show_images):
if ml_model_selection == "separator":
return "Please select a valid model or average option"
top_n = int(top_n)
# Get corner cases
corner_results = self.app.get_corner_cases(brain_measure, ml_model_selection, top_n=top_n)
if not corner_results:
return "Error: Could not calculate corner cases"
# Format results as HTML for better image display
output = "<h2>Corner Cases Results</h2>"
# Process corners in a logical order
corner_order = self.app.corner_analyzer.get_corner_order()
for corner_name in corner_order:
if corner_name not in corner_results:
continue
results = corner_results[corner_name]
interpretation = self.app.get_corner_interpretation(corner_name)
output += f"<h3>Corner {corner_name}</h3>"
output += f"<p><strong>{interpretation}</strong></p>"
if show_images:
# Show results with images in a grid
output += "<div style='display: grid; grid-template-columns: repeat(auto-fill, minmax(400px, 1fr)); gap: 20px; margin-bottom: 30px; margin-top: 20px;'>"
for rank, result in enumerate(results, 1):
output += "<div style='border: 1px solid #ddd; padding: 15px; border-radius: 8px; background: #f9f9f9;'>"
output += f"<h4>Rank {rank} - Pair #{result['index']}</h4>"
# Get image URLs and captions
data = self.app.data_loader.data
pair_row = data.iloc[result['index']]
# Get URLs
img1_url = pair_row['stim_1']
img2_url = pair_row['stim_2']
# Get captions and tags directly - no swapping needed
caption1_data = pair_row.get('image_1_description', 'No caption available')
caption2_data = pair_row.get('image_2_description', 'No caption available')
tags1_data = pair_row.get('image_1_tags', 'No tags available')
tags2_data = pair_row.get('image_2_tags', 'No tags available')
def format_caption_html(caption_text):
if not caption_text or caption_text == 'No caption available':
return 'No caption available'
captions = [c.strip() for c in str(caption_text).split('|')]
if len(captions) == 1:
return captions[0]
else:
html = f"<strong>{len(captions)} descriptions:</strong><ol style='margin: 5px 0; padding-left: 20px;'>"
for cap in captions:
html += f"<li>{cap}</li>"
html += "</ol>"
return html
def format_tags_html(tags_text):
if not tags_text or tags_text == 'No tags available':
return '<span style="color:#999;">No tags</span>'
tags = [t.strip() for t in str(tags_text).split('|')]
html = ''
for tag in tags:
html += f'<span style="background:#ffc107; color:#000; padding:2px 6px; border-radius:3px; margin:2px; display:inline-block; font-size:10px;">{tag}</span>'
return html
caption1 = format_caption_html(caption1_data)
caption2 = format_caption_html(caption2_data)
tags1 = format_tags_html(tags1_data)
tags2 = format_tags_html(tags2_data)
# Display images side by side
output += "<div style='display: flex; gap: 10px; margin: 10px 0;'>"
output += f"<div style='flex: 1;'>"
output += f"<img src='{img1_url}' style='max-width: 100%; height: 150px; object-fit: contain; display: block; margin: 0 auto;' />"
output += f"<p style='margin: 5px 0; font-size: 12px; text-align: center;'><code>{result['image_1']}</code></p>"
output += f"<p style='margin: 5px 0; font-size: 11px; color: #666; text-align: left;'><strong>Caption:</strong> {caption1}</p>"
output += f"<p style='margin: 5px 0; font-size: 11px; text-align: left;'><strong>Tags:</strong> {tags1}</p>"
output += "</div>"
output += f"<div style='flex: 1;'>"
output += f"<img src='{img2_url}' style='max-width: 100%; height: 150px; object-fit: contain; display: block; margin: 0 auto;' />"
output += f"<p style='margin: 5px 0; font-size: 12px; text-align: center;'><code>{result['image_2']}</code></p>"
output += f"<p style='margin: 5px 0; font-size: 11px; color: #666; text-align: left;'><strong>Caption:</strong> {caption2}</p>"
output += f"<p style='margin: 5px 0; font-size: 11px; text-align: left;'><strong>Tags:</strong> {tags2}</p>"
output += "</div>"
output += "</div>"
# Display metrics
output += f"<table style='width: 100%; font-size: 12px; margin-top: 10px;'>"
output += f"<tr><td><strong>Distance to Corner:</strong></td><td>{result['distance']:.4f}</td></tr>"
output += f"<tr><td colspan='2' style='padding-top: 8px; font-weight: bold;'>Raw Values:</td></tr>"
output += f"<tr><td>Human:</td><td>{result['human']:.3f}</td></tr>"
output += f"<tr><td>Brain:</td><td>{result['brain']:.3f}</td></tr>"
output += f"<tr><td>ML:</td><td>{result['ml']:.3f}</td></tr>"
output += f"<tr><td colspan='2' style='padding-top: 8px; font-weight: bold;'>Normalized (0-1):</td></tr>"
output += f"<tr><td>Human:</td><td>{result['human_norm']:.3f}</td></tr>"
output += f"<tr><td>Brain:</td><td>{result['brain_norm']:.3f}</td></tr>"
output += f"<tr><td>ML:</td><td>{result['ml_norm']:.3f}</td></tr>"
output += "</table>"
# Add bar plot for this specific pair
try:
bar_plot_html = self.create_single_pair_bar_plot(result, result['index'])
output += f"<div style='margin: 15px 0;'>{bar_plot_html}</div>"
except Exception as e:
output += f"<div style='color: red; padding: 10px;'>Error creating plot: {e}</div>"
output += "</div>"
output += "</div>"
else:
# Text-only table format
output += "<table style='width: 100%; border-collapse: collapse; margin-bottom: 30px; margin-top: 20px;'>"
output += "<thead><tr style='background: #f0f0f0;'>"
output += "<th style='border: 1px solid #ddd; padding: 8px;'>Rank</th>"
output += "<th style='border: 1px solid #ddd; padding: 8px;'>Pair #</th>"
output += "<th style='border: 1px solid #ddd; padding: 8px;'>Images</th>"
output += "<th style='border: 1px solid #ddd; padding: 8px;'>Distance</th>"
output += "<th style='border: 1px solid #ddd; padding: 8px;'>Human</th>"
output += "<th style='border: 1px solid #ddd; padding: 8px;'>Brain</th>"
output += "<th style='border: 1px solid #ddd; padding: 8px;'>ML</th>"
output += "</tr></thead><tbody>"
for rank, result in enumerate(results, 1):
output += "<tr>"
output += f"<td style='border: 1px solid #ddd; padding: 8px;'>{rank}</td>"
output += f"<td style='border: 1px solid #ddd; padding: 8px;'>{result['index']}</td>"
output += f"<td style='border: 1px solid #ddd; padding: 8px;'><code>{result['image_1']}</code> vs <code>{result['image_2']}</code></td>"
output += f"<td style='border: 1px solid #ddd; padding: 8px;'>{result['distance']:.4f}</td>"
output += f"<td style='border: 1px solid #ddd; padding: 8px;'>{result['human']:.3f}</td>"
output += f"<td style='border: 1px solid #ddd; padding: 8px;'>{result['brain']:.3f}</td>"
output += f"<td style='border: 1px solid #ddd; padding: 8px;'>{result['ml']:.3f}</td>"
output += "</tr>"
output += "</tbody></table>"
return output
components['analyze_btn'].click(
fn=analyze_corners,
inputs=[
components['brain_dropdown'],
components['ml_dropdown'],
components['top_n_slider'],
components['show_images_checkbox']
],
outputs=[components['results_display']]
)