Spaces:
Sleeping
Sleeping
| # ==================== gui/corner_cases_tab.py ==================== | |
| """Corner Cases tab components""" | |
| import gradio as gr | |
| from typing import TYPE_CHECKING | |
| import plotly.graph_objects as go | |
| from plotly.subplots import make_subplots | |
| import matplotlib.pyplot as plt | |
| import matplotlib | |
| matplotlib.use('Agg') # Use non-interactive backend | |
| import io | |
| import base64 | |
| if TYPE_CHECKING: | |
| from similarity_analysis.app import SimilarityApp | |
| class CornerCasesTab: | |
| """Handles the Corner Cases tab""" | |
| def __init__(self, app: 'SimilarityApp'): | |
| self.app = app | |
| def create_tab(self) -> dict: | |
| """Create the Corner Cases tab""" | |
| brain_options = self.app.data_loader.get_brain_measure_options() | |
| ml_options = self.app.data_loader.get_ml_model_options() | |
| # Keep all options including headers/dividers for visual organization | |
| gr.Markdown("## Corner Cases Analysis") | |
| # Find first non-header option as default | |
| default_brain = next((opt[1] for opt in brain_options if not (isinstance(opt[1], str) and opt[1].startswith('header'))), None) | |
| default_ml = next((opt[1] for opt in ml_options if not (isinstance(opt[1], str) and opt[1].startswith('header'))), None) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Controls") | |
| brain_dropdown = gr.Dropdown( | |
| choices=brain_options, | |
| value=default_brain, | |
| label="Brain Response Type" | |
| ) | |
| ml_dropdown = gr.Dropdown( | |
| choices=ml_options, | |
| value=default_ml, | |
| label="ML Model" | |
| ) | |
| top_n_slider = gr.Slider( | |
| minimum=1, | |
| maximum=20, | |
| value=5, | |
| step=1, | |
| label="Number of pairs to show per corner", | |
| info="" | |
| ) | |
| show_images_checkbox = gr.Checkbox( | |
| label="Show images", | |
| value=True, | |
| info="" | |
| ) | |
| analyze_btn = gr.Button("Analyze Corner Cases", variant="primary") | |
| with gr.Column(scale=2): | |
| gr.Markdown("### Instructions") | |
| gr.Markdown(""" | |
| This analysis finds image pairs closest to each of the 8 corners in normalized 3D space: | |
| - Each axis (Human, Brain, ML) is normalized to 0-1 | |
| - Distance is calculated using Euclidean distance | |
| - The closest pairs to each corner are shown | |
| **Corners represent:** | |
| - (0,0,0): All low - general disagreement on similarity | |
| - (1,1,1): All high - strong agreement on similarity | |
| - Mixed corners show interesting disagreements between measures | |
| """) | |
| # Results display - now using HTML for better image layout | |
| gr.Markdown("---") | |
| results_display = gr.HTML("") | |
| return { | |
| 'brain_dropdown': brain_dropdown, | |
| 'ml_dropdown': ml_dropdown, | |
| 'top_n_slider': top_n_slider, | |
| 'show_images_checkbox': show_images_checkbox, | |
| 'analyze_btn': analyze_btn, | |
| 'results_display': results_display | |
| } | |
| def create_single_pair_bar_plot(self, result, pair_index): | |
| """Create a bar plot showing normalized values for a single pair using matplotlib""" | |
| # Create figure | |
| fig, ax = plt.subplots(figsize=(6, 4)) | |
| categories = ['Human', 'Brain', 'ML'] | |
| values = [result['human_norm'], result['brain_norm'], result['ml_norm']] | |
| colors = ['#4A90E2', '#50C878', '#E24A4A'] | |
| # Create bars | |
| bars = ax.bar(categories, values, color=colors, alpha=0.8, edgecolor='black', linewidth=1.5) | |
| # Add value labels on top of bars | |
| for bar, val in zip(bars, values): | |
| height = bar.get_height() | |
| ax.text(bar.get_x() + bar.get_width()/2., height + 0.02, | |
| f'{val:.3f}', | |
| ha='center', va='bottom', fontsize=11, fontweight='bold') | |
| # Styling | |
| ax.set_ylabel('Normalized Value (0-1)', fontsize=11, fontweight='bold') | |
| ax.set_xlabel('Measure', fontsize=11, fontweight='bold') | |
| ax.set_title(f'Normalized Values for Pair #{pair_index}', fontsize=12, fontweight='bold') | |
| ax.set_ylim(0, 1.15) | |
| ax.grid(axis='y', alpha=0.3, linestyle='--') | |
| ax.set_axisbelow(True) | |
| # Style the plot | |
| ax.spines['top'].set_visible(False) | |
| ax.spines['right'].set_visible(False) | |
| plt.tight_layout() | |
| # Convert to base64 image | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', dpi=100, bbox_inches='tight') | |
| buf.seek(0) | |
| img_base64 = base64.b64encode(buf.read()).decode() | |
| plt.close(fig) | |
| return f'<img src="data:image/png;base64,{img_base64}" style="width: 100%; max-width: 450px; margin: 10px auto; display: block; border: 1px solid #ddd; border-radius: 5px; padding: 5px; background: white;" />' | |
| def connect_events(self, components): | |
| """Connect event handlers for this tab""" | |
| def analyze_corners(brain_measure, ml_model_selection, top_n, show_images): | |
| if ml_model_selection == "separator": | |
| return "Please select a valid model or average option" | |
| top_n = int(top_n) | |
| # Get corner cases | |
| corner_results = self.app.get_corner_cases(brain_measure, ml_model_selection, top_n=top_n) | |
| if not corner_results: | |
| return "Error: Could not calculate corner cases" | |
| # Format results as HTML for better image display | |
| output = "<h2>Corner Cases Results</h2>" | |
| # Process corners in a logical order | |
| corner_order = self.app.corner_analyzer.get_corner_order() | |
| for corner_name in corner_order: | |
| if corner_name not in corner_results: | |
| continue | |
| results = corner_results[corner_name] | |
| interpretation = self.app.get_corner_interpretation(corner_name) | |
| output += f"<h3>Corner {corner_name}</h3>" | |
| output += f"<p><strong>{interpretation}</strong></p>" | |
| if show_images: | |
| # Show results with images in a grid | |
| output += "<div style='display: grid; grid-template-columns: repeat(auto-fill, minmax(400px, 1fr)); gap: 20px; margin-bottom: 30px; margin-top: 20px;'>" | |
| for rank, result in enumerate(results, 1): | |
| output += "<div style='border: 1px solid #ddd; padding: 15px; border-radius: 8px; background: #f9f9f9;'>" | |
| output += f"<h4>Rank {rank} - Pair #{result['index']}</h4>" | |
| # Get image URLs and captions | |
| data = self.app.data_loader.data | |
| pair_row = data.iloc[result['index']] | |
| # Get URLs | |
| img1_url = pair_row['stim_1'] | |
| img2_url = pair_row['stim_2'] | |
| # Get captions and tags directly - no swapping needed | |
| caption1_data = pair_row.get('image_1_description', 'No caption available') | |
| caption2_data = pair_row.get('image_2_description', 'No caption available') | |
| tags1_data = pair_row.get('image_1_tags', 'No tags available') | |
| tags2_data = pair_row.get('image_2_tags', 'No tags available') | |
| def format_caption_html(caption_text): | |
| if not caption_text or caption_text == 'No caption available': | |
| return 'No caption available' | |
| captions = [c.strip() for c in str(caption_text).split('|')] | |
| if len(captions) == 1: | |
| return captions[0] | |
| else: | |
| html = f"<strong>{len(captions)} descriptions:</strong><ol style='margin: 5px 0; padding-left: 20px;'>" | |
| for cap in captions: | |
| html += f"<li>{cap}</li>" | |
| html += "</ol>" | |
| return html | |
| def format_tags_html(tags_text): | |
| if not tags_text or tags_text == 'No tags available': | |
| return '<span style="color:#999;">No tags</span>' | |
| tags = [t.strip() for t in str(tags_text).split('|')] | |
| html = '' | |
| for tag in tags: | |
| html += f'<span style="background:#ffc107; color:#000; padding:2px 6px; border-radius:3px; margin:2px; display:inline-block; font-size:10px;">{tag}</span>' | |
| return html | |
| caption1 = format_caption_html(caption1_data) | |
| caption2 = format_caption_html(caption2_data) | |
| tags1 = format_tags_html(tags1_data) | |
| tags2 = format_tags_html(tags2_data) | |
| # Display images side by side | |
| output += "<div style='display: flex; gap: 10px; margin: 10px 0;'>" | |
| output += f"<div style='flex: 1;'>" | |
| output += f"<img src='{img1_url}' style='max-width: 100%; height: 150px; object-fit: contain; display: block; margin: 0 auto;' />" | |
| output += f"<p style='margin: 5px 0; font-size: 12px; text-align: center;'><code>{result['image_1']}</code></p>" | |
| output += f"<p style='margin: 5px 0; font-size: 11px; color: #666; text-align: left;'><strong>Caption:</strong> {caption1}</p>" | |
| output += f"<p style='margin: 5px 0; font-size: 11px; text-align: left;'><strong>Tags:</strong> {tags1}</p>" | |
| output += "</div>" | |
| output += f"<div style='flex: 1;'>" | |
| output += f"<img src='{img2_url}' style='max-width: 100%; height: 150px; object-fit: contain; display: block; margin: 0 auto;' />" | |
| output += f"<p style='margin: 5px 0; font-size: 12px; text-align: center;'><code>{result['image_2']}</code></p>" | |
| output += f"<p style='margin: 5px 0; font-size: 11px; color: #666; text-align: left;'><strong>Caption:</strong> {caption2}</p>" | |
| output += f"<p style='margin: 5px 0; font-size: 11px; text-align: left;'><strong>Tags:</strong> {tags2}</p>" | |
| output += "</div>" | |
| output += "</div>" | |
| # Display metrics | |
| output += f"<table style='width: 100%; font-size: 12px; margin-top: 10px;'>" | |
| output += f"<tr><td><strong>Distance to Corner:</strong></td><td>{result['distance']:.4f}</td></tr>" | |
| output += f"<tr><td colspan='2' style='padding-top: 8px; font-weight: bold;'>Raw Values:</td></tr>" | |
| output += f"<tr><td>Human:</td><td>{result['human']:.3f}</td></tr>" | |
| output += f"<tr><td>Brain:</td><td>{result['brain']:.3f}</td></tr>" | |
| output += f"<tr><td>ML:</td><td>{result['ml']:.3f}</td></tr>" | |
| output += f"<tr><td colspan='2' style='padding-top: 8px; font-weight: bold;'>Normalized (0-1):</td></tr>" | |
| output += f"<tr><td>Human:</td><td>{result['human_norm']:.3f}</td></tr>" | |
| output += f"<tr><td>Brain:</td><td>{result['brain_norm']:.3f}</td></tr>" | |
| output += f"<tr><td>ML:</td><td>{result['ml_norm']:.3f}</td></tr>" | |
| output += "</table>" | |
| # Add bar plot for this specific pair | |
| try: | |
| bar_plot_html = self.create_single_pair_bar_plot(result, result['index']) | |
| output += f"<div style='margin: 15px 0;'>{bar_plot_html}</div>" | |
| except Exception as e: | |
| output += f"<div style='color: red; padding: 10px;'>Error creating plot: {e}</div>" | |
| output += "</div>" | |
| output += "</div>" | |
| else: | |
| # Text-only table format | |
| output += "<table style='width: 100%; border-collapse: collapse; margin-bottom: 30px; margin-top: 20px;'>" | |
| output += "<thead><tr style='background: #f0f0f0;'>" | |
| output += "<th style='border: 1px solid #ddd; padding: 8px;'>Rank</th>" | |
| output += "<th style='border: 1px solid #ddd; padding: 8px;'>Pair #</th>" | |
| output += "<th style='border: 1px solid #ddd; padding: 8px;'>Images</th>" | |
| output += "<th style='border: 1px solid #ddd; padding: 8px;'>Distance</th>" | |
| output += "<th style='border: 1px solid #ddd; padding: 8px;'>Human</th>" | |
| output += "<th style='border: 1px solid #ddd; padding: 8px;'>Brain</th>" | |
| output += "<th style='border: 1px solid #ddd; padding: 8px;'>ML</th>" | |
| output += "</tr></thead><tbody>" | |
| for rank, result in enumerate(results, 1): | |
| output += "<tr>" | |
| output += f"<td style='border: 1px solid #ddd; padding: 8px;'>{rank}</td>" | |
| output += f"<td style='border: 1px solid #ddd; padding: 8px;'>{result['index']}</td>" | |
| output += f"<td style='border: 1px solid #ddd; padding: 8px;'><code>{result['image_1']}</code> vs <code>{result['image_2']}</code></td>" | |
| output += f"<td style='border: 1px solid #ddd; padding: 8px;'>{result['distance']:.4f}</td>" | |
| output += f"<td style='border: 1px solid #ddd; padding: 8px;'>{result['human']:.3f}</td>" | |
| output += f"<td style='border: 1px solid #ddd; padding: 8px;'>{result['brain']:.3f}</td>" | |
| output += f"<td style='border: 1px solid #ddd; padding: 8px;'>{result['ml']:.3f}</td>" | |
| output += "</tr>" | |
| output += "</tbody></table>" | |
| return output | |
| components['analyze_btn'].click( | |
| fn=analyze_corners, | |
| inputs=[ | |
| components['brain_dropdown'], | |
| components['ml_dropdown'], | |
| components['top_n_slider'], | |
| components['show_images_checkbox'] | |
| ], | |
| outputs=[components['results_display']] | |
| ) |