DanJChong commited on
Commit
e11b896
·
verified ·
1 Parent(s): 401fd13

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -38,3 +38,8 @@ Final_similarity_matrix2.csv filter=lfs diff=lfs merge=lfs -text
38
  overall_database.csv filter=lfs diff=lfs merge=lfs -text
39
  overall_database2.csv filter=lfs diff=lfs merge=lfs -text
40
  overall_database3.csv filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
38
  overall_database.csv filter=lfs diff=lfs merge=lfs -text
39
  overall_database2.csv filter=lfs diff=lfs merge=lfs -text
40
  overall_database3.csv filter=lfs diff=lfs merge=lfs -text
41
+ data/Final_similarity_matrix.csv filter=lfs diff=lfs merge=lfs -text
42
+ data/Final_similarity_matrix2.csv filter=lfs diff=lfs merge=lfs -text
43
+ data/overall_database.csv filter=lfs diff=lfs merge=lfs -text
44
+ data/overall_database2.csv filter=lfs diff=lfs merge=lfs -text
45
+ data/overall_database3.csv filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -15,6 +15,8 @@ from visualization.image_viewer import ImageViewer
15
  from gui.gradio_interface import GradioInterface
16
  from typing import Tuple, Optional, Union, Dict, Any
17
  import pandas as pd
 
 
18
 
19
  class SimilarityApp:
20
  """Main application class that orchestrates all components"""
@@ -135,272 +137,370 @@ class SimilarityApp:
135
  print(f"Error getting model rankings: {e}")
136
  return {}
137
 
138
- # Replace the show_image_pair method in SimilarityApp class
139
- # This version adds normalized values to the display
140
-
141
- # Add this new method to SimilarityApp class in app.py
142
- # This returns multiple separate outputs for better Gradio layout
143
-
144
  def show_image_pair_multi(self, row_index: int):
145
- """Show image pair with separate outputs for better layout"""
146
- try:
147
- data = self.data_loader.data
148
- if row_index >= len(data):
149
- return (None, None, "Invalid index", "Invalid index",
150
- "Invalid row index", "", "", "", None)
151
-
152
- row = data.iloc[row_index]
153
- img1, img2 = self.image_viewer.get_image_pair(data, row_index)
154
-
155
- # Format captions
156
- def format_captions_html(caption_text):
157
- if pd.isna(caption_text) or caption_text == 'No caption available':
158
- return '<div style="padding:10px; background:#f8f9fa; border-radius:5px;">No caption available</div>'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
- captions = [c.strip() for c in str(caption_text).split('|')]
 
161
 
162
- if len(captions) == 1:
163
- return f'<div style="padding:10px; background:#f8f9fa; border-radius:5px;">{captions[0]}</div>'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  else:
165
- html = f'<div style="padding:10px; background:#f8f9fa; border-radius:5px;"><strong>{len(captions)} descriptions:</strong><ol style="margin:5px 0; padding-left:20px;">'
166
- for cap in captions:
167
- html += f'<li>{cap}</li>'
168
- html += '</ol></div>'
169
- return html
170
-
171
- caption1_html = format_captions_html(row.get('image_1_description', 'No caption available'))
172
- caption2_html = format_captions_html(row.get('image_2_description', 'No caption available'))
173
-
174
- # Get normalized values
175
- from analysis.corner_analyzer import CornerAnalyzer
176
-
177
- # Calculate averages if needed
178
- if 'avg_vision' not in data.columns:
179
- vision_models = [col for col in data.columns if 'BOLD5000_timm_' in col]
180
- language_models = [col for col in data.columns if 'bert-' in col or 'deberta-' in col or 'sup-simcse' in col]
181
- semantic_models = [col for col in data.columns if any(x in col for x in ["bm25", "rouge", "tf-idf", "co-occurrence"])]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
- def normalize_models(model_list):
184
- if not model_list:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  return pd.Series([0] * len(data))
186
- normalized_data = []
187
- for model in model_list:
188
- if model in data.columns:
189
- model_data = data[model]
190
- normalized = (model_data - model_data.min()) / (model_data.max() - model_data.min())
191
- normalized_data.append(normalized)
192
- if normalized_data:
193
- return pd.concat(normalized_data, axis=1).mean(axis=1)
194
- return pd.Series([0] * len(data))
195
-
196
- data['avg_vision'] = normalize_models(vision_models)
197
- data['avg_language'] = normalize_models(language_models)
198
- data['avg_semantic'] = normalize_models(semantic_models)
199
-
200
- # Get current model
201
- current_ml_model = getattr(self, '_current_ml_model', None)
202
- current_ml_name = getattr(self, '_current_ml_name', 'No model selected')
203
- current_ml_score = 'N/A'
204
- current_ml_norm = 'N/A'
205
-
206
- if current_ml_model is not None:
207
- try:
208
- if isinstance(current_ml_model, int) and current_ml_model < len(self.data_loader.ml_models):
209
- ml_column = self.data_loader.ml_models[current_ml_model]
210
- current_ml_score = f"{row[ml_column]:.3f}"
211
- current_ml_norm = f"{CornerAnalyzer.normalize_series(data[ml_column]).iloc[row_index]:.3f}"
212
- elif isinstance(current_ml_model, str) and current_ml_model.startswith('avg_'):
213
- current_ml_score = f"{row[current_ml_model]:.3f}"
214
- current_ml_norm = f"{CornerAnalyzer.normalize_series(data[current_ml_model]).iloc[row_index]:.3f}"
215
- except Exception:
216
- pass
217
-
218
- # Summary card HTML
219
- summary_html = f"""
220
- <div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 20px; border-radius: 10px; box-shadow: 0 4px 6px rgba(0,0,0,0.1);">
221
- <h3 style="margin: 0 0 10px 0;">Image Pair #{row_index} Summary</h3>
222
- <div><strong>Images:</strong> {row['image_1']} vs {row['image_2']}</div>
223
- <div style="display: grid; grid-template-columns: repeat(3, 1fr); gap: 10px; margin-top: 15px;">
224
- <div style="background: rgba(255,255,255,0.2); padding: 10px; border-radius: 5px; text-align: center;">
225
- <div style="font-size: 11px; opacity: 0.9; margin-bottom: 5px;">Human Rating</div>
226
- <div style="font-size: 20px; font-weight: bold;">{row['human_judgement']:.2f}/6</div>
227
- </div>
228
- <div style="background: rgba(255,255,255,0.2); padding: 10px; border-radius: 5px; text-align: center;">
229
- <div style="font-size: 11px; opacity: 0.9; margin-bottom: 5px;">Brain Similarity</div>
230
- <div style="font-size: 20px; font-weight: bold;">{row.get('cosine_similarity_roi_values_common', 0):.3f}</div>
231
  </div>
232
- <div style="background: rgba(255,255,255,0.2); padding: 10px; border-radius: 5px; text-align: center;">
233
- <div style="font-size: 11px; opacity: 0.9; margin-bottom: 5px;">ML Model</div>
234
- <div style="font-size: 20px; font-weight: bold;">{current_ml_score}</div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  </div>
236
  </div>
237
- </div>
238
- """
239
-
240
- # Brain measures table
241
- brain_html = f"""
242
- <table style="width: 100%; border-collapse: collapse;">
243
- <thead>
244
- <tr style="background: #f8f9fa;">
245
- <th style="padding: 10px; text-align: left; border-bottom: 2px solid #dee2e6;">Measure</th>
246
- <th style="padding: 10px; text-align: left; border-bottom: 2px solid #dee2e6;">Raw</th>
247
- <th style="padding: 10px; text-align: left; border-bottom: 2px solid #dee2e6;">Norm</th>
248
- <th style="padding: 10px; text-align: left; border-bottom: 2px solid #dee2e6;">Type</th>
249
- </tr>
250
- </thead>
251
- <tbody>
252
- <tr style="border-bottom: 1px solid #e9ecef;">
253
- <td style="padding: 8px 10px;"><strong>Cosine - Common</strong></td>
254
- <td style="padding: 8px 10px;">{row.get('cosine_similarity_roi_values_common', 0):.3f}</td>
255
- <td style="padding: 8px 10px;">{CornerAnalyzer.normalize_series(data['cosine_similarity_roi_values_common']).iloc[row_index]:.3f}</td>
256
- <td style="padding: 8px 10px;">All regions</td>
257
- </tr>
258
- <tr style="border-bottom: 1px solid #e9ecef;">
259
- <td style="padding: 8px 10px;"><strong>Cosine - Early</strong></td>
260
- <td style="padding: 8px 10px;">{row.get('cosine_similarity_roi_values_early', 0):.3f}</td>
261
- <td style="padding: 8px 10px;">{CornerAnalyzer.normalize_series(data['cosine_similarity_roi_values_early']).iloc[row_index]:.3f}</td>
262
- <td style="padding: 8px 10px;">Low-level</td>
263
- </tr>
264
- <tr style="border-bottom: 1px solid #e9ecef;">
265
- <td style="padding: 8px 10px;"><strong>Cosine - Late</strong></td>
266
- <td style="padding: 8px 10px;">{row.get('cosine_similarity_roi_values_late', 0):.3f}</td>
267
- <td style="padding: 8px 10px;">{CornerAnalyzer.normalize_series(data['cosine_similarity_roi_values_late']).iloc[row_index]:.3f}</td>
268
- <td style="padding: 8px 10px;">High-level</td>
269
- </tr>
270
- <tr style="border-bottom: 1px solid #e9ecef;">
271
- <td style="padding: 8px 10px;"><strong>Pearson - Common</strong></td>
272
- <td style="padding: 8px 10px;">{row.get('pearson_correlation_roi_values_common', 0):.3f}</td>
273
- <td style="padding: 8px 10px;">{CornerAnalyzer.normalize_series(data['pearson_correlation_roi_values_common']).iloc[row_index]:.3f}</td>
274
- <td style="padding: 8px 10px;">All regions</td>
275
- </tr>
276
- <tr style="border-bottom: 1px solid #e9ecef;">
277
- <td style="padding: 8px 10px;"><strong>Pearson - Early</strong></td>
278
- <td style="padding: 8px 10px;">{row.get('pearson_correlation_roi_values_early', 0):.3f}</td>
279
- <td style="padding: 8px 10px;">{CornerAnalyzer.normalize_series(data['pearson_correlation_roi_values_early']).iloc[row_index]:.3f}</td>
280
- <td style="padding: 8px 10px;">Low-level</td>
281
- </tr>
282
- <tr style="border-bottom: 1px solid #e9ecef;">
283
- <td style="padding: 8px 10px;"><strong>Pearson - Late</strong></td>
284
- <td style="padding: 8px 10px;">{row.get('pearson_correlation_roi_values_late', 0):.3f}</td>
285
- <td style="padding: 8px 10px;">{CornerAnalyzer.normalize_series(data['pearson_correlation_roi_values_late']).iloc[row_index]:.3f}</td>
286
- <td style="padding: 8px 10px;">High-level</td>
287
- </tr>
288
- </tbody>
289
- </table>
290
- """
291
-
292
- # Model performance HTML
293
- model_html = f"""
294
- <div style="margin-bottom: 15px;">
295
- <strong>Category Averages</strong>
296
- <table style="width: 100%; border-collapse: collapse; margin: 10px 0;">
297
  <thead>
298
  <tr style="background: #f8f9fa;">
299
- <th style="padding: 10px; text-align: left; border-bottom: 2px solid #dee2e6;">Category</th>
300
  <th style="padding: 10px; text-align: left; border-bottom: 2px solid #dee2e6;">Raw</th>
301
  <th style="padding: 10px; text-align: left; border-bottom: 2px solid #dee2e6;">Norm</th>
 
302
  </tr>
303
  </thead>
304
  <tbody>
305
  <tr style="border-bottom: 1px solid #e9ecef;">
306
- <td style="padding: 8px 10px;">Vision</td>
307
- <td style="padding: 8px 10px;">{row.get('avg_vision', 0):.3f}</td>
308
- <td style="padding: 8px 10px;">{CornerAnalyzer.normalize_series(data['avg_vision']).iloc[row_index]:.3f}</td>
 
309
  </tr>
310
  <tr style="border-bottom: 1px solid #e9ecef;">
311
- <td style="padding: 8px 10px;">Language</td>
312
- <td style="padding: 8px 10px;">{row.get('avg_language', 0):.3f}</td>
313
- <td style="padding: 8px 10px;">{CornerAnalyzer.normalize_series(data['avg_language']).iloc[row_index]:.3f}</td>
 
314
  </tr>
315
  <tr style="border-bottom: 1px solid #e9ecef;">
316
- <td style="padding: 8px 10px;">Semantic</td>
317
- <td style="padding: 8px 10px;">{row.get('avg_semantic', 0):.3f}</td>
318
- <td style="padding: 8px 10px;">{CornerAnalyzer.normalize_series(data['avg_semantic']).iloc[row_index]:.3f}</td>
319
- </tr>
320
- </tbody>
321
- </table>
322
- </div>
323
- <div>
324
- <strong>Current Selection</strong>
325
- <table style="width: 100%; border-collapse: collapse; margin: 10px 0;">
326
- <thead>
327
- <tr style="background: #f8f9fa;">
328
- <th style="padding: 10px; text-align: left; border-bottom: 2px solid #dee2e6;">Info</th>
329
- <th style="padding: 10px; text-align: left; border-bottom: 2px solid #dee2e6;">Value</th>
330
  </tr>
331
- </thead>
332
- <tbody>
333
  <tr style="border-bottom: 1px solid #e9ecef;">
334
- <td style="padding: 8px 10px;">Model</td>
335
- <td style="padding: 8px 10px;">{current_ml_name}</td>
 
 
336
  </tr>
337
  <tr style="border-bottom: 1px solid #e9ecef;">
338
- <td style="padding: 8px 10px;">Raw</td>
339
- <td style="padding: 8px 10px;">{current_ml_score}</td>
 
 
340
  </tr>
341
  <tr style="border-bottom: 1px solid #e9ecef;">
342
- <td style="padding: 8px 10px;">Norm</td>
343
- <td style="padding: 8px 10px;">{current_ml_norm}</td>
 
 
344
  </tr>
345
  </tbody>
346
  </table>
347
- <div style="margin-top: 15px; font-size: 12px; color: #666;">
348
- <strong>Dataset:</strong> Vision: {len([col for col in data.columns if 'BOLD5000_timm_' in col])},
349
- Language: {len([col for col in data.columns if 'bert-' in col or 'deberta-' in col or 'sup-simcse' in col])},
350
- Semantic: {len([col for col in data.columns if any(x in col for x in ["bm25", "rouge", "tf-idf", "co-occurrence"])])} models
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
  </div>
352
- </div>
353
- """
354
-
355
- # Model rankings
356
- rankings = self.get_model_rankings_for_pair(row_index)
357
- rankings_html = ""
358
- for category in ['vision', 'language', 'semantic']:
359
- if category in rankings:
360
- category_name = category.title()
361
- rankings_html += f"""
362
- <div style="margin: 15px 0;">
363
- <strong>{category_name} Models:</strong>
364
- <div style="display: grid; grid-template-columns: 1fr 1fr; gap: 15px; margin-top: 10px;">
365
- <div>
366
- <table style="width: 100%; border-collapse: collapse;">
367
- <thead><tr style="background: #f8f9fa;"><th style="padding: 8px; text-align: left; border-bottom: 2px solid #dee2e6;">Top 3 Best</th><th style="padding: 8px; text-align: left; border-bottom: 2px solid #dee2e6;">Score</th></tr></thead>
368
- <tbody>
369
- """
370
- for model, score in rankings[category]['best']:
371
- clean_name = model.replace('BOLD5000_timm_', '').replace('_sim_partial', '') if 'BOLD5000_timm_' in model else model
372
- rankings_html += f'<tr style="border-bottom: 1px solid #e9ecef;"><td style="padding: 6px 8px;">{clean_name}</td><td style="padding: 6px 8px;">{score:.3f}</td></tr>'
373
-
374
- rankings_html += """
375
- </tbody>
376
- </table>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377
  </div>
378
- <div>
379
- <table style="width: 100%; border-collapse: collapse;">
380
- <thead><tr style="background: #f8f9fa;"><th style="padding: 8px; text-align: left; border-bottom: 2px solid #dee2e6;">Top 3 Worst</th><th style="padding: 8px; text-align: left; border-bottom: 2px solid #dee2e6;">Score</th></tr></thead>
381
- <tbody>
382
- """
383
- for model, score in rankings[category]['worst']:
384
- clean_name = model.replace('BOLD5000_timm_', '').replace('_sim_partial', '') if 'BOLD5000_timm_' in model else model
385
- rankings_html += f'<tr style="border-bottom: 1px solid #e9ecef;"><td style="padding: 6px 8px;">{clean_name}</td><td style="padding: 6px 8px;">{score:.3f}</td></tr>'
386
-
387
- rankings_html += """
388
- </tbody>
389
- </table>
390
  </div>
391
- </div>
392
- </div>
393
- """
394
-
395
- # ROI plot
396
- roi_plot = self.roi_analyzer.create_roi_comparison_plot(data, row_index)
397
-
398
- return (img1, img2, caption1_html, caption2_html, summary_html,
399
- brain_html, model_html, rankings_html, roi_plot)
400
-
401
- except Exception as e:
402
- error_msg = f"<div style='color: red;'>Error: {e}</div>"
403
- return (None, None, error_msg, error_msg, error_msg, error_msg, error_msg, error_msg, None)
 
 
404
 
405
  def set_current_model(self, ml_model_selection, ml_name):
406
  """Store the current ML model selection for display in image viewer"""
@@ -429,12 +529,248 @@ class SimilarityApp:
429
  def get_corner_interpretation(self, corner_name: str) -> str:
430
  """Get interpretation of a corner - delegates to CornerAnalyzer"""
431
  return self.corner_analyzer.get_interpretation(corner_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
432
 
433
  def main():
434
  """Main function to run the application"""
435
  try:
436
  # Create and launch the app
437
- app = SimilarityApp('overall_database3.csv')
438
  app.launch(
439
  server_name="0.0.0.0",
440
  server_port=7860,
 
15
  from gui.gradio_interface import GradioInterface
16
  from typing import Tuple, Optional, Union, Dict, Any
17
  import pandas as pd
18
+ import numpy as np
19
+
20
 
21
  class SimilarityApp:
22
  """Main application class that orchestrates all components"""
 
137
  print(f"Error getting model rankings: {e}")
138
  return {}
139
 
 
 
 
 
 
 
140
  def show_image_pair_multi(self, row_index: int):
141
+ """Show image pair with separate outputs for better layout"""
142
+ try:
143
+ import matplotlib.pyplot as plt
144
+ import matplotlib
145
+ matplotlib.use('Agg')
146
+ import io
147
+ import base64
148
+
149
+ data = self.data_loader.data
150
+ if row_index >= len(data):
151
+ return (None, None, "Invalid index", "Invalid index",
152
+ "Invalid row index", "", "", "", "", None)
153
+
154
+ row = data.iloc[row_index]
155
+
156
+ # DEBUG: Print what images we're loading
157
+ print(f"\n{'*'*60}")
158
+ print(f"APP.PY - Loading Pair #{row_index}")
159
+ print(f"{'*'*60}")
160
+ print(f"Image 1 filename: {row.get('image_1', 'MISSING')}")
161
+ print(f"Image 2 filename: {row.get('image_2', 'MISSING')}")
162
+ print(f"Image 1 URL: {row.get('stim_1', 'MISSING')}")
163
+ print(f"Image 2 URL: {row.get('stim_2', 'MISSING')}")
164
+ print(f"{'*'*60}\n")
165
+
166
+ # Get images - now returns swap information
167
+ img1, img2, was_swapped = self.image_viewer.get_image_pair(data, row_index)
168
+
169
+ # Verify which images were loaded
170
+ print(f"Image 1 loaded: {'SUCCESS' if img1 is not None else 'FAILED'}")
171
+ print(f"Image 2 loaded: {'SUCCESS' if img2 is not None else 'FAILED'}")
172
+ print(f"URLs were swapped: {was_swapped}")
173
+
174
+ # Format captions
175
+ def format_captions_html(caption_text):
176
+ if pd.isna(caption_text) or caption_text == 'No caption available':
177
+ return '<div style="padding:10px; background:#f8f9fa; border-radius:5px;">No caption available</div>'
178
+
179
+ captions = [c.strip() for c in str(caption_text).split('|')]
180
+
181
+ if len(captions) == 1:
182
+ return f'<div style="padding:10px; background:#f8f9fa; border-radius:5px;">{captions[0]}</div>'
183
+ else:
184
+ html = f'<div style="padding:10px; background:#f8f9fa; border-radius:5px;"><strong>{len(captions)} descriptions:</strong><ol style="margin:5px 0; padding-left:20px;">'
185
+ for cap in captions:
186
+ html += f'<li>{cap}</li>'
187
+ html += '</ol></div>'
188
+ return html
189
+
190
+ # Get captions in correct order (swap if images were swapped)
191
+ if was_swapped:
192
+ print("[APP.PY] Swapping captions to match swapped images")
193
+ caption1_html = format_captions_html(row.get('image_2_description', 'No caption available'))
194
+ caption2_html = format_captions_html(row.get('image_1_description', 'No caption available'))
195
+ else:
196
+ caption1_html = format_captions_html(row.get('image_1_description', 'No caption available'))
197
+ caption2_html = format_captions_html(row.get('image_2_description', 'No caption available'))
198
+
199
+ # Get normalized values for bar plot
200
+ from analysis.corner_analyzer import CornerAnalyzer
201
 
202
+ # Get current brain measure from the last used one (or default to common)
203
+ brain_measure = getattr(self, '_current_brain_measure', 'cosine_similarity_roi_values_common')
204
 
205
+ # Normalize the three main values for this pair
206
+ human_norm = CornerAnalyzer.normalize_series(data['human_judgement']).iloc[row_index]
207
+ brain_norm = CornerAnalyzer.normalize_series(data[brain_measure]).iloc[row_index]
208
+
209
+ # Get ML model norm
210
+ current_ml_model = getattr(self, '_current_ml_model', None)
211
+ if current_ml_model is not None:
212
+ try:
213
+ if isinstance(current_ml_model, int) and current_ml_model < len(self.data_loader.ml_models):
214
+ ml_column = self.data_loader.ml_models[current_ml_model]
215
+ ml_norm = CornerAnalyzer.normalize_series(data[ml_column]).iloc[row_index]
216
+ elif isinstance(current_ml_model, str) and current_ml_model.startswith('avg_'):
217
+ ml_norm = CornerAnalyzer.normalize_series(data[current_ml_model]).iloc[row_index]
218
+ else:
219
+ ml_norm = 0.5 # default
220
+ except Exception:
221
+ ml_norm = 0.5
222
  else:
223
+ ml_norm = 0.5
224
+
225
+ # Create bar plot
226
+ fig, ax = plt.subplots(figsize=(6, 4))
227
+
228
+ categories = ['Human', 'Brain', 'ML']
229
+ values = [human_norm, brain_norm, ml_norm]
230
+ colors = ['#4A90E2', '#50C878', '#E24A4A']
231
+
232
+ # Create bars
233
+ bars = ax.bar(categories, values, color=colors, alpha=0.8, edgecolor='black', linewidth=1.5)
234
+
235
+ # Add value labels on top of bars
236
+ for bar, val in zip(bars, values):
237
+ height = bar.get_height()
238
+ ax.text(bar.get_x() + bar.get_width()/2., height + 0.02,
239
+ f'{val:.3f}',
240
+ ha='center', va='bottom', fontsize=11, fontweight='bold')
241
+
242
+ # Styling
243
+ ax.set_ylabel('Normalized Value (0-1)', fontsize=11, fontweight='bold')
244
+ ax.set_xlabel('Measure', fontsize=11, fontweight='bold')
245
+ ax.set_title(f'Normalized Values for Pair #{row_index}', fontsize=12, fontweight='bold')
246
+ ax.set_ylim(0, 1.15)
247
+ ax.grid(axis='y', alpha=0.3, linestyle='--')
248
+ ax.set_axisbelow(True)
249
+
250
+ # Style the plot
251
+ ax.spines['top'].set_visible(False)
252
+ ax.spines['right'].set_visible(False)
253
+
254
+ plt.tight_layout()
255
 
256
+ # Convert to base64 image
257
+ buf = io.BytesIO()
258
+ plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
259
+ buf.seek(0)
260
+ img_base64 = base64.b64encode(buf.read()).decode()
261
+ plt.close(fig)
262
+
263
+ bar_plot_html = f'<img src="data:image/png;base64,{img_base64}" style="width: 100%; max-width: 500px; margin: 20px auto; display: block; border: 1px solid #ddd; border-radius: 5px; padding: 10px; background: white;" />'
264
+
265
+ # Calculate averages if needed
266
+ if 'avg_vision' not in data.columns:
267
+ vision_models = [col for col in data.columns if 'BOLD5000_timm_' in col]
268
+ language_models = [col for col in data.columns if 'bert-' in col or 'deberta-' in col or 'sup-simcse' in col]
269
+ semantic_models = [col for col in data.columns if any(x in col for x in ["bm25", "rouge", "tf-idf", "co-occurrence"])]
270
+
271
+ def normalize_models(model_list):
272
+ if not model_list:
273
+ return pd.Series([0] * len(data))
274
+ normalized_data = []
275
+ for model in model_list:
276
+ if model in data.columns:
277
+ model_data = data[model]
278
+ normalized = (model_data - model_data.min()) / (model_data.max() - model_data.min())
279
+ normalized_data.append(normalized)
280
+ if normalized_data:
281
+ return pd.concat(normalized_data, axis=1).mean(axis=1)
282
  return pd.Series([0] * len(data))
283
+
284
+ data['avg_vision'] = normalize_models(vision_models)
285
+ data['avg_language'] = normalize_models(language_models)
286
+ data['avg_semantic'] = normalize_models(semantic_models)
287
+
288
+ # Get current model
289
+ current_ml_model = getattr(self, '_current_ml_model', None)
290
+ current_ml_name = getattr(self, '_current_ml_name', 'No model selected')
291
+ current_ml_score = 'N/A'
292
+ current_ml_norm = 'N/A'
293
+
294
+ if current_ml_model is not None:
295
+ try:
296
+ if isinstance(current_ml_model, int) and current_ml_model < len(self.data_loader.ml_models):
297
+ ml_column = self.data_loader.ml_models[current_ml_model]
298
+ current_ml_score = f"{row[ml_column]:.3f}"
299
+ current_ml_norm = f"{CornerAnalyzer.normalize_series(data[ml_column]).iloc[row_index]:.3f}"
300
+ elif isinstance(current_ml_model, str) and current_ml_model.startswith('avg_'):
301
+ current_ml_score = f"{row[current_ml_model]:.3f}"
302
+ current_ml_norm = f"{CornerAnalyzer.normalize_series(data[current_ml_model]).iloc[row_index]:.3f}"
303
+ except Exception:
304
+ pass
305
+
306
+ # Summary card HTML - WITH DEBUG INFO AND NORMALIZED VALUES
307
+ summary_html = f"""
308
+ <div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 20px; border-radius: 10px; box-shadow: 0 4px 6px rgba(0,0,0,0.1);">
309
+ <h3 style="margin: 0 0 10px 0;">Image Pair #{row_index} Summary</h3>
310
+ <div style="background: rgba(255,255,255,0.15); padding: 10px; border-radius: 5px; margin-bottom: 10px; font-family: monospace; font-size: 11px;">
311
+ <strong>🔍 DEBUG INFO:</strong><br>
312
+ Image 1 File: <code>{row['image_1']}</code><br>
313
+ Image 2 File: <code>{row['image_2']}</code>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
  </div>
315
+ <div><strong>Images:</strong> {row['image_1']} vs {row['image_2']}</div>
316
+ <div style="display: grid; grid-template-columns: repeat(3, 1fr); gap: 10px; margin-top: 15px;">
317
+ <div style="background: rgba(255,255,255,0.2); padding: 10px; border-radius: 5px; text-align: center;">
318
+ <div style="font-size: 11px; opacity: 0.9; margin-bottom: 5px;">Human Rating</div>
319
+ <div style="font-size: 20px; font-weight: bold;">{row['human_judgement']:.2f}/6</div>
320
+ <div style="font-size: 11px; opacity: 0.8; margin-top: 5px;">Norm: {human_norm:.3f}</div>
321
+ </div>
322
+ <div style="background: rgba(255,255,255,0.2); padding: 10px; border-radius: 5px; text-align: center;">
323
+ <div style="font-size: 11px; opacity: 0.9; margin-bottom: 5px;">Brain Similarity</div>
324
+ <div style="font-size: 20px; font-weight: bold;">{row.get(brain_measure, 0):.3f}</div>
325
+ <div style="font-size: 11px; opacity: 0.8; margin-top: 5px;">Norm: {brain_norm:.3f}</div>
326
+ </div>
327
+ <div style="background: rgba(255,255,255,0.2); padding: 10px; border-radius: 5px; text-align: center;">
328
+ <div style="font-size: 11px; opacity: 0.9; margin-bottom: 5px;">ML Model</div>
329
+ <div style="font-size: 20px; font-weight: bold;">{current_ml_score}</div>
330
+ <div style="font-size: 11px; opacity: 0.8; margin-top: 5px;">Norm: {ml_norm:.3f}</div>
331
+ </div>
332
  </div>
333
  </div>
334
+ """
335
+
336
+ # Brain measures table
337
+ brain_html = f"""
338
+ <table style="width: 100%; border-collapse: collapse;">
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
339
  <thead>
340
  <tr style="background: #f8f9fa;">
341
+ <th style="padding: 10px; text-align: left; border-bottom: 2px solid #dee2e6;">Measure</th>
342
  <th style="padding: 10px; text-align: left; border-bottom: 2px solid #dee2e6;">Raw</th>
343
  <th style="padding: 10px; text-align: left; border-bottom: 2px solid #dee2e6;">Norm</th>
344
+ <th style="padding: 10px; text-align: left; border-bottom: 2px solid #dee2e6;">Type</th>
345
  </tr>
346
  </thead>
347
  <tbody>
348
  <tr style="border-bottom: 1px solid #e9ecef;">
349
+ <td style="padding: 8px 10px;"><strong>Cosine - Common</strong></td>
350
+ <td style="padding: 8px 10px;">{row.get('cosine_similarity_roi_values_common', 0):.3f}</td>
351
+ <td style="padding: 8px 10px;">{CornerAnalyzer.normalize_series(data['cosine_similarity_roi_values_common']).iloc[row_index]:.3f}</td>
352
+ <td style="padding: 8px 10px;">All regions</td>
353
  </tr>
354
  <tr style="border-bottom: 1px solid #e9ecef;">
355
+ <td style="padding: 8px 10px;"><strong>Cosine - Early</strong></td>
356
+ <td style="padding: 8px 10px;">{row.get('cosine_similarity_roi_values_early', 0):.3f}</td>
357
+ <td style="padding: 8px 10px;">{CornerAnalyzer.normalize_series(data['cosine_similarity_roi_values_early']).iloc[row_index]:.3f}</td>
358
+ <td style="padding: 8px 10px;">Low-level</td>
359
  </tr>
360
  <tr style="border-bottom: 1px solid #e9ecef;">
361
+ <td style="padding: 8px 10px;"><strong>Cosine - Late</strong></td>
362
+ <td style="padding: 8px 10px;">{row.get('cosine_similarity_roi_values_late', 0):.3f}</td>
363
+ <td style="padding: 8px 10px;">{CornerAnalyzer.normalize_series(data['cosine_similarity_roi_values_late']).iloc[row_index]:.3f}</td>
364
+ <td style="padding: 8px 10px;">High-level</td>
 
 
 
 
 
 
 
 
 
 
365
  </tr>
 
 
366
  <tr style="border-bottom: 1px solid #e9ecef;">
367
+ <td style="padding: 8px 10px;"><strong>Pearson - Common</strong></td>
368
+ <td style="padding: 8px 10px;">{row.get('pearson_correlation_roi_values_common', 0):.3f}</td>
369
+ <td style="padding: 8px 10px;">{CornerAnalyzer.normalize_series(data['pearson_correlation_roi_values_common']).iloc[row_index]:.3f}</td>
370
+ <td style="padding: 8px 10px;">All regions</td>
371
  </tr>
372
  <tr style="border-bottom: 1px solid #e9ecef;">
373
+ <td style="padding: 8px 10px;"><strong>Pearson - Early</strong></td>
374
+ <td style="padding: 8px 10px;">{row.get('pearson_correlation_roi_values_early', 0):.3f}</td>
375
+ <td style="padding: 8px 10px;">{CornerAnalyzer.normalize_series(data['pearson_correlation_roi_values_early']).iloc[row_index]:.3f}</td>
376
+ <td style="padding: 8px 10px;">Low-level</td>
377
  </tr>
378
  <tr style="border-bottom: 1px solid #e9ecef;">
379
+ <td style="padding: 8px 10px;"><strong>Pearson - Late</strong></td>
380
+ <td style="padding: 8px 10px;">{row.get('pearson_correlation_roi_values_late', 0):.3f}</td>
381
+ <td style="padding: 8px 10px;">{CornerAnalyzer.normalize_series(data['pearson_correlation_roi_values_late']).iloc[row_index]:.3f}</td>
382
+ <td style="padding: 8px 10px;">High-level</td>
383
  </tr>
384
  </tbody>
385
  </table>
386
+ """
387
+
388
+ # Model performance HTML
389
+ model_html = f"""
390
+ <div style="margin-bottom: 15px;">
391
+ <strong>Category Averages</strong>
392
+ <table style="width: 100%; border-collapse: collapse; margin: 10px 0;">
393
+ <thead>
394
+ <tr style="background: #f8f9fa;">
395
+ <th style="padding: 10px; text-align: left; border-bottom: 2px solid #dee2e6;">Category</th>
396
+ <th style="padding: 10px; text-align: left; border-bottom: 2px solid #dee2e6;">Raw</th>
397
+ <th style="padding: 10px; text-align: left; border-bottom: 2px solid #dee2e6;">Norm</th>
398
+ </tr>
399
+ </thead>
400
+ <tbody>
401
+ <tr style="border-bottom: 1px solid #e9ecef;">
402
+ <td style="padding: 8px 10px;">Vision</td>
403
+ <td style="padding: 8px 10px;">{row.get('avg_vision', 0):.3f}</td>
404
+ <td style="padding: 8px 10px;">{CornerAnalyzer.normalize_series(data['avg_vision']).iloc[row_index]:.3f}</td>
405
+ </tr>
406
+ <tr style="border-bottom: 1px solid #e9ecef;">
407
+ <td style="padding: 8px 10px;">Language</td>
408
+ <td style="padding: 8px 10px;">{row.get('avg_language', 0):.3f}</td>
409
+ <td style="padding: 8px 10px;">{CornerAnalyzer.normalize_series(data['avg_language']).iloc[row_index]:.3f}</td>
410
+ </tr>
411
+ <tr style="border-bottom: 1px solid #e9ecef;">
412
+ <td style="padding: 8px 10px;">Semantic</td>
413
+ <td style="padding: 8px 10px;">{row.get('avg_semantic', 0):.3f}</td>
414
+ <td style="padding: 8px 10px;">{CornerAnalyzer.normalize_series(data['avg_semantic']).iloc[row_index]:.3f}</td>
415
+ </tr>
416
+ </tbody>
417
+ </table>
418
  </div>
419
+ <div>
420
+ <strong>Current Selection</strong>
421
+ <table style="width: 100%; border-collapse: collapse; margin: 10px 0;">
422
+ <thead>
423
+ <tr style="background: #f8f9fa;">
424
+ <th style="padding: 10px; text-align: left; border-bottom: 2px solid #dee2e6;">Info</th>
425
+ <th style="padding: 10px; text-align: left; border-bottom: 2px solid #dee2e6;">Value</th>
426
+ </tr>
427
+ </thead>
428
+ <tbody>
429
+ <tr style="border-bottom: 1px solid #e9ecef;">
430
+ <td style="padding: 8px 10px;">Model</td>
431
+ <td style="padding: 8px 10px;">{current_ml_name}</td>
432
+ </tr>
433
+ <tr style="border-bottom: 1px solid #e9ecef;">
434
+ <td style="padding: 8px 10px;">Raw</td>
435
+ <td style="padding: 8px 10px;">{current_ml_score}</td>
436
+ </tr>
437
+ <tr style="border-bottom: 1px solid #e9ecef;">
438
+ <td style="padding: 8px 10px;">Norm</td>
439
+ <td style="padding: 8px 10px;">{current_ml_norm}</td>
440
+ </tr>
441
+ </tbody>
442
+ </table>
443
+ <div style="margin-top: 15px; font-size: 12px; color: #666;">
444
+ <strong>Dataset:</strong> Vision: {len([col for col in data.columns if 'BOLD5000_timm_' in col])},
445
+ Language: {len([col for col in data.columns if 'bert-' in col or 'deberta-' in col or 'sup-simcse' in col])},
446
+ Semantic: {len([col for col in data.columns if any(x in col for x in ["bm25", "rouge", "tf-idf", "co-occurrence"])])} models
447
+ </div>
448
+ </div>
449
+ """
450
+
451
+ # Model rankings
452
+ rankings = self.get_model_rankings_for_pair(row_index)
453
+ rankings_html = ""
454
+ for category in ['vision', 'language', 'semantic']:
455
+ if category in rankings:
456
+ category_name = category.title()
457
+ rankings_html += f"""
458
+ <div style="margin: 15px 0;">
459
+ <strong>{category_name} Models:</strong>
460
+ <div style="display: grid; grid-template-columns: 1fr 1fr; gap: 15px; margin-top: 10px;">
461
+ <div>
462
+ <table style="width: 100%; border-collapse: collapse;">
463
+ <thead><tr style="background: #f8f9fa;"><th style="padding: 8px; text-align: left; border-bottom: 2px solid #dee2e6;">Top 3 Best</th><th style="padding: 8px; text-align: left; border-bottom: 2px solid #dee2e6;">Score</th></tr></thead>
464
+ <tbody>
465
+ """
466
+ for model, score in rankings[category]['best']:
467
+ clean_name = model.replace('BOLD5000_timm_', '').replace('_sim_partial', '') if 'BOLD5000_timm_' in model else model
468
+ rankings_html += f'<tr style="border-bottom: 1px solid #e9ecef;"><td style="padding: 6px 8px;">{clean_name}</td><td style="padding: 6px 8px;">{score:.3f}</td></tr>'
469
+
470
+ rankings_html += """
471
+ </tbody>
472
+ </table>
473
+ </div>
474
+ <div>
475
+ <table style="width: 100%; border-collapse: collapse;">
476
+ <thead><tr style="background: #f8f9fa;"><th style="padding: 8px; text-align: left; border-bottom: 2px solid #dee2e6;">Top 3 Worst</th><th style="padding: 8px; text-align: left; border-bottom: 2px solid #dee2e6;">Score</th></tr></thead>
477
+ <tbody>
478
+ """
479
+ for model, score in rankings[category]['worst']:
480
+ clean_name = model.replace('BOLD5000_timm_', '').replace('_sim_partial', '') if 'BOLD5000_timm_' in model else model
481
+ rankings_html += f'<tr style="border-bottom: 1px solid #e9ecef;"><td style="padding: 6px 8px;">{clean_name}</td><td style="padding: 6px 8px;">{score:.3f}</td></tr>'
482
+
483
+ rankings_html += """
484
+ </tbody>
485
+ </table>
486
+ </div>
487
  </div>
 
 
 
 
 
 
 
 
 
 
 
 
488
  </div>
489
+ """
490
+
491
+ # ROI plot
492
+ print(f"\nCalling ROI analyzer for pair #{row_index}...")
493
+ roi_plot = self.roi_analyzer.create_roi_comparison_plot(data, row_index)
494
+ print(f"ROI plot created successfully\n")
495
+
496
+ return (img1, img2, caption1_html, caption2_html, summary_html,
497
+ bar_plot_html, brain_html, model_html, rankings_html, roi_plot)
498
+
499
+ except Exception as e:
500
+ error_msg = f"<div style='color: red;'>Error: {e}</div>"
501
+ import traceback
502
+ traceback.print_exc()
503
+ return (None, None, error_msg, error_msg, error_msg, error_msg, error_msg, error_msg, error_msg, None)
504
 
505
  def set_current_model(self, ml_model_selection, ml_name):
506
  """Store the current ML model selection for display in image viewer"""
 
529
  def get_corner_interpretation(self, corner_name: str) -> str:
530
  """Get interpretation of a corner - delegates to CornerAnalyzer"""
531
  return self.corner_analyzer.get_interpretation(corner_name)
532
+
533
+ # Add this new method to the SimilarityApp class in app.py
534
+
535
+ # Add this new method to the SimilarityApp class in app.py
536
+
537
+ def get_point_corner_distances(self, row_index: int, brain_measure: str, ml_model_selection: Union[str, int]) -> Tuple[str, Optional[Any]]:
538
+ """Get distances from a specific point to all 8 corners and create a 3D visualization"""
539
+ try:
540
+ data = self.data_loader.data
541
+ if row_index >= len(data):
542
+ return "Invalid row index", None
543
+
544
+ # Get the data for this point
545
+ row = data.iloc[row_index]
546
+ ml_data, ml_name = self.plot_generator.get_model_data(ml_model_selection)
547
+
548
+ # Get raw values
549
+ human_raw = row['human_judgement']
550
+ brain_raw = row[brain_measure]
551
+ ml_raw = ml_data.iloc[row_index]
552
+
553
+ # Normalize to 0-1 for distance calculations
554
+ from analysis.corner_analyzer import CornerAnalyzer
555
+ human_norm = CornerAnalyzer.normalize_series(data['human_judgement']).iloc[row_index]
556
+ brain_norm = CornerAnalyzer.normalize_series(data[brain_measure]).iloc[row_index]
557
+ ml_norm = CornerAnalyzer.normalize_series(ml_data).iloc[row_index]
558
+
559
+ # Define all 8 corners
560
+ corners = {
561
+ '(0,0,0)': (0, 0, 0),
562
+ '(0,0,1)': (0, 0, 1),
563
+ '(0,1,0)': (0, 1, 0),
564
+ '(0,1,1)': (0, 1, 1),
565
+ '(1,0,0)': (1, 0, 0),
566
+ '(1,0,1)': (1, 0, 1),
567
+ '(1,1,0)': (1, 1, 0),
568
+ '(1,1,1)': (1, 1, 1)
569
+ }
570
+
571
+ # Calculate distances to each corner
572
+ point = np.array([human_norm, brain_norm, ml_norm])
573
+ distances = {}
574
+ for corner_name, corner_coords in corners.items():
575
+ corner_array = np.array(corner_coords)
576
+ distance = np.linalg.norm(point - corner_array)
577
+ distances[corner_name] = distance
578
+
579
+ # Sort by distance
580
+ sorted_distances = sorted(distances.items(), key=lambda x: x[1])
581
+
582
+ # Get corner interpretations
583
+ corner_meanings = {
584
+ '(0,0,0)': 'All Low - General disagreement on similarity',
585
+ '(0,0,1)': 'ML High, Human & Brain Low',
586
+ '(0,1,0)': 'Brain High, Human & ML Low',
587
+ '(0,1,1)': 'Brain & ML High, Human Low',
588
+ '(1,0,0)': 'Human High, Brain & ML Low',
589
+ '(1,0,1)': 'Human & ML High, Brain Low',
590
+ '(1,1,0)': 'Human & Brain High, ML Low',
591
+ '(1,1,1)': 'All High - Strong agreement on similarity'
592
+ }
593
+
594
+ # Create HTML table
595
+ html = f"""
596
+ <div style="background: #f8f9fa; padding: 20px; border-radius: 10px; margin: 20px 0;">
597
+ <h3 style="margin-top: 0;">Distance to Each Corner (Normalized Space)</h3>
598
+ <div style="background: white; padding: 15px; border-radius: 5px; margin-bottom: 15px;">
599
+ <strong>Current Point (Normalized 0-1):</strong><br>
600
+ Human: {human_norm:.3f} | Brain: {brain_norm:.3f} | ML: {ml_norm:.3f}
601
+ </div>
602
+ <table style="width: 100%; border-collapse: collapse; background: white;">
603
+ <thead>
604
+ <tr style="background: #667eea; color: white;">
605
+ <th style="padding: 12px; text-align: left;">Rank</th>
606
+ <th style="padding: 12px; text-align: left;">Corner (H,B,M)</th>
607
+ <th style="padding: 12px; text-align: left;">Distance</th>
608
+ <th style="padding: 12px; text-align: left;">Meaning</th>
609
+ </tr>
610
+ </thead>
611
+ <tbody>
612
+ """
613
+
614
+ for rank, (corner_name, distance) in enumerate(sorted_distances, 1):
615
+ row_color = "#e8f4f8" if rank == 1 else "white"
616
+ star = "⭐ " if rank == 1 else ""
617
+ html += f"""
618
+ <tr style="background: {row_color}; border-bottom: 1px solid #dee2e6;">
619
+ <td style="padding: 10px;"><strong>{star}{rank}</strong></td>
620
+ <td style="padding: 10px; font-family: monospace;">{corner_name}</td>
621
+ <td style="padding: 10px;"><strong>{distance:.4f}</strong></td>
622
+ <td style="padding: 10px; font-size: 12px;">{corner_meanings[corner_name]}</td>
623
+ </tr>
624
+ """
625
+
626
+ html += """
627
+ </tbody>
628
+ </table>
629
+ <div style="margin-top: 15px; font-size: 12px; color: #666;">
630
+ Distances calculated in normalized 0-1 space using Euclidean distance.
631
+ Closer corners indicate which extreme this pair is nearest to.
632
+ </div>
633
+ </div>
634
+ """
635
+
636
+ # Create 3D plot with the point highlighted
637
+ import plotly.graph_objects as go
638
+
639
+ # Normalize all data for the 3D plot
640
+ human_all_norm = CornerAnalyzer.normalize_series(data['human_judgement'])
641
+ brain_all_norm = CornerAnalyzer.normalize_series(data[brain_measure])
642
+ ml_all_norm = CornerAnalyzer.normalize_series(ml_data)
643
+
644
+ fig = go.Figure()
645
+
646
+ # Add all other points in gray
647
+ other_indices = [i for i in range(len(data)) if i != row_index]
648
+ fig.add_trace(go.Scatter3d(
649
+ x=human_all_norm.iloc[other_indices],
650
+ y=brain_all_norm.iloc[other_indices],
651
+ z=ml_all_norm.iloc[other_indices],
652
+ mode='markers',
653
+ marker=dict(size=3, color='lightgray', opacity=0.3),
654
+ name='Other pairs',
655
+ hoverinfo='skip'
656
+ ))
657
+
658
+ # Add the current point in VERY VISIBLE bright color with larger size
659
+ fig.add_trace(go.Scatter3d(
660
+ x=[human_norm],
661
+ y=[brain_norm],
662
+ z=[ml_norm],
663
+ mode='markers',
664
+ marker=dict(
665
+ size=25, # Much larger
666
+ color='#FF0000', # Bright red
667
+ symbol='diamond',
668
+ line=dict(color='yellow', width=4), # Yellow outline for extra visibility
669
+ opacity=1.0
670
+ ),
671
+ name=f'⭐ Current Pair #{row_index}',
672
+ hovertemplate=f'<b>⭐ CURRENT PAIR #{row_index}</b><br>' +
673
+ f'Human: {human_norm:.3f}<br>' +
674
+ f'Brain: {brain_norm:.3f}<br>' +
675
+ f'ML: {ml_norm:.3f}<br>' +
676
+ f'<extra></extra>'
677
+ ))
678
+
679
+ # Add corner points with larger size
680
+ corner_x = [c[0] for c in corners.values()]
681
+ corner_y = [c[1] for c in corners.values()]
682
+ corner_z = [c[2] for c in corners.values()]
683
+ corner_labels = [f"{name}<br>{corner_meanings[name]}" for name in corners.keys()]
684
+
685
+ fig.add_trace(go.Scatter3d(
686
+ x=corner_x,
687
+ y=corner_y,
688
+ z=corner_z,
689
+ mode='markers+text',
690
+ marker=dict(size=10, color='#4169E1', symbol='square', opacity=0.8), # Larger, more visible
691
+ text=list(corners.keys()),
692
+ textposition='top center',
693
+ textfont=dict(size=9, color='darkblue', family='Arial Black'),
694
+ name='Corners',
695
+ hovertext=corner_labels,
696
+ hoverinfo='text'
697
+ ))
698
+
699
+ # Add lines from current point to ALL corners with distance labels
700
+ for corner_name, distance in distances.items():
701
+ corner_coords = corners[corner_name]
702
+
703
+ # Determine line color based on distance (closer = more orange, farther = more gray)
704
+ # Normalize distance for color (distances range from 0 to sqrt(3) ≈ 1.732)
705
+ normalized_dist = distance / 1.732
706
+ if normalized_dist < 0.33:
707
+ line_color = '#FF4500' # OrangeRed for closest
708
+ line_width = 4
709
+ elif normalized_dist < 0.67:
710
+ line_color = '#FFA500' # Orange for medium
711
+ line_width = 3
712
+ else:
713
+ line_color = '#C0C0C0' # Silver for farthest
714
+ line_width = 2
715
+
716
+ # Add the line
717
+ fig.add_trace(go.Scatter3d(
718
+ x=[human_norm, corner_coords[0]],
719
+ y=[brain_norm, corner_coords[1]],
720
+ z=[ml_norm, corner_coords[2]],
721
+ mode='lines',
722
+ line=dict(color=line_color, width=line_width, dash='dot'),
723
+ showlegend=False,
724
+ hoverinfo='skip'
725
+ ))
726
+
727
+ # Add text label at midpoint of line showing distance
728
+ mid_x = (human_norm + corner_coords[0]) / 2
729
+ mid_y = (brain_norm + corner_coords[1]) / 2
730
+ mid_z = (ml_norm + corner_coords[2]) / 2
731
+
732
+ fig.add_trace(go.Scatter3d(
733
+ x=[mid_x],
734
+ y=[mid_y],
735
+ z=[mid_z],
736
+ mode='text',
737
+ text=[f'{distance:.3f}'],
738
+ textfont=dict(size=20, color=line_color),
739
+ showlegend=False,
740
+ hoverinfo='skip'
741
+ ))
742
+
743
+ brain_name = brain_measure.replace('cosine_similarity_roi_values_', '').replace('pearson_correlation_roi_values_', '').title()
744
+ measure_type = "Cosine" if "cosine" in brain_measure else "Pearson"
745
+
746
+ fig.update_layout(
747
+ title=f'Pair #{row_index} Position in Normalized 3D Space<br><sub>Red diamond shows current pair, blue squares show corners, dashed lines to 3 nearest corners</sub>',
748
+ scene=dict(
749
+ xaxis_title='Human (norm)',
750
+ yaxis_title=f'Brain {measure_type} (norm)',
751
+ zaxis_title=f'{ml_name} (norm)',
752
+ xaxis=dict(range=[0, 1]),
753
+ yaxis=dict(range=[0, 1]),
754
+ zaxis=dict(range=[0, 1]),
755
+ camera=dict(eye=dict(x=1.5, y=1.5, z=1.5))
756
+ ),
757
+ height=600,
758
+ showlegend=True,
759
+ legend=dict(x=0.7, y=0.9)
760
+ )
761
+
762
+ return html, fig
763
+
764
+ except Exception as e:
765
+ import traceback
766
+ traceback.print_exc()
767
+ return f"<div style='color: red;'>Error calculating distances: {e}</div>", None
768
 
769
  def main():
770
  """Main function to run the application"""
771
  try:
772
  # Create and launch the app
773
+ app = SimilarityApp('data/overall_database3.csv')
774
  app.launch(
775
  server_name="0.0.0.0",
776
  server_port=7860,
brain/__pycache__/roi_analyzer.cpython-311.pyc CHANGED
Binary files a/brain/__pycache__/roi_analyzer.cpython-311.pyc and b/brain/__pycache__/roi_analyzer.cpython-311.pyc differ
 
data/Final_similarity_matrix.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5ff83586ea8a37e91968eb0b73f23edebd61fdd88f815bef34a0f727c6d5ef35
3
+ size 39820273
data/Final_similarity_matrix2.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0cdd2d60df4ebd8de7533d82f7df5d4f8733d06134c641795d2f014c5de561b2
3
+ size 64812304
data/overall_database.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:08b27836f940efc44631b48b31964c02ae9ae0b7ef0c35ec4b33f1181e9e5480
3
+ size 44624446
data/overall_database2.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5dc09b2d3975398f84c9a734e0eac9f3a9db044a961600a91d8d34f5e1288bcb
3
+ size 46365658
data/overall_database3.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0a3c421e159026ec81497882cb4bc156f85cf93b21e5a3799eaf83bd653537aa
3
+ size 46506013
gui/__pycache__/corner_cases_tab.cpython-311.pyc CHANGED
Binary files a/gui/__pycache__/corner_cases_tab.cpython-311.pyc and b/gui/__pycache__/corner_cases_tab.cpython-311.pyc differ
 
gui/__pycache__/viewer_tab.cpython-311.pyc CHANGED
Binary files a/gui/__pycache__/viewer_tab.cpython-311.pyc and b/gui/__pycache__/viewer_tab.cpython-311.pyc differ
 
gui/corner_cases_tab.py CHANGED
@@ -3,6 +3,13 @@
3
 
4
  import gradio as gr
5
  from typing import TYPE_CHECKING
 
 
 
 
 
 
 
6
 
7
  if TYPE_CHECKING:
8
  from similarity_analysis.app import SimilarityApp
@@ -19,7 +26,6 @@ class CornerCasesTab:
19
  ml_options = self.app.data_loader.get_ml_model_options()
20
 
21
  gr.Markdown("## Corner Cases Analysis")
22
- # gr.Markdown("Find the top 10 image pairs closest to each corner of the 3D space (Human × Brain × ML)")
23
 
24
  with gr.Row():
25
  with gr.Column(scale=1):
@@ -82,6 +88,49 @@ class CornerCasesTab:
82
  'results_display': results_display
83
  }
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  def connect_events(self, components):
86
  """Connect event handlers for this tab"""
87
 
@@ -115,7 +164,7 @@ class CornerCasesTab:
115
 
116
  if show_images:
117
  # Show results with images in a grid
118
- output += "<div style='display: grid; grid-template-columns: repeat(auto-fill, minmax(400px, 1fr)); gap: 20px; margin-bottom: 30px;'>"
119
 
120
  for rank, result in enumerate(results, 1):
121
  output += "<div style='border: 1px solid #ddd; padding: 15px; border-radius: 8px; background: #f9f9f9;'>"
@@ -123,8 +172,28 @@ class CornerCasesTab:
123
 
124
  # Get image URLs and captions
125
  data = self.app.data_loader.data
126
- img1_url = data.iloc[result['index']]['stim_1']
127
- img2_url = data.iloc[result['index']]['stim_2']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
  # Format captions (handle multiple descriptions separated by |)
130
  def format_caption_html(caption_text):
@@ -140,8 +209,8 @@ class CornerCasesTab:
140
  html += "</ol>"
141
  return html
142
 
143
- caption1 = format_caption_html(data.iloc[result['index']].get('image_1_description', 'No caption available'))
144
- caption2 = format_caption_html(data.iloc[result['index']].get('image_2_description', 'No caption available'))
145
 
146
  # Display images side by side
147
  output += "<div style='display: flex; gap: 10px; margin: 10px 0;'>"
@@ -170,12 +239,19 @@ class CornerCasesTab:
170
  output += f"<tr><td>ML:</td><td>{result['ml_norm']:.3f}</td></tr>"
171
  output += "</table>"
172
 
 
 
 
 
 
 
 
173
  output += "</div>"
174
 
175
  output += "</div>"
176
  else:
177
  # Text-only table format
178
- output += "<table style='width: 100%; border-collapse: collapse; margin-bottom: 30px;'>"
179
  output += "<thead><tr style='background: #f0f0f0;'>"
180
  output += "<th style='border: 1px solid #ddd; padding: 8px;'>Rank</th>"
181
  output += "<th style='border: 1px solid #ddd; padding: 8px;'>Pair #</th>"
 
3
 
4
  import gradio as gr
5
  from typing import TYPE_CHECKING
6
+ import plotly.graph_objects as go
7
+ from plotly.subplots import make_subplots
8
+ import matplotlib.pyplot as plt
9
+ import matplotlib
10
+ matplotlib.use('Agg') # Use non-interactive backend
11
+ import io
12
+ import base64
13
 
14
  if TYPE_CHECKING:
15
  from similarity_analysis.app import SimilarityApp
 
26
  ml_options = self.app.data_loader.get_ml_model_options()
27
 
28
  gr.Markdown("## Corner Cases Analysis")
 
29
 
30
  with gr.Row():
31
  with gr.Column(scale=1):
 
88
  'results_display': results_display
89
  }
90
 
91
+ def create_single_pair_bar_plot(self, result, pair_index):
92
+ """Create a bar plot showing normalized values for a single pair using matplotlib"""
93
+
94
+ # Create figure
95
+ fig, ax = plt.subplots(figsize=(6, 4))
96
+
97
+ categories = ['Human', 'Brain', 'ML']
98
+ values = [result['human_norm'], result['brain_norm'], result['ml_norm']]
99
+ colors = ['#4A90E2', '#50C878', '#E24A4A']
100
+
101
+ # Create bars
102
+ bars = ax.bar(categories, values, color=colors, alpha=0.8, edgecolor='black', linewidth=1.5)
103
+
104
+ # Add value labels on top of bars
105
+ for bar, val in zip(bars, values):
106
+ height = bar.get_height()
107
+ ax.text(bar.get_x() + bar.get_width()/2., height + 0.02,
108
+ f'{val:.3f}',
109
+ ha='center', va='bottom', fontsize=11, fontweight='bold')
110
+
111
+ # Styling
112
+ ax.set_ylabel('Normalized Value (0-1)', fontsize=11, fontweight='bold')
113
+ ax.set_xlabel('Measure', fontsize=11, fontweight='bold')
114
+ ax.set_title(f'Normalized Values for Pair #{pair_index}', fontsize=12, fontweight='bold')
115
+ ax.set_ylim(0, 1.15)
116
+ ax.grid(axis='y', alpha=0.3, linestyle='--')
117
+ ax.set_axisbelow(True)
118
+
119
+ # Style the plot
120
+ ax.spines['top'].set_visible(False)
121
+ ax.spines['right'].set_visible(False)
122
+
123
+ plt.tight_layout()
124
+
125
+ # Convert to base64 image
126
+ buf = io.BytesIO()
127
+ plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
128
+ buf.seek(0)
129
+ img_base64 = base64.b64encode(buf.read()).decode()
130
+ plt.close(fig)
131
+
132
+ return f'<img src="data:image/png;base64,{img_base64}" style="width: 100%; max-width: 450px; margin: 10px auto; display: block; border: 1px solid #ddd; border-radius: 5px; padding: 5px; background: white;" />'
133
+
134
  def connect_events(self, components):
135
  """Connect event handlers for this tab"""
136
 
 
164
 
165
  if show_images:
166
  # Show results with images in a grid
167
+ output += "<div style='display: grid; grid-template-columns: repeat(auto-fill, minmax(400px, 1fr)); gap: 20px; margin-bottom: 30px; margin-top: 20px;'>"
168
 
169
  for rank, result in enumerate(results, 1):
170
  output += "<div style='border: 1px solid #ddd; padding: 15px; border-radius: 8px; background: #f9f9f9;'>"
 
172
 
173
  # Get image URLs and captions
174
  data = self.app.data_loader.data
175
+ pair_row = data.iloc[result['index']]
176
+
177
+ # Get URLs
178
+ img1_url = pair_row['stim_1']
179
+ img2_url = pair_row['stim_2']
180
+
181
+ # Check if we need to swap URLs to match image_1 and image_2 filenames
182
+ image_1_filename = str(pair_row.get('image_1', ''))
183
+ image_2_filename = str(pair_row.get('image_2', ''))
184
+
185
+ stim_1_matches_image_1 = image_1_filename in str(img1_url)
186
+ stim_1_matches_image_2 = image_2_filename in str(img1_url)
187
+
188
+ # Swap if needed
189
+ if stim_1_matches_image_2 and not stim_1_matches_image_1:
190
+ img1_url, img2_url = img2_url, img1_url
191
+ # Also swap captions
192
+ caption1_data = pair_row.get('image_2_description', 'No caption available')
193
+ caption2_data = pair_row.get('image_1_description', 'No caption available')
194
+ else:
195
+ caption1_data = pair_row.get('image_1_description', 'No caption available')
196
+ caption2_data = pair_row.get('image_2_description', 'No caption available')
197
 
198
  # Format captions (handle multiple descriptions separated by |)
199
  def format_caption_html(caption_text):
 
209
  html += "</ol>"
210
  return html
211
 
212
+ caption1 = format_caption_html(caption1_data)
213
+ caption2 = format_caption_html(caption2_data)
214
 
215
  # Display images side by side
216
  output += "<div style='display: flex; gap: 10px; margin: 10px 0;'>"
 
239
  output += f"<tr><td>ML:</td><td>{result['ml_norm']:.3f}</td></tr>"
240
  output += "</table>"
241
 
242
+ # Add bar plot for this specific pair
243
+ try:
244
+ bar_plot_html = self.create_single_pair_bar_plot(result, result['index'])
245
+ output += f"<div style='margin: 15px 0;'>{bar_plot_html}</div>"
246
+ except Exception as e:
247
+ output += f"<div style='color: red; padding: 10px;'>Error creating plot: {e}</div>"
248
+
249
  output += "</div>"
250
 
251
  output += "</div>"
252
  else:
253
  # Text-only table format
254
+ output += "<table style='width: 100%; border-collapse: collapse; margin-bottom: 30px; margin-top: 20px;'>"
255
  output += "<thead><tr style='background: #f0f0f0;'>"
256
  output += "<th style='border: 1px solid #ddd; padding: 8px;'>Rank</th>"
257
  output += "<th style='border: 1px solid #ddd; padding: 8px;'>Pair #</th>"
gui/viewer_tab.py CHANGED
@@ -26,6 +26,25 @@ class ViewerTab:
26
  info=f"Enter 0 to {len(self.app.data_loader.data)-1}",
27
  precision=0
28
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  show_btn = gr.Button("Show Images & Details", variant="primary", size="lg")
30
 
31
  # Images side by side
@@ -44,6 +63,12 @@ class ViewerTab:
44
  with gr.Row():
45
  summary_card = gr.HTML("<div>Select an image pair to see summary</div>")
46
 
 
 
 
 
 
 
47
  # Brain measures and Model performance side by side
48
  with gr.Row():
49
  with gr.Column(scale=1):
@@ -66,55 +91,99 @@ class ViewerTab:
66
  gr.Markdown("### ROI Brain Activation Comparison")
67
  roi_plot = gr.Plot(label="Side-by-Side ROI Values", show_label=False)
68
 
 
 
 
 
 
 
 
 
 
 
 
69
  return {
70
  'row_input': row_input,
 
 
71
  'show_btn': show_btn,
72
  'image1_display': image1_display,
73
  'image2_display': image2_display,
74
  'caption1_display': caption1_display,
75
  'caption2_display': caption2_display,
76
  'summary_card': summary_card,
 
77
  'brain_table': brain_table,
78
  'model_table': model_table,
79
  'rankings_display': rankings_display,
80
- 'roi_plot': roi_plot
 
 
81
  }
82
 
83
  def connect_events(self, components):
84
  """Connect event handlers for this tab"""
85
- def show_images_and_details(row_idx):
 
 
 
 
 
86
  results = self.app.show_image_pair_multi(int(row_idx) if row_idx is not None else 0)
87
- return results
 
 
 
 
 
 
 
 
 
88
 
89
- # Connect with multiple outputs
90
  components['show_btn'].click(
91
  fn=show_images_and_details,
92
- inputs=[components['row_input']],
 
 
 
 
93
  outputs=[
94
  components['image1_display'],
95
  components['image2_display'],
96
  components['caption1_display'],
97
  components['caption2_display'],
98
  components['summary_card'],
 
99
  components['brain_table'],
100
  components['model_table'],
101
  components['rankings_display'],
102
- components['roi_plot']
 
 
103
  ]
104
  )
105
 
106
  components['row_input'].change(
107
  fn=show_images_and_details,
108
- inputs=[components['row_input']],
 
 
 
 
109
  outputs=[
110
  components['image1_display'],
111
  components['image2_display'],
112
  components['caption1_display'],
113
  components['caption2_display'],
114
  components['summary_card'],
 
115
  components['brain_table'],
116
  components['model_table'],
117
  components['rankings_display'],
118
- components['roi_plot']
 
 
119
  ]
120
  )
 
26
  info=f"Enter 0 to {len(self.app.data_loader.data)-1}",
27
  precision=0
28
  )
29
+
30
+ # Add dropdowns for corner distance calculation
31
+ brain_options = self.app.data_loader.get_brain_measure_options()
32
+ ml_options = self.app.data_loader.get_ml_model_options()
33
+
34
+ brain_dropdown = gr.Dropdown(
35
+ choices=brain_options,
36
+ value=brain_options[0][1] if brain_options else None,
37
+ label="Brain Measure (for 3D position)",
38
+ info="Select brain measure for corner distance calculation"
39
+ )
40
+
41
+ ml_dropdown = gr.Dropdown(
42
+ choices=ml_options,
43
+ value=ml_options[0][1] if ml_options else None,
44
+ label="ML Model (for 3D position)",
45
+ info="Select ML model for corner distance calculation"
46
+ )
47
+
48
  show_btn = gr.Button("Show Images & Details", variant="primary", size="lg")
49
 
50
  # Images side by side
 
63
  with gr.Row():
64
  summary_card = gr.HTML("<div>Select an image pair to see summary</div>")
65
 
66
+ # Bar plot for normalized values
67
+ with gr.Row():
68
+ with gr.Column():
69
+ gr.Markdown("### Normalized Values Visualization")
70
+ bar_plot_display = gr.HTML("<div>Bar plot will appear here</div>")
71
+
72
  # Brain measures and Model performance side by side
73
  with gr.Row():
74
  with gr.Column(scale=1):
 
91
  gr.Markdown("### ROI Brain Activation Comparison")
92
  roi_plot = gr.Plot(label="Side-by-Side ROI Values", show_label=False)
93
 
94
+ # NEW: Corner distances section
95
+ with gr.Row():
96
+ with gr.Column():
97
+ gr.Markdown("### Position in 3D Space & Corner Distances")
98
+ gr.Markdown("This shows where this specific pair sits in the normalized 3D space (Human × Brain × ML) and its distance to each of the 8 corners.")
99
+ corner_distance_table = gr.HTML("<div>Select parameters and click 'Show Images & Details' to see corner distances</div>")
100
+
101
+ with gr.Row():
102
+ with gr.Column():
103
+ corner_3d_plot = gr.Plot(label="3D Position Visualization", show_label=False)
104
+
105
  return {
106
  'row_input': row_input,
107
+ 'brain_dropdown': brain_dropdown,
108
+ 'ml_dropdown': ml_dropdown,
109
  'show_btn': show_btn,
110
  'image1_display': image1_display,
111
  'image2_display': image2_display,
112
  'caption1_display': caption1_display,
113
  'caption2_display': caption2_display,
114
  'summary_card': summary_card,
115
+ 'bar_plot_display': bar_plot_display,
116
  'brain_table': brain_table,
117
  'model_table': model_table,
118
  'rankings_display': rankings_display,
119
+ 'roi_plot': roi_plot,
120
+ 'corner_distance_table': corner_distance_table,
121
+ 'corner_3d_plot': corner_3d_plot
122
  }
123
 
124
  def connect_events(self, components):
125
  """Connect event handlers for this tab"""
126
+ def show_images_and_details(row_idx, brain_measure, ml_model_selection):
127
+ # Store the current brain measure and ML model in the app for the bar plot
128
+ self.app._current_brain_measure = brain_measure
129
+ self.app._current_ml_model = ml_model_selection
130
+
131
+ # Get basic image pair info
132
  results = self.app.show_image_pair_multi(int(row_idx) if row_idx is not None else 0)
133
+
134
+ # Get corner distances and 3D plot
135
+ corner_html, corner_plot = self.app.get_point_corner_distances(
136
+ int(row_idx) if row_idx is not None else 0,
137
+ brain_measure,
138
+ ml_model_selection
139
+ )
140
+
141
+ # Return all outputs including the new corner distance outputs
142
+ return (*results, corner_html, corner_plot)
143
 
144
+ # Connect with multiple outputs (added bar_plot_display)
145
  components['show_btn'].click(
146
  fn=show_images_and_details,
147
+ inputs=[
148
+ components['row_input'],
149
+ components['brain_dropdown'],
150
+ components['ml_dropdown']
151
+ ],
152
  outputs=[
153
  components['image1_display'],
154
  components['image2_display'],
155
  components['caption1_display'],
156
  components['caption2_display'],
157
  components['summary_card'],
158
+ components['bar_plot_display'],
159
  components['brain_table'],
160
  components['model_table'],
161
  components['rankings_display'],
162
+ components['roi_plot'],
163
+ components['corner_distance_table'],
164
+ components['corner_3d_plot']
165
  ]
166
  )
167
 
168
  components['row_input'].change(
169
  fn=show_images_and_details,
170
+ inputs=[
171
+ components['row_input'],
172
+ components['brain_dropdown'],
173
+ components['ml_dropdown']
174
+ ],
175
  outputs=[
176
  components['image1_display'],
177
  components['image2_display'],
178
  components['caption1_display'],
179
  components['caption2_display'],
180
  components['summary_card'],
181
+ components['bar_plot_display'],
182
  components['brain_table'],
183
  components['model_table'],
184
  components['rankings_display'],
185
+ components['roi_plot'],
186
+ components['corner_distance_table'],
187
+ components['corner_3d_plot']
188
  ]
189
  )
visualization/__pycache__/image_viewer.cpython-311.pyc CHANGED
Binary files a/visualization/__pycache__/image_viewer.cpython-311.pyc and b/visualization/__pycache__/image_viewer.cpython-311.pyc differ
 
visualization/image_viewer.py CHANGED
@@ -24,16 +24,58 @@ class ImageViewer:
24
  return placeholder
25
 
26
  @staticmethod
27
- def get_image_pair(data: pd.DataFrame, row_index: int) -> Tuple[Optional[Image.Image], Optional[Image.Image]]:
28
- """Get image pair for a specific row"""
 
 
 
 
 
 
 
29
  if row_index >= len(data):
30
- return None, None
31
 
32
  row = data.iloc[row_index]
 
 
 
 
 
 
 
 
 
33
  img1_url = row.get('stim_1', '')
34
  img2_url = row.get('stim_2', '')
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  img1 = ImageViewer.load_image_from_url(img1_url) if img1_url else None
37
  img2 = ImageViewer.load_image_from_url(img2_url) if img2_url else None
38
 
39
- return img1, img2
 
24
  return placeholder
25
 
26
  @staticmethod
27
+ def get_image_pair(data: pd.DataFrame, row_index: int) -> Tuple[Optional[Image.Image], Optional[Image.Image], bool]:
28
+ """Get image pair for a specific row
29
+
30
+ Returns:
31
+ tuple: (img1, img2, was_swapped)
32
+ - img1: Image for image_1
33
+ - img2: Image for image_2
34
+ - was_swapped: True if URLs were swapped to match filenames
35
+ """
36
  if row_index >= len(data):
37
+ return None, None, False
38
 
39
  row = data.iloc[row_index]
40
+ was_swapped = False
41
+
42
+ # DEBUG: Check which columns exist
43
+ print(f"\n[IMAGE_VIEWER DEBUG] Available URL columns: {[col for col in row.index if 'stim' in col.lower() or 'url' in col.lower()]}")
44
+ print(f"[IMAGE_VIEWER DEBUG] image_1: {row.get('image_1', 'MISSING')}")
45
+ print(f"[IMAGE_VIEWER DEBUG] image_2: {row.get('image_2', 'MISSING')}")
46
+
47
+ # Try different possible column names for URLs
48
+ # Priority: use stim_1/stim_2 which are the URL columns
49
  img1_url = row.get('stim_1', '')
50
  img2_url = row.get('stim_2', '')
51
 
52
+ print(f"[IMAGE_VIEWER DEBUG] Loading stim_1 URL: {img1_url[:80] if img1_url else 'EMPTY'}...")
53
+ print(f"[IMAGE_VIEWER DEBUG] Loading stim_2 URL: {img2_url[:80] if img2_url else 'EMPTY'}...")
54
+
55
+ # Check if stim_1 corresponds to image_1 or image_2 by looking at the filename in the URL
56
+ # If the URL contains the image_1 filename, then stim_1 = image_1
57
+ # Otherwise they might be swapped
58
+ image_1_filename = str(row.get('image_1', ''))
59
+ image_2_filename = str(row.get('image_2', ''))
60
+
61
+ # Check if we need to swap
62
+ stim_1_matches_image_1 = image_1_filename in str(img1_url)
63
+ stim_1_matches_image_2 = image_2_filename in str(img1_url)
64
+
65
+ print(f"[IMAGE_VIEWER DEBUG] stim_1 contains image_1 filename? {stim_1_matches_image_1}")
66
+ print(f"[IMAGE_VIEWER DEBUG] stim_1 contains image_2 filename? {stim_1_matches_image_2}")
67
+
68
+ # If stim_1 contains image_2's filename, we need to swap
69
+ if stim_1_matches_image_2 and not stim_1_matches_image_1:
70
+ print("[IMAGE_VIEWER DEBUG] ⚠️ SWAPPING: stim_1 corresponds to image_2, swapping URLs")
71
+ img1_url, img2_url = img2_url, img1_url
72
+ was_swapped = True
73
+ elif stim_1_matches_image_1:
74
+ print("[IMAGE_VIEWER DEBUG] ✓ No swap needed: stim_1 corresponds to image_1")
75
+ else:
76
+ print("[IMAGE_VIEWER DEBUG] ⚠️ WARNING: Could not determine correspondence, assuming stim_1=image_1")
77
+
78
  img1 = ImageViewer.load_image_from_url(img1_url) if img1_url else None
79
  img2 = ImageViewer.load_image_from_url(img2_url) if img2_url else None
80
 
81
+ return img1, img2, was_swapped