Spaces:
Sleeping
Sleeping
| # visualization/plot_generator.py | |
| """Main plotting functionality for similarity analysis""" | |
| import pandas as pd | |
| import plotly.graph_objects as go | |
| from plotly.subplots import make_subplots | |
| from typing import Tuple, Optional, Union | |
| import sys | |
| import os | |
| # Add parent directory to path for imports | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| class PlotGenerator: | |
| """Handles creation of plotly visualizations""" | |
| def __init__(self, data_loader): | |
| self.data_loader = data_loader | |
| def compute_category_correlation_method2(self, category_key: str, target_series: pd.Series) -> float: | |
| """ | |
| Compute correlation using Method 2: Correlate each model, then average correlations. | |
| This matches the bar chart methodology. | |
| Args: | |
| category_key: Category name like 'vision', 'captions_neural', etc. | |
| target_series: The series to correlate with (e.g., brain measure or human judgement) | |
| Returns: | |
| Average correlation across all models in the category | |
| """ | |
| import numpy as np | |
| # Get models in this category | |
| models = [model[0] for model in self.data_loader.model_categories[category_key]] | |
| if not models: | |
| return 0.0 | |
| # Filter to available models | |
| data = self.data_loader.data | |
| available_models = [m for m in models if m in data.columns] | |
| if not available_models: | |
| return 0.0 | |
| # Compute correlation for each model | |
| correlations = [] | |
| for model in available_models: | |
| corr = data[model].corr(target_series) | |
| if not np.isnan(corr): | |
| correlations.append(corr) | |
| # Return average correlation | |
| if correlations: | |
| return np.mean(correlations) | |
| else: | |
| return 0.0 | |
| def add_image_hover_to_html(html_str: str) -> str: | |
| """Add custom JavaScript to enable image preview on hover for image pairs""" | |
| custom_code = """ | |
| <style> | |
| #image-hover-tooltip { | |
| position: fixed; | |
| display: none; | |
| background: white; | |
| border: 2px solid #333; | |
| border-radius: 8px; | |
| padding: 10px; | |
| box-shadow: 0 4px 6px rgba(0,0,0,0.1); | |
| z-index: 10000; | |
| max-width: 500px; | |
| pointer-events: none; | |
| } | |
| #image-hover-tooltip .tooltip-header { | |
| font-weight: bold; | |
| margin-bottom: 8px; | |
| font-size: 13px; | |
| color: #333; | |
| } | |
| #image-hover-tooltip .image-container { | |
| display: flex; | |
| gap: 10px; | |
| align-items: flex-start; | |
| } | |
| #image-hover-tooltip .image-wrapper { | |
| flex: 1; | |
| text-align: center; | |
| } | |
| #image-hover-tooltip .image-label { | |
| font-size: 10px; | |
| color: #666; | |
| margin-bottom: 3px; | |
| font-weight: bold; | |
| } | |
| #image-hover-tooltip img { | |
| max-width: 230px; | |
| max-height: 230px; | |
| display: block; | |
| margin: 0 auto; | |
| border: 1px solid #ddd; | |
| } | |
| </style> | |
| <div id="image-hover-tooltip"> | |
| <div class="tooltip-header"></div> | |
| <div class="image-container"> | |
| <div class="image-wrapper"> | |
| <div class="image-label">Image 1</div> | |
| <img id="tooltip-img1" src="" alt="Image 1"> | |
| </div> | |
| <div class="image-wrapper"> | |
| <div class="image-label">Image 2</div> | |
| <img id="tooltip-img2" src="" alt="Image 2"> | |
| </div> | |
| </div> | |
| </div> | |
| <script> | |
| document.addEventListener('DOMContentLoaded', function() { | |
| const tooltip = document.getElementById('image-hover-tooltip'); | |
| const tooltipHeader = tooltip.querySelector('.tooltip-header'); | |
| const tooltipImg1 = document.getElementById('tooltip-img1'); | |
| const tooltipImg2 = document.getElementById('tooltip-img2'); | |
| // Get all plotly divs | |
| const plotDivs = document.querySelectorAll('.plotly-graph-div'); | |
| plotDivs.forEach(plotDiv => { | |
| plotDiv.on('plotly_hover', function(data) { | |
| if (data.points && data.points.length > 0) { | |
| const point = data.points[0]; | |
| // Check if customdata exists and has image URLs | |
| // Format: [idx, image_1_name, image_2_name, stim_1_url, stim_2_url] | |
| if (point.customdata && point.customdata.length >= 5) { | |
| const img1Name = point.customdata[1]; | |
| const img2Name = point.customdata[2]; | |
| const img1Url = point.customdata[3]; | |
| const img2Url = point.customdata[4]; | |
| if (img1Url && img2Url) { | |
| tooltipHeader.textContent = `${img1Name} vs ${img2Name}`; | |
| tooltipImg1.src = img1Url; | |
| tooltipImg2.src = img2Url; | |
| tooltip.style.display = 'block'; | |
| } | |
| } | |
| // Handle 4-element customdata (old format from 3D plot) | |
| else if (point.customdata && point.customdata.length === 4) { | |
| const img1Name = point.customdata[0]; | |
| const img2Name = point.customdata[1]; | |
| const img1Url = point.customdata[2]; | |
| const img2Url = point.customdata[3]; | |
| if (img1Url && img2Url) { | |
| tooltipHeader.textContent = `${img1Name} vs ${img2Name}`; | |
| tooltipImg1.src = img1Url; | |
| tooltipImg2.src = img2Url; | |
| tooltip.style.display = 'block'; | |
| } | |
| } | |
| } | |
| }); | |
| plotDiv.on('plotly_unhover', function(data) { | |
| tooltip.style.display = 'none'; | |
| }); | |
| // Update tooltip position on mouse move | |
| plotDiv.addEventListener('mousemove', function(e) { | |
| if (tooltip.style.display === 'block') { | |
| const x = e.clientX + 15; | |
| const y = e.clientY + 15; | |
| // Keep tooltip on screen | |
| const tooltipRect = tooltip.getBoundingClientRect(); | |
| const maxX = window.innerWidth - tooltipRect.width - 10; | |
| const maxY = window.innerHeight - tooltipRect.height - 10; | |
| tooltip.style.left = Math.min(x, maxX) + 'px'; | |
| tooltip.style.top = Math.min(y, maxY) + 'px'; | |
| } | |
| }); | |
| }); | |
| }); | |
| </script> | |
| """ | |
| # Insert custom code before closing body tag | |
| return html_str.replace('</body>', custom_code + '</body>') | |
| def get_model_data(self, ml_model_selection: Union[str, int]) -> Tuple[pd.Series, str]: | |
| """Get model data - either individual model or category average""" | |
| data = self.data_loader.data | |
| # Define category display names | |
| category_labels = { | |
| 'vision': 'Vision Models (Images) - Average', | |
| 'captions_neural': 'Neural Language (Captions) - Average', | |
| 'captions_statistical': 'Statistical Text (Captions) - Average', | |
| 'tags_statistical': 'Statistical Text (Tags) - Average' | |
| } | |
| # Handle category averages - USE STORED COLUMNS IF THEY EXIST FOR CONSISTENCY | |
| if ml_model_selection == "avg_vision": | |
| # Check if pre-calculated column exists | |
| if 'avg_vision' in self.data_loader.data.columns: | |
| return self.data_loader.data['avg_vision'], category_labels['vision'] | |
| # Otherwise calculate AND STORE for consistency | |
| models = [model[0] for model in self.data_loader.model_categories['vision']] | |
| if models: | |
| available_models = [m for m in models if m in data.columns] | |
| if not available_models: | |
| raise ValueError("No vision models available in data") | |
| avg_data = data[available_models].mean(axis=1) | |
| self.data_loader.data['avg_vision'] = avg_data # Store in original DataFrame | |
| return self.data_loader.data['avg_vision'], category_labels['vision'] | |
| else: | |
| raise ValueError("No vision models available") | |
| elif ml_model_selection == "avg_captions_neural": | |
| # Check if pre-calculated column exists | |
| if 'avg_captions_neural' in self.data_loader.data.columns: | |
| return self.data_loader.data['avg_captions_neural'], category_labels['captions_neural'] | |
| # Otherwise calculate AND STORE for consistency | |
| models = [model[0] for model in self.data_loader.model_categories['captions_neural']] | |
| if models: | |
| available_models = [m for m in models if m in data.columns] | |
| if not available_models: | |
| raise ValueError("No neural language models available in data") | |
| avg_data = data[available_models].mean(axis=1) | |
| self.data_loader.data['avg_captions_neural'] = avg_data # Store in original DataFrame | |
| return self.data_loader.data['avg_captions_neural'], category_labels['captions_neural'] | |
| else: | |
| raise ValueError("No neural language models available") | |
| elif ml_model_selection == "avg_captions_statistical": | |
| # Check if pre-calculated column exists | |
| if 'avg_captions_statistical' in self.data_loader.data.columns: | |
| return self.data_loader.data['avg_captions_statistical'], category_labels['captions_statistical'] | |
| # Otherwise calculate AND STORE for consistency | |
| models = [model[0] for model in self.data_loader.model_categories['captions_statistical']] | |
| if models: | |
| available_models = [m for m in models if m in data.columns] | |
| if not available_models: | |
| raise ValueError("No statistical caption models available in data") | |
| avg_data = data[available_models].mean(axis=1) | |
| self.data_loader.data['avg_captions_statistical'] = avg_data # Store in original DataFrame | |
| return self.data_loader.data['avg_captions_statistical'], category_labels['captions_statistical'] | |
| else: | |
| raise ValueError("No statistical caption models available") | |
| elif ml_model_selection == "avg_tags_statistical": | |
| # Check if pre-calculated column exists | |
| if 'avg_tags_statistical' in self.data_loader.data.columns: | |
| return self.data_loader.data['avg_tags_statistical'], category_labels['tags_statistical'] | |
| # Otherwise calculate AND STORE for consistency | |
| models = [model[0] for model in self.data_loader.model_categories['tags_statistical']] | |
| if models: | |
| available_models = [m for m in models if m in data.columns] | |
| if not available_models: | |
| raise ValueError("No statistical tag models available in data") | |
| avg_data = data[available_models].mean(axis=1) | |
| self.data_loader.data['avg_tags_statistical'] = avg_data # Store in original DataFrame | |
| return self.data_loader.data['avg_tags_statistical'], category_labels['tags_statistical'] | |
| else: | |
| raise ValueError("No statistical tag models available") | |
| # Handle individual models | |
| elif isinstance(ml_model_selection, int): | |
| ml_column = self.data_loader.ml_models[ml_model_selection] | |
| return data[ml_column], ml_column | |
| else: | |
| raise ValueError(f"Invalid model selection: {ml_model_selection}") | |
| def normalize_series(series: pd.Series) -> pd.Series: | |
| """Normalize a pandas series to 0-1 range""" | |
| min_val = series.min() | |
| max_val = series.max() | |
| if max_val == min_val: | |
| return pd.Series([0.5] * len(series)) | |
| return (series - min_val) / (max_val - min_val) | |
| def create_3d_plot(self, brain_measure: str, ml_model_selection: Union[str, int], normalize: bool = False) -> Optional[go.Figure]: | |
| """Create 3D scatter plot""" | |
| data = self.data_loader.data | |
| try: | |
| ml_data, ml_name = self.get_model_data(ml_model_selection) | |
| except ValueError as e: | |
| print(f"Error getting model data: {e}") | |
| return None | |
| # Get data (normalized or raw) | |
| if normalize: | |
| human_data = self.normalize_series(data['human_judgement']) | |
| brain_data = self.normalize_series(data[brain_measure]) | |
| ml_plot_data = self.normalize_series(ml_data) | |
| value_suffix = " (normalized)" | |
| else: | |
| human_data = data['human_judgement'] | |
| brain_data = data[brain_measure] | |
| ml_plot_data = ml_data | |
| value_suffix = "" | |
| # Create hover text | |
| hover_text = [] | |
| for idx, row in data.iterrows(): | |
| text = f"Pair #{idx}<br>" | |
| text += f"Images: {row['image_1']} vs {row['image_2']}<br>" | |
| text += f"Human: {human_data.iloc[idx]:.3f}<br>" | |
| text += f"Brain: {brain_data.iloc[idx]:.3f}<br>" | |
| text += f"ML: {ml_plot_data.iloc[idx]:.3f}" | |
| hover_text.append(text) | |
| fig = go.Figure(data=go.Scatter3d( | |
| x=human_data, | |
| y=brain_data, | |
| z=ml_plot_data, | |
| mode='markers', | |
| marker=dict( | |
| size=6, | |
| color=human_data, | |
| colorscale='Viridis', | |
| opacity=0.7, | |
| colorbar=dict(title="Human Rating" + value_suffix) | |
| ), | |
| text=hover_text, | |
| hovertemplate='%{text}<extra></extra>', | |
| customdata=data[['image_1', 'image_2', 'stim_1', 'stim_2']].values | |
| )) | |
| # Determine measure type and name from new column naming | |
| if brain_measure.startswith("roi_"): | |
| parts = brain_measure.split("_") | |
| measure_type = parts[1].title() # cosine or pearson | |
| roi_type = parts[2].title() # common, early, late | |
| if "avg_sim" in brain_measure: | |
| brain_name = f"{measure_type} {roi_type} (Similarity)" | |
| elif "avg_roi" in brain_measure: | |
| brain_name = f"{measure_type} {roi_type} (Pattern)" | |
| else: | |
| brain_name = brain_measure | |
| elif brain_measure.startswith("voxel_") and not brain_measure.startswith("voxel_to_roi_"): | |
| if "cosine" in brain_measure: | |
| measure_type = "Cosine" | |
| else: | |
| measure_type = "Pearson" | |
| if "all_avg" in brain_measure: | |
| brain_name = f"{measure_type} All Voxels (Avg)" | |
| elif "subj" in brain_measure: | |
| subj_num = brain_measure.split("subj")[1] | |
| brain_name = f"{measure_type} Subject {subj_num} Voxels" | |
| else: | |
| brain_name = brain_measure | |
| elif brain_measure.startswith("voxel_to_roi_"): | |
| parts = brain_measure.replace("voxel_to_roi_", "").split("_") | |
| measure_type = parts[0].title() | |
| roi_type = parts[1].title() | |
| if "avg_sim" in brain_measure: | |
| brain_name = f"{measure_type} {roi_type} (V→R Sim)" | |
| elif "avg_roi" in brain_measure: | |
| brain_name = f"{measure_type} {roi_type} (V→R Pattern)" | |
| else: | |
| brain_name = brain_measure | |
| else: | |
| brain_name = brain_measure | |
| measure_type = "Unknown" | |
| x_title = f'Human Rating{value_suffix}' | |
| y_title = f'Brain Similarity ({measure_type} {brain_name}){value_suffix}' | |
| z_title = f'ML Model: {ml_name}{value_suffix}' | |
| fig.update_layout( | |
| title=f'3D Analysis: Human vs {measure_type} {brain_name} Brain vs {ml_name}{"" if not normalize else " (Normalized)"}', | |
| scene=dict( | |
| xaxis_title=x_title, | |
| yaxis_title=y_title, | |
| zaxis_title=z_title, | |
| camera=dict(eye=dict(x=1.5, y=1.5, z=1.5)) | |
| ), | |
| width=800, | |
| height=600 | |
| ) | |
| return fig | |
| def create_2d_plots(self, brain_measure: str, ml_model_selection: Union[str, int], normalize: bool = False) -> Optional[go.Figure]: | |
| """Create three 2D scatter plots""" | |
| data = self.data_loader.data | |
| try: | |
| ml_data, ml_name = self.get_model_data(ml_model_selection) | |
| except ValueError as e: | |
| print(f"Error getting model data: {e}") | |
| return None | |
| # Get data (normalized or raw) | |
| if normalize: | |
| human_data = self.normalize_series(data['human_judgement']) | |
| brain_data = self.normalize_series(data[brain_measure]) | |
| ml_plot_data = self.normalize_series(ml_data) | |
| value_suffix = " (norm)" | |
| else: | |
| human_data = data['human_judgement'] | |
| brain_data = data[brain_measure] | |
| ml_plot_data = ml_data | |
| value_suffix = "" | |
| # Calculate correlations (always on raw data) | |
| corr_hb = data['human_judgement'].corr(data[brain_measure]) | |
| # Debug: Print what we're working with | |
| print(f"\n[DEBUG generate_scatter]") | |
| print(f" ml_model_selection = {ml_model_selection}") | |
| print(f" ml_model_selection type = {type(ml_model_selection)}") | |
| print(f" Starts with 'avg_'? {str(ml_model_selection).startswith('avg_')}") | |
| # Check if this is a category average - use Method 2 (correlate then average) | |
| # This matches the bar chart methodology | |
| if str(ml_model_selection).startswith('avg_'): | |
| # Determine which category | |
| category_map = { | |
| 'avg_vision': 'vision', | |
| 'avg_captions_neural': 'captions_neural', | |
| 'avg_captions_statistical': 'captions_statistical', | |
| 'avg_tags_statistical': 'tags_statistical' | |
| } | |
| if ml_model_selection in category_map: | |
| category_key = category_map[ml_model_selection] | |
| # Method 2: Correlate each model individually, then average correlations | |
| corr_hm = self.compute_category_correlation_method2(category_key, data['human_judgement']) | |
| corr_bm = self.compute_category_correlation_method2(category_key, data[brain_measure]) | |
| # Update model name to indicate Method 2 is used | |
| ml_name = ml_name + " (Method 2: Avg of Correlations)" | |
| print(f"[METHOD 2] Using correlate-then-average for {ml_model_selection}") | |
| print(f" Human vs Category: r = {corr_hm:.4f}") | |
| print(f" Brain vs Category: r = {corr_bm:.4f}") | |
| else: | |
| # Fallback to Method 1 | |
| corr_hm = data['human_judgement'].corr(ml_data) | |
| corr_bm = data[brain_measure].corr(ml_data) | |
| else: | |
| # Individual model - use regular Method 1 | |
| corr_hm = data['human_judgement'].corr(ml_data) | |
| corr_bm = data[brain_measure].corr(ml_data) | |
| # Determine measure type and name from new column naming | |
| if brain_measure.startswith("roi_"): | |
| parts = brain_measure.split("_") | |
| measure_type = parts[1].title() # cosine or pearson | |
| roi_type = parts[2].title() # common, early, late | |
| if "avg_sim" in brain_measure: | |
| brain_name = f"{measure_type} {roi_type} (Similarity)" | |
| elif "avg_roi" in brain_measure: | |
| brain_name = f"{measure_type} {roi_type} (Pattern)" | |
| else: | |
| brain_name = brain_measure | |
| elif brain_measure.startswith("voxel_") and not brain_measure.startswith("voxel_to_roi_"): | |
| if "cosine" in brain_measure: | |
| measure_type = "Cosine" | |
| else: | |
| measure_type = "Pearson" | |
| if "all_avg" in brain_measure: | |
| brain_name = f"{measure_type} All Voxels (Avg)" | |
| elif "subj" in brain_measure: | |
| subj_num = brain_measure.split("subj")[1] | |
| brain_name = f"{measure_type} Subject {subj_num} Voxels" | |
| else: | |
| brain_name = brain_measure | |
| elif brain_measure.startswith("voxel_to_roi_"): | |
| parts = brain_measure.replace("voxel_to_roi_", "").split("_") | |
| measure_type = parts[0].title() | |
| roi_type = parts[1].title() | |
| if "avg_sim" in brain_measure: | |
| brain_name = f"{measure_type} {roi_type} (V→R Sim)" | |
| elif "avg_roi" in brain_measure: | |
| brain_name = f"{measure_type} {roi_type} (V→R Pattern)" | |
| else: | |
| brain_name = brain_measure | |
| else: | |
| brain_name = brain_measure | |
| measure_type = "Unknown" | |
| # Create subplot | |
| fig = make_subplots( | |
| rows=1, cols=3, | |
| subplot_titles=[ | |
| f'Human vs Brain (r={corr_hb:.3f})', | |
| f'Human vs ML (r={corr_hm:.3f})', | |
| f'Brain vs ML (r={corr_bm:.3f})' | |
| ], | |
| horizontal_spacing=0.1 | |
| ) | |
| # Custom data for hover | |
| customdata = [[idx, row['image_1'], row['image_2']] for idx, row in data.iterrows()] | |
| # Add scatter plots with proper labels | |
| plot_configs = [ | |
| { | |
| 'x': human_data, | |
| 'y': brain_data, | |
| 'color': 'blue', | |
| 'x_label': f'Human{value_suffix}', | |
| 'y_label': f'Brain ({measure_type}){value_suffix}' | |
| }, | |
| { | |
| 'x': human_data, | |
| 'y': ml_plot_data, | |
| 'color': 'red', | |
| 'x_label': f'Human{value_suffix}', | |
| 'y_label': f'ML Model{value_suffix}' | |
| }, | |
| { | |
| 'x': brain_data, | |
| 'y': ml_plot_data, | |
| 'color': 'green', | |
| 'x_label': f'Brain ({measure_type}){value_suffix}', | |
| 'y_label': f'ML Model{value_suffix}' | |
| } | |
| ] | |
| for i, config in enumerate(plot_configs): | |
| fig.add_trace( | |
| go.Scatter( | |
| x=config['x'], | |
| y=config['y'], | |
| mode='markers', | |
| marker=dict(color=config['color'], opacity=0.6, size=3), | |
| hovertemplate=f'Pair #%{{customdata[0]}}<br>{config["x_label"]}: %{{x:.3f}}<br>{config["y_label"]}: %{{y:.3f}}<br>%{{customdata[1]}} vs %{{customdata[2]}}<extra></extra>', | |
| customdata=customdata, | |
| showlegend=False | |
| ), | |
| row=1, col=i+1 | |
| ) | |
| fig.update_layout( | |
| title=f'2D Comparisons: {measure_type} {brain_name} Brain vs {ml_name}{"" if not normalize else " (Normalized)"}', | |
| width=1300, | |
| height=500, | |
| margin=dict(l=60, r=60, t=80, b=80) | |
| ) | |
| # Add axis labels to each subplot | |
| fig.update_xaxes(title_text="Human Similarity", row=1, col=1) | |
| fig.update_yaxes(title_text=f"Brain Similarity ({measure_type})", row=1, col=1) | |
| fig.update_xaxes(title_text="Human Similarity", row=1, col=2) | |
| fig.update_yaxes(title_text="ML Model Similarity", row=1, col=2) | |
| fig.update_xaxes(title_text=f"Brain Similarity ({measure_type})", row=1, col=3) | |
| fig.update_yaxes(title_text="ML Model Similarity", row=1, col=3) | |
| return fig |