similarity_analysis / visualization /plot_generator.py
DanJChong's picture
Upload folder using huggingface_hub
329d553 verified
# visualization/plot_generator.py
"""Main plotting functionality for similarity analysis"""
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from typing import Tuple, Optional, Union
import sys
import os
# Add parent directory to path for imports
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
class PlotGenerator:
"""Handles creation of plotly visualizations"""
def __init__(self, data_loader):
self.data_loader = data_loader
def compute_category_correlation_method2(self, category_key: str, target_series: pd.Series) -> float:
"""
Compute correlation using Method 2: Correlate each model, then average correlations.
This matches the bar chart methodology.
Args:
category_key: Category name like 'vision', 'captions_neural', etc.
target_series: The series to correlate with (e.g., brain measure or human judgement)
Returns:
Average correlation across all models in the category
"""
import numpy as np
# Get models in this category
models = [model[0] for model in self.data_loader.model_categories[category_key]]
if not models:
return 0.0
# Filter to available models
data = self.data_loader.data
available_models = [m for m in models if m in data.columns]
if not available_models:
return 0.0
# Compute correlation for each model
correlations = []
for model in available_models:
corr = data[model].corr(target_series)
if not np.isnan(corr):
correlations.append(corr)
# Return average correlation
if correlations:
return np.mean(correlations)
else:
return 0.0
@staticmethod
def add_image_hover_to_html(html_str: str) -> str:
"""Add custom JavaScript to enable image preview on hover for image pairs"""
custom_code = """
<style>
#image-hover-tooltip {
position: fixed;
display: none;
background: white;
border: 2px solid #333;
border-radius: 8px;
padding: 10px;
box-shadow: 0 4px 6px rgba(0,0,0,0.1);
z-index: 10000;
max-width: 500px;
pointer-events: none;
}
#image-hover-tooltip .tooltip-header {
font-weight: bold;
margin-bottom: 8px;
font-size: 13px;
color: #333;
}
#image-hover-tooltip .image-container {
display: flex;
gap: 10px;
align-items: flex-start;
}
#image-hover-tooltip .image-wrapper {
flex: 1;
text-align: center;
}
#image-hover-tooltip .image-label {
font-size: 10px;
color: #666;
margin-bottom: 3px;
font-weight: bold;
}
#image-hover-tooltip img {
max-width: 230px;
max-height: 230px;
display: block;
margin: 0 auto;
border: 1px solid #ddd;
}
</style>
<div id="image-hover-tooltip">
<div class="tooltip-header"></div>
<div class="image-container">
<div class="image-wrapper">
<div class="image-label">Image 1</div>
<img id="tooltip-img1" src="" alt="Image 1">
</div>
<div class="image-wrapper">
<div class="image-label">Image 2</div>
<img id="tooltip-img2" src="" alt="Image 2">
</div>
</div>
</div>
<script>
document.addEventListener('DOMContentLoaded', function() {
const tooltip = document.getElementById('image-hover-tooltip');
const tooltipHeader = tooltip.querySelector('.tooltip-header');
const tooltipImg1 = document.getElementById('tooltip-img1');
const tooltipImg2 = document.getElementById('tooltip-img2');
// Get all plotly divs
const plotDivs = document.querySelectorAll('.plotly-graph-div');
plotDivs.forEach(plotDiv => {
plotDiv.on('plotly_hover', function(data) {
if (data.points && data.points.length > 0) {
const point = data.points[0];
// Check if customdata exists and has image URLs
// Format: [idx, image_1_name, image_2_name, stim_1_url, stim_2_url]
if (point.customdata && point.customdata.length >= 5) {
const img1Name = point.customdata[1];
const img2Name = point.customdata[2];
const img1Url = point.customdata[3];
const img2Url = point.customdata[4];
if (img1Url && img2Url) {
tooltipHeader.textContent = `${img1Name} vs ${img2Name}`;
tooltipImg1.src = img1Url;
tooltipImg2.src = img2Url;
tooltip.style.display = 'block';
}
}
// Handle 4-element customdata (old format from 3D plot)
else if (point.customdata && point.customdata.length === 4) {
const img1Name = point.customdata[0];
const img2Name = point.customdata[1];
const img1Url = point.customdata[2];
const img2Url = point.customdata[3];
if (img1Url && img2Url) {
tooltipHeader.textContent = `${img1Name} vs ${img2Name}`;
tooltipImg1.src = img1Url;
tooltipImg2.src = img2Url;
tooltip.style.display = 'block';
}
}
}
});
plotDiv.on('plotly_unhover', function(data) {
tooltip.style.display = 'none';
});
// Update tooltip position on mouse move
plotDiv.addEventListener('mousemove', function(e) {
if (tooltip.style.display === 'block') {
const x = e.clientX + 15;
const y = e.clientY + 15;
// Keep tooltip on screen
const tooltipRect = tooltip.getBoundingClientRect();
const maxX = window.innerWidth - tooltipRect.width - 10;
const maxY = window.innerHeight - tooltipRect.height - 10;
tooltip.style.left = Math.min(x, maxX) + 'px';
tooltip.style.top = Math.min(y, maxY) + 'px';
}
});
});
});
</script>
"""
# Insert custom code before closing body tag
return html_str.replace('</body>', custom_code + '</body>')
def get_model_data(self, ml_model_selection: Union[str, int]) -> Tuple[pd.Series, str]:
"""Get model data - either individual model or category average"""
data = self.data_loader.data
# Define category display names
category_labels = {
'vision': 'Vision Models (Images) - Average',
'captions_neural': 'Neural Language (Captions) - Average',
'captions_statistical': 'Statistical Text (Captions) - Average',
'tags_statistical': 'Statistical Text (Tags) - Average'
}
# Handle category averages - USE STORED COLUMNS IF THEY EXIST FOR CONSISTENCY
if ml_model_selection == "avg_vision":
# Check if pre-calculated column exists
if 'avg_vision' in self.data_loader.data.columns:
return self.data_loader.data['avg_vision'], category_labels['vision']
# Otherwise calculate AND STORE for consistency
models = [model[0] for model in self.data_loader.model_categories['vision']]
if models:
available_models = [m for m in models if m in data.columns]
if not available_models:
raise ValueError("No vision models available in data")
avg_data = data[available_models].mean(axis=1)
self.data_loader.data['avg_vision'] = avg_data # Store in original DataFrame
return self.data_loader.data['avg_vision'], category_labels['vision']
else:
raise ValueError("No vision models available")
elif ml_model_selection == "avg_captions_neural":
# Check if pre-calculated column exists
if 'avg_captions_neural' in self.data_loader.data.columns:
return self.data_loader.data['avg_captions_neural'], category_labels['captions_neural']
# Otherwise calculate AND STORE for consistency
models = [model[0] for model in self.data_loader.model_categories['captions_neural']]
if models:
available_models = [m for m in models if m in data.columns]
if not available_models:
raise ValueError("No neural language models available in data")
avg_data = data[available_models].mean(axis=1)
self.data_loader.data['avg_captions_neural'] = avg_data # Store in original DataFrame
return self.data_loader.data['avg_captions_neural'], category_labels['captions_neural']
else:
raise ValueError("No neural language models available")
elif ml_model_selection == "avg_captions_statistical":
# Check if pre-calculated column exists
if 'avg_captions_statistical' in self.data_loader.data.columns:
return self.data_loader.data['avg_captions_statistical'], category_labels['captions_statistical']
# Otherwise calculate AND STORE for consistency
models = [model[0] for model in self.data_loader.model_categories['captions_statistical']]
if models:
available_models = [m for m in models if m in data.columns]
if not available_models:
raise ValueError("No statistical caption models available in data")
avg_data = data[available_models].mean(axis=1)
self.data_loader.data['avg_captions_statistical'] = avg_data # Store in original DataFrame
return self.data_loader.data['avg_captions_statistical'], category_labels['captions_statistical']
else:
raise ValueError("No statistical caption models available")
elif ml_model_selection == "avg_tags_statistical":
# Check if pre-calculated column exists
if 'avg_tags_statistical' in self.data_loader.data.columns:
return self.data_loader.data['avg_tags_statistical'], category_labels['tags_statistical']
# Otherwise calculate AND STORE for consistency
models = [model[0] for model in self.data_loader.model_categories['tags_statistical']]
if models:
available_models = [m for m in models if m in data.columns]
if not available_models:
raise ValueError("No statistical tag models available in data")
avg_data = data[available_models].mean(axis=1)
self.data_loader.data['avg_tags_statistical'] = avg_data # Store in original DataFrame
return self.data_loader.data['avg_tags_statistical'], category_labels['tags_statistical']
else:
raise ValueError("No statistical tag models available")
# Handle individual models
elif isinstance(ml_model_selection, int):
ml_column = self.data_loader.ml_models[ml_model_selection]
return data[ml_column], ml_column
else:
raise ValueError(f"Invalid model selection: {ml_model_selection}")
@staticmethod
def normalize_series(series: pd.Series) -> pd.Series:
"""Normalize a pandas series to 0-1 range"""
min_val = series.min()
max_val = series.max()
if max_val == min_val:
return pd.Series([0.5] * len(series))
return (series - min_val) / (max_val - min_val)
def create_3d_plot(self, brain_measure: str, ml_model_selection: Union[str, int], normalize: bool = False) -> Optional[go.Figure]:
"""Create 3D scatter plot"""
data = self.data_loader.data
try:
ml_data, ml_name = self.get_model_data(ml_model_selection)
except ValueError as e:
print(f"Error getting model data: {e}")
return None
# Get data (normalized or raw)
if normalize:
human_data = self.normalize_series(data['human_judgement'])
brain_data = self.normalize_series(data[brain_measure])
ml_plot_data = self.normalize_series(ml_data)
value_suffix = " (normalized)"
else:
human_data = data['human_judgement']
brain_data = data[brain_measure]
ml_plot_data = ml_data
value_suffix = ""
# Create hover text
hover_text = []
for idx, row in data.iterrows():
text = f"Pair #{idx}<br>"
text += f"Images: {row['image_1']} vs {row['image_2']}<br>"
text += f"Human: {human_data.iloc[idx]:.3f}<br>"
text += f"Brain: {brain_data.iloc[idx]:.3f}<br>"
text += f"ML: {ml_plot_data.iloc[idx]:.3f}"
hover_text.append(text)
fig = go.Figure(data=go.Scatter3d(
x=human_data,
y=brain_data,
z=ml_plot_data,
mode='markers',
marker=dict(
size=6,
color=human_data,
colorscale='Viridis',
opacity=0.7,
colorbar=dict(title="Human Rating" + value_suffix)
),
text=hover_text,
hovertemplate='%{text}<extra></extra>',
customdata=data[['image_1', 'image_2', 'stim_1', 'stim_2']].values
))
# Determine measure type and name from new column naming
if brain_measure.startswith("roi_"):
parts = brain_measure.split("_")
measure_type = parts[1].title() # cosine or pearson
roi_type = parts[2].title() # common, early, late
if "avg_sim" in brain_measure:
brain_name = f"{measure_type} {roi_type} (Similarity)"
elif "avg_roi" in brain_measure:
brain_name = f"{measure_type} {roi_type} (Pattern)"
else:
brain_name = brain_measure
elif brain_measure.startswith("voxel_") and not brain_measure.startswith("voxel_to_roi_"):
if "cosine" in brain_measure:
measure_type = "Cosine"
else:
measure_type = "Pearson"
if "all_avg" in brain_measure:
brain_name = f"{measure_type} All Voxels (Avg)"
elif "subj" in brain_measure:
subj_num = brain_measure.split("subj")[1]
brain_name = f"{measure_type} Subject {subj_num} Voxels"
else:
brain_name = brain_measure
elif brain_measure.startswith("voxel_to_roi_"):
parts = brain_measure.replace("voxel_to_roi_", "").split("_")
measure_type = parts[0].title()
roi_type = parts[1].title()
if "avg_sim" in brain_measure:
brain_name = f"{measure_type} {roi_type} (V→R Sim)"
elif "avg_roi" in brain_measure:
brain_name = f"{measure_type} {roi_type} (V→R Pattern)"
else:
brain_name = brain_measure
else:
brain_name = brain_measure
measure_type = "Unknown"
x_title = f'Human Rating{value_suffix}'
y_title = f'Brain Similarity ({measure_type} {brain_name}){value_suffix}'
z_title = f'ML Model: {ml_name}{value_suffix}'
fig.update_layout(
title=f'3D Analysis: Human vs {measure_type} {brain_name} Brain vs {ml_name}{"" if not normalize else " (Normalized)"}',
scene=dict(
xaxis_title=x_title,
yaxis_title=y_title,
zaxis_title=z_title,
camera=dict(eye=dict(x=1.5, y=1.5, z=1.5))
),
width=800,
height=600
)
return fig
def create_2d_plots(self, brain_measure: str, ml_model_selection: Union[str, int], normalize: bool = False) -> Optional[go.Figure]:
"""Create three 2D scatter plots"""
data = self.data_loader.data
try:
ml_data, ml_name = self.get_model_data(ml_model_selection)
except ValueError as e:
print(f"Error getting model data: {e}")
return None
# Get data (normalized or raw)
if normalize:
human_data = self.normalize_series(data['human_judgement'])
brain_data = self.normalize_series(data[brain_measure])
ml_plot_data = self.normalize_series(ml_data)
value_suffix = " (norm)"
else:
human_data = data['human_judgement']
brain_data = data[brain_measure]
ml_plot_data = ml_data
value_suffix = ""
# Calculate correlations (always on raw data)
corr_hb = data['human_judgement'].corr(data[brain_measure])
# Debug: Print what we're working with
print(f"\n[DEBUG generate_scatter]")
print(f" ml_model_selection = {ml_model_selection}")
print(f" ml_model_selection type = {type(ml_model_selection)}")
print(f" Starts with 'avg_'? {str(ml_model_selection).startswith('avg_')}")
# Check if this is a category average - use Method 2 (correlate then average)
# This matches the bar chart methodology
if str(ml_model_selection).startswith('avg_'):
# Determine which category
category_map = {
'avg_vision': 'vision',
'avg_captions_neural': 'captions_neural',
'avg_captions_statistical': 'captions_statistical',
'avg_tags_statistical': 'tags_statistical'
}
if ml_model_selection in category_map:
category_key = category_map[ml_model_selection]
# Method 2: Correlate each model individually, then average correlations
corr_hm = self.compute_category_correlation_method2(category_key, data['human_judgement'])
corr_bm = self.compute_category_correlation_method2(category_key, data[brain_measure])
# Update model name to indicate Method 2 is used
ml_name = ml_name + " (Method 2: Avg of Correlations)"
print(f"[METHOD 2] Using correlate-then-average for {ml_model_selection}")
print(f" Human vs Category: r = {corr_hm:.4f}")
print(f" Brain vs Category: r = {corr_bm:.4f}")
else:
# Fallback to Method 1
corr_hm = data['human_judgement'].corr(ml_data)
corr_bm = data[brain_measure].corr(ml_data)
else:
# Individual model - use regular Method 1
corr_hm = data['human_judgement'].corr(ml_data)
corr_bm = data[brain_measure].corr(ml_data)
# Determine measure type and name from new column naming
if brain_measure.startswith("roi_"):
parts = brain_measure.split("_")
measure_type = parts[1].title() # cosine or pearson
roi_type = parts[2].title() # common, early, late
if "avg_sim" in brain_measure:
brain_name = f"{measure_type} {roi_type} (Similarity)"
elif "avg_roi" in brain_measure:
brain_name = f"{measure_type} {roi_type} (Pattern)"
else:
brain_name = brain_measure
elif brain_measure.startswith("voxel_") and not brain_measure.startswith("voxel_to_roi_"):
if "cosine" in brain_measure:
measure_type = "Cosine"
else:
measure_type = "Pearson"
if "all_avg" in brain_measure:
brain_name = f"{measure_type} All Voxels (Avg)"
elif "subj" in brain_measure:
subj_num = brain_measure.split("subj")[1]
brain_name = f"{measure_type} Subject {subj_num} Voxels"
else:
brain_name = brain_measure
elif brain_measure.startswith("voxel_to_roi_"):
parts = brain_measure.replace("voxel_to_roi_", "").split("_")
measure_type = parts[0].title()
roi_type = parts[1].title()
if "avg_sim" in brain_measure:
brain_name = f"{measure_type} {roi_type} (V→R Sim)"
elif "avg_roi" in brain_measure:
brain_name = f"{measure_type} {roi_type} (V→R Pattern)"
else:
brain_name = brain_measure
else:
brain_name = brain_measure
measure_type = "Unknown"
# Create subplot
fig = make_subplots(
rows=1, cols=3,
subplot_titles=[
f'Human vs Brain (r={corr_hb:.3f})',
f'Human vs ML (r={corr_hm:.3f})',
f'Brain vs ML (r={corr_bm:.3f})'
],
horizontal_spacing=0.1
)
# Custom data for hover
customdata = [[idx, row['image_1'], row['image_2']] for idx, row in data.iterrows()]
# Add scatter plots with proper labels
plot_configs = [
{
'x': human_data,
'y': brain_data,
'color': 'blue',
'x_label': f'Human{value_suffix}',
'y_label': f'Brain ({measure_type}){value_suffix}'
},
{
'x': human_data,
'y': ml_plot_data,
'color': 'red',
'x_label': f'Human{value_suffix}',
'y_label': f'ML Model{value_suffix}'
},
{
'x': brain_data,
'y': ml_plot_data,
'color': 'green',
'x_label': f'Brain ({measure_type}){value_suffix}',
'y_label': f'ML Model{value_suffix}'
}
]
for i, config in enumerate(plot_configs):
fig.add_trace(
go.Scatter(
x=config['x'],
y=config['y'],
mode='markers',
marker=dict(color=config['color'], opacity=0.6, size=3),
hovertemplate=f'Pair #%{{customdata[0]}}<br>{config["x_label"]}: %{{x:.3f}}<br>{config["y_label"]}: %{{y:.3f}}<br>%{{customdata[1]}} vs %{{customdata[2]}}<extra></extra>',
customdata=customdata,
showlegend=False
),
row=1, col=i+1
)
fig.update_layout(
title=f'2D Comparisons: {measure_type} {brain_name} Brain vs {ml_name}{"" if not normalize else " (Normalized)"}',
width=1300,
height=500,
margin=dict(l=60, r=60, t=80, b=80)
)
# Add axis labels to each subplot
fig.update_xaxes(title_text="Human Similarity", row=1, col=1)
fig.update_yaxes(title_text=f"Brain Similarity ({measure_type})", row=1, col=1)
fig.update_xaxes(title_text="Human Similarity", row=1, col=2)
fig.update_yaxes(title_text="ML Model Similarity", row=1, col=2)
fig.update_xaxes(title_text=f"Brain Similarity ({measure_type})", row=1, col=3)
fig.update_yaxes(title_text="ML Model Similarity", row=1, col=3)
return fig