File size: 24,258 Bytes
18a9b1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4cb9a15
18a9b1c
 
4cb9a15
329d553
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4cb9a15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18a9b1c
 
 
 
4cb9a15
 
 
 
 
 
 
 
 
 
18a9b1c
4cb9a15
 
 
 
 
 
 
 
 
 
 
 
18a9b1c
 
4cb9a15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18a9b1c
4cb9a15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18a9b1c
4cb9a15
 
18a9b1c
 
 
 
4cb9a15
18a9b1c
 
 
befcceb
 
 
 
 
 
 
 
 
 
18a9b1c
 
 
 
 
 
 
 
 
befcceb
 
 
 
 
 
 
 
 
 
 
 
18a9b1c
 
 
 
 
befcceb
 
 
18a9b1c
 
 
befcceb
 
 
18a9b1c
 
 
befcceb
18a9b1c
 
befcceb
18a9b1c
 
 
 
 
 
4cb9a15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18a9b1c
befcceb
 
 
 
18a9b1c
befcceb
18a9b1c
befcceb
 
 
18a9b1c
 
 
 
 
 
 
befcceb
 
18a9b1c
 
 
 
 
 
 
 
 
befcceb
 
 
 
 
 
 
 
 
 
 
 
 
18a9b1c
329d553
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18a9b1c
4cb9a15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18a9b1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
befcceb
 
18a9b1c
befcceb
 
18a9b1c
 
befcceb
 
18a9b1c
befcceb
 
18a9b1c
 
befcceb
 
18a9b1c
befcceb
 
18a9b1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
befcceb
 
18a9b1c
befcceb
18a9b1c
 
 
 
 
befcceb
 
 
 
 
 
 
 
 
 
4cb9a15
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
# visualization/plot_generator.py
"""Main plotting functionality for similarity analysis"""

import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from typing import Tuple, Optional, Union
import sys
import os

# Add parent directory to path for imports
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

class PlotGenerator:
    """Handles creation of plotly visualizations"""

    def __init__(self, data_loader):
        self.data_loader = data_loader

    def compute_category_correlation_method2(self, category_key: str, target_series: pd.Series) -> float:
        """

        Compute correlation using Method 2: Correlate each model, then average correlations.

        This matches the bar chart methodology.



        Args:

            category_key: Category name like 'vision', 'captions_neural', etc.

            target_series: The series to correlate with (e.g., brain measure or human judgement)



        Returns:

            Average correlation across all models in the category

        """
        import numpy as np

        # Get models in this category
        models = [model[0] for model in self.data_loader.model_categories[category_key]]

        if not models:
            return 0.0

        # Filter to available models
        data = self.data_loader.data
        available_models = [m for m in models if m in data.columns]

        if not available_models:
            return 0.0

        # Compute correlation for each model
        correlations = []
        for model in available_models:
            corr = data[model].corr(target_series)
            if not np.isnan(corr):
                correlations.append(corr)

        # Return average correlation
        if correlations:
            return np.mean(correlations)
        else:
            return 0.0

    @staticmethod
    def add_image_hover_to_html(html_str: str) -> str:
        """Add custom JavaScript to enable image preview on hover for image pairs"""
        custom_code = """

<style>

    #image-hover-tooltip {

        position: fixed;

        display: none;

        background: white;

        border: 2px solid #333;

        border-radius: 8px;

        padding: 10px;

        box-shadow: 0 4px 6px rgba(0,0,0,0.1);

        z-index: 10000;

        max-width: 500px;

        pointer-events: none;

    }

    #image-hover-tooltip .tooltip-header {

        font-weight: bold;

        margin-bottom: 8px;

        font-size: 13px;

        color: #333;

    }

    #image-hover-tooltip .image-container {

        display: flex;

        gap: 10px;

        align-items: flex-start;

    }

    #image-hover-tooltip .image-wrapper {

        flex: 1;

        text-align: center;

    }

    #image-hover-tooltip .image-label {

        font-size: 10px;

        color: #666;

        margin-bottom: 3px;

        font-weight: bold;

    }

    #image-hover-tooltip img {

        max-width: 230px;

        max-height: 230px;

        display: block;

        margin: 0 auto;

        border: 1px solid #ddd;

    }

</style>



<div id="image-hover-tooltip">

    <div class="tooltip-header"></div>

    <div class="image-container">

        <div class="image-wrapper">

            <div class="image-label">Image 1</div>

            <img id="tooltip-img1" src="" alt="Image 1">

        </div>

        <div class="image-wrapper">

            <div class="image-label">Image 2</div>

            <img id="tooltip-img2" src="" alt="Image 2">

        </div>

    </div>

</div>



<script>

document.addEventListener('DOMContentLoaded', function() {

    const tooltip = document.getElementById('image-hover-tooltip');

    const tooltipHeader = tooltip.querySelector('.tooltip-header');

    const tooltipImg1 = document.getElementById('tooltip-img1');

    const tooltipImg2 = document.getElementById('tooltip-img2');



    // Get all plotly divs

    const plotDivs = document.querySelectorAll('.plotly-graph-div');



    plotDivs.forEach(plotDiv => {

        plotDiv.on('plotly_hover', function(data) {

            if (data.points && data.points.length > 0) {

                const point = data.points[0];



                // Check if customdata exists and has image URLs

                // Format: [idx, image_1_name, image_2_name, stim_1_url, stim_2_url]

                if (point.customdata && point.customdata.length >= 5) {

                    const img1Name = point.customdata[1];

                    const img2Name = point.customdata[2];

                    const img1Url = point.customdata[3];

                    const img2Url = point.customdata[4];



                    if (img1Url && img2Url) {

                        tooltipHeader.textContent = `${img1Name} vs ${img2Name}`;

                        tooltipImg1.src = img1Url;

                        tooltipImg2.src = img2Url;

                        tooltip.style.display = 'block';

                    }

                }

                // Handle 4-element customdata (old format from 3D plot)

                else if (point.customdata && point.customdata.length === 4) {

                    const img1Name = point.customdata[0];

                    const img2Name = point.customdata[1];

                    const img1Url = point.customdata[2];

                    const img2Url = point.customdata[3];



                    if (img1Url && img2Url) {

                        tooltipHeader.textContent = `${img1Name} vs ${img2Name}`;

                        tooltipImg1.src = img1Url;

                        tooltipImg2.src = img2Url;

                        tooltip.style.display = 'block';

                    }

                }

            }

        });



        plotDiv.on('plotly_unhover', function(data) {

            tooltip.style.display = 'none';

        });



        // Update tooltip position on mouse move

        plotDiv.addEventListener('mousemove', function(e) {

            if (tooltip.style.display === 'block') {

                const x = e.clientX + 15;

                const y = e.clientY + 15;



                // Keep tooltip on screen

                const tooltipRect = tooltip.getBoundingClientRect();

                const maxX = window.innerWidth - tooltipRect.width - 10;

                const maxY = window.innerHeight - tooltipRect.height - 10;



                tooltip.style.left = Math.min(x, maxX) + 'px';

                tooltip.style.top = Math.min(y, maxY) + 'px';

            }

        });

    });

});

</script>

"""
        # Insert custom code before closing body tag
        return html_str.replace('</body>', custom_code + '</body>')
    
    def get_model_data(self, ml_model_selection: Union[str, int]) -> Tuple[pd.Series, str]:
        """Get model data - either individual model or category average"""
        data = self.data_loader.data

        # Define category display names
        category_labels = {
            'vision': 'Vision Models (Images) - Average',
            'captions_neural': 'Neural Language (Captions) - Average',
            'captions_statistical': 'Statistical Text (Captions) - Average',
            'tags_statistical': 'Statistical Text (Tags) - Average'
        }

        # Handle category averages - USE STORED COLUMNS IF THEY EXIST FOR CONSISTENCY
        if ml_model_selection == "avg_vision":
            # Check if pre-calculated column exists
            if 'avg_vision' in self.data_loader.data.columns:
                return self.data_loader.data['avg_vision'], category_labels['vision']
            # Otherwise calculate AND STORE for consistency
            models = [model[0] for model in self.data_loader.model_categories['vision']]
            if models:
                available_models = [m for m in models if m in data.columns]
                if not available_models:
                    raise ValueError("No vision models available in data")
                avg_data = data[available_models].mean(axis=1)
                self.data_loader.data['avg_vision'] = avg_data  # Store in original DataFrame
                return self.data_loader.data['avg_vision'], category_labels['vision']
            else:
                raise ValueError("No vision models available")

        elif ml_model_selection == "avg_captions_neural":
            # Check if pre-calculated column exists
            if 'avg_captions_neural' in self.data_loader.data.columns:
                return self.data_loader.data['avg_captions_neural'], category_labels['captions_neural']
            # Otherwise calculate AND STORE for consistency
            models = [model[0] for model in self.data_loader.model_categories['captions_neural']]
            if models:
                available_models = [m for m in models if m in data.columns]
                if not available_models:
                    raise ValueError("No neural language models available in data")
                avg_data = data[available_models].mean(axis=1)
                self.data_loader.data['avg_captions_neural'] = avg_data  # Store in original DataFrame
                return self.data_loader.data['avg_captions_neural'], category_labels['captions_neural']
            else:
                raise ValueError("No neural language models available")

        elif ml_model_selection == "avg_captions_statistical":
            # Check if pre-calculated column exists
            if 'avg_captions_statistical' in self.data_loader.data.columns:
                return self.data_loader.data['avg_captions_statistical'], category_labels['captions_statistical']
            # Otherwise calculate AND STORE for consistency
            models = [model[0] for model in self.data_loader.model_categories['captions_statistical']]
            if models:
                available_models = [m for m in models if m in data.columns]
                if not available_models:
                    raise ValueError("No statistical caption models available in data")
                avg_data = data[available_models].mean(axis=1)
                self.data_loader.data['avg_captions_statistical'] = avg_data  # Store in original DataFrame
                return self.data_loader.data['avg_captions_statistical'], category_labels['captions_statistical']
            else:
                raise ValueError("No statistical caption models available")

        elif ml_model_selection == "avg_tags_statistical":
            # Check if pre-calculated column exists
            if 'avg_tags_statistical' in self.data_loader.data.columns:
                return self.data_loader.data['avg_tags_statistical'], category_labels['tags_statistical']
            # Otherwise calculate AND STORE for consistency
            models = [model[0] for model in self.data_loader.model_categories['tags_statistical']]
            if models:
                available_models = [m for m in models if m in data.columns]
                if not available_models:
                    raise ValueError("No statistical tag models available in data")
                avg_data = data[available_models].mean(axis=1)
                self.data_loader.data['avg_tags_statistical'] = avg_data  # Store in original DataFrame
                return self.data_loader.data['avg_tags_statistical'], category_labels['tags_statistical']
            else:
                raise ValueError("No statistical tag models available")

        # Handle individual models
        elif isinstance(ml_model_selection, int):
            ml_column = self.data_loader.ml_models[ml_model_selection]
            return data[ml_column], ml_column

        else:
            raise ValueError(f"Invalid model selection: {ml_model_selection}")
    
    @staticmethod
    def normalize_series(series: pd.Series) -> pd.Series:
        """Normalize a pandas series to 0-1 range"""
        min_val = series.min()
        max_val = series.max()
        if max_val == min_val:
            return pd.Series([0.5] * len(series))
        return (series - min_val) / (max_val - min_val)

    def create_3d_plot(self, brain_measure: str, ml_model_selection: Union[str, int], normalize: bool = False) -> Optional[go.Figure]:
        """Create 3D scatter plot"""
        data = self.data_loader.data
        
        try:
            ml_data, ml_name = self.get_model_data(ml_model_selection)
        except ValueError as e:
            print(f"Error getting model data: {e}")
            return None
        
        # Get data (normalized or raw)
        if normalize:
            human_data = self.normalize_series(data['human_judgement'])
            brain_data = self.normalize_series(data[brain_measure])
            ml_plot_data = self.normalize_series(ml_data)
            value_suffix = " (normalized)"
        else:
            human_data = data['human_judgement']
            brain_data = data[brain_measure]
            ml_plot_data = ml_data
            value_suffix = ""
        
        # Create hover text
        hover_text = []
        for idx, row in data.iterrows():
            text = f"Pair #{idx}<br>"
            text += f"Images: {row['image_1']} vs {row['image_2']}<br>"
            text += f"Human: {human_data.iloc[idx]:.3f}<br>"
            text += f"Brain: {brain_data.iloc[idx]:.3f}<br>"
            text += f"ML: {ml_plot_data.iloc[idx]:.3f}"
            hover_text.append(text)
        
        fig = go.Figure(data=go.Scatter3d(
            x=human_data,
            y=brain_data,
            z=ml_plot_data,
            mode='markers',
            marker=dict(
                size=6,
                color=human_data,
                colorscale='Viridis',
                opacity=0.7,
                colorbar=dict(title="Human Rating" + value_suffix)
            ),
            text=hover_text,
            hovertemplate='%{text}<extra></extra>',
            customdata=data[['image_1', 'image_2', 'stim_1', 'stim_2']].values
        ))
        
        # Determine measure type and name from new column naming
        if brain_measure.startswith("roi_"):
            parts = brain_measure.split("_")
            measure_type = parts[1].title()  # cosine or pearson
            roi_type = parts[2].title()  # common, early, late
            if "avg_sim" in brain_measure:
                brain_name = f"{measure_type} {roi_type} (Similarity)"
            elif "avg_roi" in brain_measure:
                brain_name = f"{measure_type} {roi_type} (Pattern)"
            else:
                brain_name = brain_measure
        elif brain_measure.startswith("voxel_") and not brain_measure.startswith("voxel_to_roi_"):
            if "cosine" in brain_measure:
                measure_type = "Cosine"
            else:
                measure_type = "Pearson"
            if "all_avg" in brain_measure:
                brain_name = f"{measure_type} All Voxels (Avg)"
            elif "subj" in brain_measure:
                subj_num = brain_measure.split("subj")[1]
                brain_name = f"{measure_type} Subject {subj_num} Voxels"
            else:
                brain_name = brain_measure
        elif brain_measure.startswith("voxel_to_roi_"):
            parts = brain_measure.replace("voxel_to_roi_", "").split("_")
            measure_type = parts[0].title()
            roi_type = parts[1].title()
            if "avg_sim" in brain_measure:
                brain_name = f"{measure_type} {roi_type} (V→R Sim)"
            elif "avg_roi" in brain_measure:
                brain_name = f"{measure_type} {roi_type} (V→R Pattern)"
            else:
                brain_name = brain_measure
        else:
            brain_name = brain_measure
            measure_type = "Unknown"

        
        x_title = f'Human Rating{value_suffix}'
        y_title = f'Brain Similarity ({measure_type} {brain_name}){value_suffix}'
        z_title = f'ML Model: {ml_name}{value_suffix}'
        
        fig.update_layout(
            title=f'3D Analysis: Human vs {measure_type} {brain_name} Brain vs {ml_name}{"" if not normalize else " (Normalized)"}',
            scene=dict(
                xaxis_title=x_title,
                yaxis_title=y_title,
                zaxis_title=z_title,
                camera=dict(eye=dict(x=1.5, y=1.5, z=1.5))
            ),
            width=800,
            height=600
        )
        
        return fig

    def create_2d_plots(self, brain_measure: str, ml_model_selection: Union[str, int], normalize: bool = False) -> Optional[go.Figure]:
        """Create three 2D scatter plots"""
        data = self.data_loader.data
        
        try:
            ml_data, ml_name = self.get_model_data(ml_model_selection)
        except ValueError as e:
            print(f"Error getting model data: {e}")
            return None
        
        # Get data (normalized or raw)
        if normalize:
            human_data = self.normalize_series(data['human_judgement'])
            brain_data = self.normalize_series(data[brain_measure])
            ml_plot_data = self.normalize_series(ml_data)
            value_suffix = " (norm)"
        else:
            human_data = data['human_judgement']
            brain_data = data[brain_measure]
            ml_plot_data = ml_data
            value_suffix = ""
        
        # Calculate correlations (always on raw data)
        corr_hb = data['human_judgement'].corr(data[brain_measure])

        # Debug: Print what we're working with
        print(f"\n[DEBUG generate_scatter]")
        print(f"  ml_model_selection = {ml_model_selection}")
        print(f"  ml_model_selection type = {type(ml_model_selection)}")
        print(f"  Starts with 'avg_'? {str(ml_model_selection).startswith('avg_')}")

        # Check if this is a category average - use Method 2 (correlate then average)
        # This matches the bar chart methodology
        if str(ml_model_selection).startswith('avg_'):
            # Determine which category
            category_map = {
                'avg_vision': 'vision',
                'avg_captions_neural': 'captions_neural',
                'avg_captions_statistical': 'captions_statistical',
                'avg_tags_statistical': 'tags_statistical'
            }

            if ml_model_selection in category_map:
                category_key = category_map[ml_model_selection]

                # Method 2: Correlate each model individually, then average correlations
                corr_hm = self.compute_category_correlation_method2(category_key, data['human_judgement'])
                corr_bm = self.compute_category_correlation_method2(category_key, data[brain_measure])

                # Update model name to indicate Method 2 is used
                ml_name = ml_name + " (Method 2: Avg of Correlations)"

                print(f"[METHOD 2] Using correlate-then-average for {ml_model_selection}")
                print(f"  Human vs Category: r = {corr_hm:.4f}")
                print(f"  Brain vs Category: r = {corr_bm:.4f}")
            else:
                # Fallback to Method 1
                corr_hm = data['human_judgement'].corr(ml_data)
                corr_bm = data[brain_measure].corr(ml_data)
        else:
            # Individual model - use regular Method 1
            corr_hm = data['human_judgement'].corr(ml_data)
            corr_bm = data[brain_measure].corr(ml_data)
        
        # Determine measure type and name from new column naming
        if brain_measure.startswith("roi_"):
            parts = brain_measure.split("_")
            measure_type = parts[1].title()  # cosine or pearson
            roi_type = parts[2].title()  # common, early, late
            if "avg_sim" in brain_measure:
                brain_name = f"{measure_type} {roi_type} (Similarity)"
            elif "avg_roi" in brain_measure:
                brain_name = f"{measure_type} {roi_type} (Pattern)"
            else:
                brain_name = brain_measure
        elif brain_measure.startswith("voxel_") and not brain_measure.startswith("voxel_to_roi_"):
            if "cosine" in brain_measure:
                measure_type = "Cosine"
            else:
                measure_type = "Pearson"
            if "all_avg" in brain_measure:
                brain_name = f"{measure_type} All Voxels (Avg)"
            elif "subj" in brain_measure:
                subj_num = brain_measure.split("subj")[1]
                brain_name = f"{measure_type} Subject {subj_num} Voxels"
            else:
                brain_name = brain_measure
        elif brain_measure.startswith("voxel_to_roi_"):
            parts = brain_measure.replace("voxel_to_roi_", "").split("_")
            measure_type = parts[0].title()
            roi_type = parts[1].title()
            if "avg_sim" in brain_measure:
                brain_name = f"{measure_type} {roi_type} (V→R Sim)"
            elif "avg_roi" in brain_measure:
                brain_name = f"{measure_type} {roi_type} (V→R Pattern)"
            else:
                brain_name = brain_measure
        else:
            brain_name = brain_measure
            measure_type = "Unknown"

        
        # Create subplot
        fig = make_subplots(
            rows=1, cols=3,
            subplot_titles=[
                f'Human vs Brain (r={corr_hb:.3f})', 
                f'Human vs ML (r={corr_hm:.3f})', 
                f'Brain vs ML (r={corr_bm:.3f})'
            ],
            horizontal_spacing=0.1
        )
        
        # Custom data for hover
        customdata = [[idx, row['image_1'], row['image_2']] for idx, row in data.iterrows()]
        
        # Add scatter plots with proper labels
        plot_configs = [
            {
                'x': human_data, 
                'y': brain_data, 
                'color': 'blue',
                'x_label': f'Human{value_suffix}',
                'y_label': f'Brain ({measure_type}){value_suffix}'
            },
            {
                'x': human_data, 
                'y': ml_plot_data, 
                'color': 'red',
                'x_label': f'Human{value_suffix}',
                'y_label': f'ML Model{value_suffix}'
            },
            {
                'x': brain_data, 
                'y': ml_plot_data, 
                'color': 'green',
                'x_label': f'Brain ({measure_type}){value_suffix}',
                'y_label': f'ML Model{value_suffix}'
            }
        ]
        
        for i, config in enumerate(plot_configs):
            fig.add_trace(
                go.Scatter(
                    x=config['x'], 
                    y=config['y'],
                    mode='markers',
                    marker=dict(color=config['color'], opacity=0.6, size=3),
                    hovertemplate=f'Pair #%{{customdata[0]}}<br>{config["x_label"]}: %{{x:.3f}}<br>{config["y_label"]}: %{{y:.3f}}<br>%{{customdata[1]}} vs %{{customdata[2]}}<extra></extra>',
                    customdata=customdata,
                    showlegend=False
                ),
                row=1, col=i+1
            ) 

        fig.update_layout(
            title=f'2D Comparisons: {measure_type} {brain_name} Brain vs {ml_name}{"" if not normalize else " (Normalized)"}',
            width=1300,
            height=500,
            margin=dict(l=60, r=60, t=80, b=80)
        )
        
        # Add axis labels to each subplot
        fig.update_xaxes(title_text="Human Similarity", row=1, col=1)
        fig.update_yaxes(title_text=f"Brain Similarity ({measure_type})", row=1, col=1)
        
        fig.update_xaxes(title_text="Human Similarity", row=1, col=2)
        fig.update_yaxes(title_text="ML Model Similarity", row=1, col=2)
        
        fig.update_xaxes(title_text=f"Brain Similarity ({measure_type})", row=1, col=3)
        fig.update_yaxes(title_text="ML Model Similarity", row=1, col=3)
        
        return fig