Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- .gitattributes +5 -0
- app.py +570 -234
- brain/__pycache__/roi_analyzer.cpython-311.pyc +0 -0
- data/Final_similarity_matrix.csv +3 -0
- data/Final_similarity_matrix2.csv +3 -0
- data/overall_database.csv +3 -0
- data/overall_database2.csv +3 -0
- data/overall_database3.csv +3 -0
- gui/__pycache__/corner_cases_tab.cpython-311.pyc +0 -0
- gui/__pycache__/viewer_tab.cpython-311.pyc +0 -0
- gui/corner_cases_tab.py +83 -7
- gui/viewer_tab.py +77 -8
- visualization/__pycache__/image_viewer.cpython-311.pyc +0 -0
- visualization/image_viewer.py +46 -4
.gitattributes
CHANGED
|
@@ -38,3 +38,8 @@ Final_similarity_matrix2.csv filter=lfs diff=lfs merge=lfs -text
|
|
| 38 |
overall_database.csv filter=lfs diff=lfs merge=lfs -text
|
| 39 |
overall_database2.csv filter=lfs diff=lfs merge=lfs -text
|
| 40 |
overall_database3.csv filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
overall_database.csv filter=lfs diff=lfs merge=lfs -text
|
| 39 |
overall_database2.csv filter=lfs diff=lfs merge=lfs -text
|
| 40 |
overall_database3.csv filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
data/Final_similarity_matrix.csv filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
data/Final_similarity_matrix2.csv filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
data/overall_database.csv filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
data/overall_database2.csv filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
data/overall_database3.csv filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
|
@@ -15,6 +15,8 @@ from visualization.image_viewer import ImageViewer
|
|
| 15 |
from gui.gradio_interface import GradioInterface
|
| 16 |
from typing import Tuple, Optional, Union, Dict, Any
|
| 17 |
import pandas as pd
|
|
|
|
|
|
|
| 18 |
|
| 19 |
class SimilarityApp:
|
| 20 |
"""Main application class that orchestrates all components"""
|
|
@@ -135,272 +137,370 @@ class SimilarityApp:
|
|
| 135 |
print(f"Error getting model rankings: {e}")
|
| 136 |
return {}
|
| 137 |
|
| 138 |
-
# Replace the show_image_pair method in SimilarityApp class
|
| 139 |
-
# This version adds normalized values to the display
|
| 140 |
-
|
| 141 |
-
# Add this new method to SimilarityApp class in app.py
|
| 142 |
-
# This returns multiple separate outputs for better Gradio layout
|
| 143 |
-
|
| 144 |
def show_image_pair_multi(self, row_index: int):
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
|
| 160 |
-
|
|
|
|
| 161 |
|
| 162 |
-
|
| 163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
else:
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
|
| 183 |
-
|
| 184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
return pd.Series([0] * len(data))
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
# Summary card HTML
|
| 219 |
-
summary_html = f"""
|
| 220 |
-
<div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 20px; border-radius: 10px; box-shadow: 0 4px 6px rgba(0,0,0,0.1);">
|
| 221 |
-
<h3 style="margin: 0 0 10px 0;">Image Pair #{row_index} Summary</h3>
|
| 222 |
-
<div><strong>Images:</strong> {row['image_1']} vs {row['image_2']}</div>
|
| 223 |
-
<div style="display: grid; grid-template-columns: repeat(3, 1fr); gap: 10px; margin-top: 15px;">
|
| 224 |
-
<div style="background: rgba(255,255,255,0.2); padding: 10px; border-radius: 5px; text-align: center;">
|
| 225 |
-
<div style="font-size: 11px; opacity: 0.9; margin-bottom: 5px;">Human Rating</div>
|
| 226 |
-
<div style="font-size: 20px; font-weight: bold;">{row['human_judgement']:.2f}/6</div>
|
| 227 |
-
</div>
|
| 228 |
-
<div style="background: rgba(255,255,255,0.2); padding: 10px; border-radius: 5px; text-align: center;">
|
| 229 |
-
<div style="font-size: 11px; opacity: 0.9; margin-bottom: 5px;">Brain Similarity</div>
|
| 230 |
-
<div style="font-size: 20px; font-weight: bold;">{row.get('cosine_similarity_roi_values_common', 0):.3f}</div>
|
| 231 |
</div>
|
| 232 |
-
<div
|
| 233 |
-
|
| 234 |
-
<div style="
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
</div>
|
| 236 |
</div>
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
<table style="width: 100%; border-collapse: collapse;">
|
| 243 |
-
<thead>
|
| 244 |
-
<tr style="background: #f8f9fa;">
|
| 245 |
-
<th style="padding: 10px; text-align: left; border-bottom: 2px solid #dee2e6;">Measure</th>
|
| 246 |
-
<th style="padding: 10px; text-align: left; border-bottom: 2px solid #dee2e6;">Raw</th>
|
| 247 |
-
<th style="padding: 10px; text-align: left; border-bottom: 2px solid #dee2e6;">Norm</th>
|
| 248 |
-
<th style="padding: 10px; text-align: left; border-bottom: 2px solid #dee2e6;">Type</th>
|
| 249 |
-
</tr>
|
| 250 |
-
</thead>
|
| 251 |
-
<tbody>
|
| 252 |
-
<tr style="border-bottom: 1px solid #e9ecef;">
|
| 253 |
-
<td style="padding: 8px 10px;"><strong>Cosine - Common</strong></td>
|
| 254 |
-
<td style="padding: 8px 10px;">{row.get('cosine_similarity_roi_values_common', 0):.3f}</td>
|
| 255 |
-
<td style="padding: 8px 10px;">{CornerAnalyzer.normalize_series(data['cosine_similarity_roi_values_common']).iloc[row_index]:.3f}</td>
|
| 256 |
-
<td style="padding: 8px 10px;">All regions</td>
|
| 257 |
-
</tr>
|
| 258 |
-
<tr style="border-bottom: 1px solid #e9ecef;">
|
| 259 |
-
<td style="padding: 8px 10px;"><strong>Cosine - Early</strong></td>
|
| 260 |
-
<td style="padding: 8px 10px;">{row.get('cosine_similarity_roi_values_early', 0):.3f}</td>
|
| 261 |
-
<td style="padding: 8px 10px;">{CornerAnalyzer.normalize_series(data['cosine_similarity_roi_values_early']).iloc[row_index]:.3f}</td>
|
| 262 |
-
<td style="padding: 8px 10px;">Low-level</td>
|
| 263 |
-
</tr>
|
| 264 |
-
<tr style="border-bottom: 1px solid #e9ecef;">
|
| 265 |
-
<td style="padding: 8px 10px;"><strong>Cosine - Late</strong></td>
|
| 266 |
-
<td style="padding: 8px 10px;">{row.get('cosine_similarity_roi_values_late', 0):.3f}</td>
|
| 267 |
-
<td style="padding: 8px 10px;">{CornerAnalyzer.normalize_series(data['cosine_similarity_roi_values_late']).iloc[row_index]:.3f}</td>
|
| 268 |
-
<td style="padding: 8px 10px;">High-level</td>
|
| 269 |
-
</tr>
|
| 270 |
-
<tr style="border-bottom: 1px solid #e9ecef;">
|
| 271 |
-
<td style="padding: 8px 10px;"><strong>Pearson - Common</strong></td>
|
| 272 |
-
<td style="padding: 8px 10px;">{row.get('pearson_correlation_roi_values_common', 0):.3f}</td>
|
| 273 |
-
<td style="padding: 8px 10px;">{CornerAnalyzer.normalize_series(data['pearson_correlation_roi_values_common']).iloc[row_index]:.3f}</td>
|
| 274 |
-
<td style="padding: 8px 10px;">All regions</td>
|
| 275 |
-
</tr>
|
| 276 |
-
<tr style="border-bottom: 1px solid #e9ecef;">
|
| 277 |
-
<td style="padding: 8px 10px;"><strong>Pearson - Early</strong></td>
|
| 278 |
-
<td style="padding: 8px 10px;">{row.get('pearson_correlation_roi_values_early', 0):.3f}</td>
|
| 279 |
-
<td style="padding: 8px 10px;">{CornerAnalyzer.normalize_series(data['pearson_correlation_roi_values_early']).iloc[row_index]:.3f}</td>
|
| 280 |
-
<td style="padding: 8px 10px;">Low-level</td>
|
| 281 |
-
</tr>
|
| 282 |
-
<tr style="border-bottom: 1px solid #e9ecef;">
|
| 283 |
-
<td style="padding: 8px 10px;"><strong>Pearson - Late</strong></td>
|
| 284 |
-
<td style="padding: 8px 10px;">{row.get('pearson_correlation_roi_values_late', 0):.3f}</td>
|
| 285 |
-
<td style="padding: 8px 10px;">{CornerAnalyzer.normalize_series(data['pearson_correlation_roi_values_late']).iloc[row_index]:.3f}</td>
|
| 286 |
-
<td style="padding: 8px 10px;">High-level</td>
|
| 287 |
-
</tr>
|
| 288 |
-
</tbody>
|
| 289 |
-
</table>
|
| 290 |
-
"""
|
| 291 |
-
|
| 292 |
-
# Model performance HTML
|
| 293 |
-
model_html = f"""
|
| 294 |
-
<div style="margin-bottom: 15px;">
|
| 295 |
-
<strong>Category Averages</strong>
|
| 296 |
-
<table style="width: 100%; border-collapse: collapse; margin: 10px 0;">
|
| 297 |
<thead>
|
| 298 |
<tr style="background: #f8f9fa;">
|
| 299 |
-
<th style="padding: 10px; text-align: left; border-bottom: 2px solid #dee2e6;">
|
| 300 |
<th style="padding: 10px; text-align: left; border-bottom: 2px solid #dee2e6;">Raw</th>
|
| 301 |
<th style="padding: 10px; text-align: left; border-bottom: 2px solid #dee2e6;">Norm</th>
|
|
|
|
| 302 |
</tr>
|
| 303 |
</thead>
|
| 304 |
<tbody>
|
| 305 |
<tr style="border-bottom: 1px solid #e9ecef;">
|
| 306 |
-
<td style="padding: 8px 10px;">
|
| 307 |
-
<td style="padding: 8px 10px;">{row.get('
|
| 308 |
-
<td style="padding: 8px 10px;">{CornerAnalyzer.normalize_series(data['
|
|
|
|
| 309 |
</tr>
|
| 310 |
<tr style="border-bottom: 1px solid #e9ecef;">
|
| 311 |
-
<td style="padding: 8px 10px;">
|
| 312 |
-
<td style="padding: 8px 10px;">{row.get('
|
| 313 |
-
<td style="padding: 8px 10px;">{CornerAnalyzer.normalize_series(data['
|
|
|
|
| 314 |
</tr>
|
| 315 |
<tr style="border-bottom: 1px solid #e9ecef;">
|
| 316 |
-
<td style="padding: 8px 10px;">
|
| 317 |
-
<td style="padding: 8px 10px;">{row.get('
|
| 318 |
-
<td style="padding: 8px 10px;">{CornerAnalyzer.normalize_series(data['
|
| 319 |
-
|
| 320 |
-
</tbody>
|
| 321 |
-
</table>
|
| 322 |
-
</div>
|
| 323 |
-
<div>
|
| 324 |
-
<strong>Current Selection</strong>
|
| 325 |
-
<table style="width: 100%; border-collapse: collapse; margin: 10px 0;">
|
| 326 |
-
<thead>
|
| 327 |
-
<tr style="background: #f8f9fa;">
|
| 328 |
-
<th style="padding: 10px; text-align: left; border-bottom: 2px solid #dee2e6;">Info</th>
|
| 329 |
-
<th style="padding: 10px; text-align: left; border-bottom: 2px solid #dee2e6;">Value</th>
|
| 330 |
</tr>
|
| 331 |
-
</thead>
|
| 332 |
-
<tbody>
|
| 333 |
<tr style="border-bottom: 1px solid #e9ecef;">
|
| 334 |
-
<td style="padding: 8px 10px;">
|
| 335 |
-
<td style="padding: 8px 10px;">{
|
|
|
|
|
|
|
| 336 |
</tr>
|
| 337 |
<tr style="border-bottom: 1px solid #e9ecef;">
|
| 338 |
-
<td style="padding: 8px 10px;">
|
| 339 |
-
<td style="padding: 8px 10px;">{
|
|
|
|
|
|
|
| 340 |
</tr>
|
| 341 |
<tr style="border-bottom: 1px solid #e9ecef;">
|
| 342 |
-
<td style="padding: 8px 10px;">
|
| 343 |
-
<td style="padding: 8px 10px;">{
|
|
|
|
|
|
|
| 344 |
</tr>
|
| 345 |
</tbody>
|
| 346 |
</table>
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 351 |
</div>
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
<
|
| 366 |
-
|
| 367 |
-
<
|
| 368 |
-
<
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 377 |
</div>
|
| 378 |
-
<div>
|
| 379 |
-
<table style="width: 100%; border-collapse: collapse;">
|
| 380 |
-
<thead><tr style="background: #f8f9fa;"><th style="padding: 8px; text-align: left; border-bottom: 2px solid #dee2e6;">Top 3 Worst</th><th style="padding: 8px; text-align: left; border-bottom: 2px solid #dee2e6;">Score</th></tr></thead>
|
| 381 |
-
<tbody>
|
| 382 |
-
"""
|
| 383 |
-
for model, score in rankings[category]['worst']:
|
| 384 |
-
clean_name = model.replace('BOLD5000_timm_', '').replace('_sim_partial', '') if 'BOLD5000_timm_' in model else model
|
| 385 |
-
rankings_html += f'<tr style="border-bottom: 1px solid #e9ecef;"><td style="padding: 6px 8px;">{clean_name}</td><td style="padding: 6px 8px;">{score:.3f}</td></tr>'
|
| 386 |
-
|
| 387 |
-
rankings_html += """
|
| 388 |
-
</tbody>
|
| 389 |
-
</table>
|
| 390 |
</div>
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
|
|
|
|
|
|
| 404 |
|
| 405 |
def set_current_model(self, ml_model_selection, ml_name):
|
| 406 |
"""Store the current ML model selection for display in image viewer"""
|
|
@@ -429,12 +529,248 @@ class SimilarityApp:
|
|
| 429 |
def get_corner_interpretation(self, corner_name: str) -> str:
|
| 430 |
"""Get interpretation of a corner - delegates to CornerAnalyzer"""
|
| 431 |
return self.corner_analyzer.get_interpretation(corner_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 432 |
|
| 433 |
def main():
|
| 434 |
"""Main function to run the application"""
|
| 435 |
try:
|
| 436 |
# Create and launch the app
|
| 437 |
-
app = SimilarityApp('overall_database3.csv')
|
| 438 |
app.launch(
|
| 439 |
server_name="0.0.0.0",
|
| 440 |
server_port=7860,
|
|
|
|
| 15 |
from gui.gradio_interface import GradioInterface
|
| 16 |
from typing import Tuple, Optional, Union, Dict, Any
|
| 17 |
import pandas as pd
|
| 18 |
+
import numpy as np
|
| 19 |
+
|
| 20 |
|
| 21 |
class SimilarityApp:
|
| 22 |
"""Main application class that orchestrates all components"""
|
|
|
|
| 137 |
print(f"Error getting model rankings: {e}")
|
| 138 |
return {}
|
| 139 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
def show_image_pair_multi(self, row_index: int):
|
| 141 |
+
"""Show image pair with separate outputs for better layout"""
|
| 142 |
+
try:
|
| 143 |
+
import matplotlib.pyplot as plt
|
| 144 |
+
import matplotlib
|
| 145 |
+
matplotlib.use('Agg')
|
| 146 |
+
import io
|
| 147 |
+
import base64
|
| 148 |
+
|
| 149 |
+
data = self.data_loader.data
|
| 150 |
+
if row_index >= len(data):
|
| 151 |
+
return (None, None, "Invalid index", "Invalid index",
|
| 152 |
+
"Invalid row index", "", "", "", "", None)
|
| 153 |
+
|
| 154 |
+
row = data.iloc[row_index]
|
| 155 |
+
|
| 156 |
+
# DEBUG: Print what images we're loading
|
| 157 |
+
print(f"\n{'*'*60}")
|
| 158 |
+
print(f"APP.PY - Loading Pair #{row_index}")
|
| 159 |
+
print(f"{'*'*60}")
|
| 160 |
+
print(f"Image 1 filename: {row.get('image_1', 'MISSING')}")
|
| 161 |
+
print(f"Image 2 filename: {row.get('image_2', 'MISSING')}")
|
| 162 |
+
print(f"Image 1 URL: {row.get('stim_1', 'MISSING')}")
|
| 163 |
+
print(f"Image 2 URL: {row.get('stim_2', 'MISSING')}")
|
| 164 |
+
print(f"{'*'*60}\n")
|
| 165 |
+
|
| 166 |
+
# Get images - now returns swap information
|
| 167 |
+
img1, img2, was_swapped = self.image_viewer.get_image_pair(data, row_index)
|
| 168 |
+
|
| 169 |
+
# Verify which images were loaded
|
| 170 |
+
print(f"Image 1 loaded: {'SUCCESS' if img1 is not None else 'FAILED'}")
|
| 171 |
+
print(f"Image 2 loaded: {'SUCCESS' if img2 is not None else 'FAILED'}")
|
| 172 |
+
print(f"URLs were swapped: {was_swapped}")
|
| 173 |
+
|
| 174 |
+
# Format captions
|
| 175 |
+
def format_captions_html(caption_text):
|
| 176 |
+
if pd.isna(caption_text) or caption_text == 'No caption available':
|
| 177 |
+
return '<div style="padding:10px; background:#f8f9fa; border-radius:5px;">No caption available</div>'
|
| 178 |
+
|
| 179 |
+
captions = [c.strip() for c in str(caption_text).split('|')]
|
| 180 |
+
|
| 181 |
+
if len(captions) == 1:
|
| 182 |
+
return f'<div style="padding:10px; background:#f8f9fa; border-radius:5px;">{captions[0]}</div>'
|
| 183 |
+
else:
|
| 184 |
+
html = f'<div style="padding:10px; background:#f8f9fa; border-radius:5px;"><strong>{len(captions)} descriptions:</strong><ol style="margin:5px 0; padding-left:20px;">'
|
| 185 |
+
for cap in captions:
|
| 186 |
+
html += f'<li>{cap}</li>'
|
| 187 |
+
html += '</ol></div>'
|
| 188 |
+
return html
|
| 189 |
+
|
| 190 |
+
# Get captions in correct order (swap if images were swapped)
|
| 191 |
+
if was_swapped:
|
| 192 |
+
print("[APP.PY] Swapping captions to match swapped images")
|
| 193 |
+
caption1_html = format_captions_html(row.get('image_2_description', 'No caption available'))
|
| 194 |
+
caption2_html = format_captions_html(row.get('image_1_description', 'No caption available'))
|
| 195 |
+
else:
|
| 196 |
+
caption1_html = format_captions_html(row.get('image_1_description', 'No caption available'))
|
| 197 |
+
caption2_html = format_captions_html(row.get('image_2_description', 'No caption available'))
|
| 198 |
+
|
| 199 |
+
# Get normalized values for bar plot
|
| 200 |
+
from analysis.corner_analyzer import CornerAnalyzer
|
| 201 |
|
| 202 |
+
# Get current brain measure from the last used one (or default to common)
|
| 203 |
+
brain_measure = getattr(self, '_current_brain_measure', 'cosine_similarity_roi_values_common')
|
| 204 |
|
| 205 |
+
# Normalize the three main values for this pair
|
| 206 |
+
human_norm = CornerAnalyzer.normalize_series(data['human_judgement']).iloc[row_index]
|
| 207 |
+
brain_norm = CornerAnalyzer.normalize_series(data[brain_measure]).iloc[row_index]
|
| 208 |
+
|
| 209 |
+
# Get ML model norm
|
| 210 |
+
current_ml_model = getattr(self, '_current_ml_model', None)
|
| 211 |
+
if current_ml_model is not None:
|
| 212 |
+
try:
|
| 213 |
+
if isinstance(current_ml_model, int) and current_ml_model < len(self.data_loader.ml_models):
|
| 214 |
+
ml_column = self.data_loader.ml_models[current_ml_model]
|
| 215 |
+
ml_norm = CornerAnalyzer.normalize_series(data[ml_column]).iloc[row_index]
|
| 216 |
+
elif isinstance(current_ml_model, str) and current_ml_model.startswith('avg_'):
|
| 217 |
+
ml_norm = CornerAnalyzer.normalize_series(data[current_ml_model]).iloc[row_index]
|
| 218 |
+
else:
|
| 219 |
+
ml_norm = 0.5 # default
|
| 220 |
+
except Exception:
|
| 221 |
+
ml_norm = 0.5
|
| 222 |
else:
|
| 223 |
+
ml_norm = 0.5
|
| 224 |
+
|
| 225 |
+
# Create bar plot
|
| 226 |
+
fig, ax = plt.subplots(figsize=(6, 4))
|
| 227 |
+
|
| 228 |
+
categories = ['Human', 'Brain', 'ML']
|
| 229 |
+
values = [human_norm, brain_norm, ml_norm]
|
| 230 |
+
colors = ['#4A90E2', '#50C878', '#E24A4A']
|
| 231 |
+
|
| 232 |
+
# Create bars
|
| 233 |
+
bars = ax.bar(categories, values, color=colors, alpha=0.8, edgecolor='black', linewidth=1.5)
|
| 234 |
+
|
| 235 |
+
# Add value labels on top of bars
|
| 236 |
+
for bar, val in zip(bars, values):
|
| 237 |
+
height = bar.get_height()
|
| 238 |
+
ax.text(bar.get_x() + bar.get_width()/2., height + 0.02,
|
| 239 |
+
f'{val:.3f}',
|
| 240 |
+
ha='center', va='bottom', fontsize=11, fontweight='bold')
|
| 241 |
+
|
| 242 |
+
# Styling
|
| 243 |
+
ax.set_ylabel('Normalized Value (0-1)', fontsize=11, fontweight='bold')
|
| 244 |
+
ax.set_xlabel('Measure', fontsize=11, fontweight='bold')
|
| 245 |
+
ax.set_title(f'Normalized Values for Pair #{row_index}', fontsize=12, fontweight='bold')
|
| 246 |
+
ax.set_ylim(0, 1.15)
|
| 247 |
+
ax.grid(axis='y', alpha=0.3, linestyle='--')
|
| 248 |
+
ax.set_axisbelow(True)
|
| 249 |
+
|
| 250 |
+
# Style the plot
|
| 251 |
+
ax.spines['top'].set_visible(False)
|
| 252 |
+
ax.spines['right'].set_visible(False)
|
| 253 |
+
|
| 254 |
+
plt.tight_layout()
|
| 255 |
|
| 256 |
+
# Convert to base64 image
|
| 257 |
+
buf = io.BytesIO()
|
| 258 |
+
plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
|
| 259 |
+
buf.seek(0)
|
| 260 |
+
img_base64 = base64.b64encode(buf.read()).decode()
|
| 261 |
+
plt.close(fig)
|
| 262 |
+
|
| 263 |
+
bar_plot_html = f'<img src="data:image/png;base64,{img_base64}" style="width: 100%; max-width: 500px; margin: 20px auto; display: block; border: 1px solid #ddd; border-radius: 5px; padding: 10px; background: white;" />'
|
| 264 |
+
|
| 265 |
+
# Calculate averages if needed
|
| 266 |
+
if 'avg_vision' not in data.columns:
|
| 267 |
+
vision_models = [col for col in data.columns if 'BOLD5000_timm_' in col]
|
| 268 |
+
language_models = [col for col in data.columns if 'bert-' in col or 'deberta-' in col or 'sup-simcse' in col]
|
| 269 |
+
semantic_models = [col for col in data.columns if any(x in col for x in ["bm25", "rouge", "tf-idf", "co-occurrence"])]
|
| 270 |
+
|
| 271 |
+
def normalize_models(model_list):
|
| 272 |
+
if not model_list:
|
| 273 |
+
return pd.Series([0] * len(data))
|
| 274 |
+
normalized_data = []
|
| 275 |
+
for model in model_list:
|
| 276 |
+
if model in data.columns:
|
| 277 |
+
model_data = data[model]
|
| 278 |
+
normalized = (model_data - model_data.min()) / (model_data.max() - model_data.min())
|
| 279 |
+
normalized_data.append(normalized)
|
| 280 |
+
if normalized_data:
|
| 281 |
+
return pd.concat(normalized_data, axis=1).mean(axis=1)
|
| 282 |
return pd.Series([0] * len(data))
|
| 283 |
+
|
| 284 |
+
data['avg_vision'] = normalize_models(vision_models)
|
| 285 |
+
data['avg_language'] = normalize_models(language_models)
|
| 286 |
+
data['avg_semantic'] = normalize_models(semantic_models)
|
| 287 |
+
|
| 288 |
+
# Get current model
|
| 289 |
+
current_ml_model = getattr(self, '_current_ml_model', None)
|
| 290 |
+
current_ml_name = getattr(self, '_current_ml_name', 'No model selected')
|
| 291 |
+
current_ml_score = 'N/A'
|
| 292 |
+
current_ml_norm = 'N/A'
|
| 293 |
+
|
| 294 |
+
if current_ml_model is not None:
|
| 295 |
+
try:
|
| 296 |
+
if isinstance(current_ml_model, int) and current_ml_model < len(self.data_loader.ml_models):
|
| 297 |
+
ml_column = self.data_loader.ml_models[current_ml_model]
|
| 298 |
+
current_ml_score = f"{row[ml_column]:.3f}"
|
| 299 |
+
current_ml_norm = f"{CornerAnalyzer.normalize_series(data[ml_column]).iloc[row_index]:.3f}"
|
| 300 |
+
elif isinstance(current_ml_model, str) and current_ml_model.startswith('avg_'):
|
| 301 |
+
current_ml_score = f"{row[current_ml_model]:.3f}"
|
| 302 |
+
current_ml_norm = f"{CornerAnalyzer.normalize_series(data[current_ml_model]).iloc[row_index]:.3f}"
|
| 303 |
+
except Exception:
|
| 304 |
+
pass
|
| 305 |
+
|
| 306 |
+
# Summary card HTML - WITH DEBUG INFO AND NORMALIZED VALUES
|
| 307 |
+
summary_html = f"""
|
| 308 |
+
<div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 20px; border-radius: 10px; box-shadow: 0 4px 6px rgba(0,0,0,0.1);">
|
| 309 |
+
<h3 style="margin: 0 0 10px 0;">Image Pair #{row_index} Summary</h3>
|
| 310 |
+
<div style="background: rgba(255,255,255,0.15); padding: 10px; border-radius: 5px; margin-bottom: 10px; font-family: monospace; font-size: 11px;">
|
| 311 |
+
<strong>🔍 DEBUG INFO:</strong><br>
|
| 312 |
+
Image 1 File: <code>{row['image_1']}</code><br>
|
| 313 |
+
Image 2 File: <code>{row['image_2']}</code>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 314 |
</div>
|
| 315 |
+
<div><strong>Images:</strong> {row['image_1']} vs {row['image_2']}</div>
|
| 316 |
+
<div style="display: grid; grid-template-columns: repeat(3, 1fr); gap: 10px; margin-top: 15px;">
|
| 317 |
+
<div style="background: rgba(255,255,255,0.2); padding: 10px; border-radius: 5px; text-align: center;">
|
| 318 |
+
<div style="font-size: 11px; opacity: 0.9; margin-bottom: 5px;">Human Rating</div>
|
| 319 |
+
<div style="font-size: 20px; font-weight: bold;">{row['human_judgement']:.2f}/6</div>
|
| 320 |
+
<div style="font-size: 11px; opacity: 0.8; margin-top: 5px;">Norm: {human_norm:.3f}</div>
|
| 321 |
+
</div>
|
| 322 |
+
<div style="background: rgba(255,255,255,0.2); padding: 10px; border-radius: 5px; text-align: center;">
|
| 323 |
+
<div style="font-size: 11px; opacity: 0.9; margin-bottom: 5px;">Brain Similarity</div>
|
| 324 |
+
<div style="font-size: 20px; font-weight: bold;">{row.get(brain_measure, 0):.3f}</div>
|
| 325 |
+
<div style="font-size: 11px; opacity: 0.8; margin-top: 5px;">Norm: {brain_norm:.3f}</div>
|
| 326 |
+
</div>
|
| 327 |
+
<div style="background: rgba(255,255,255,0.2); padding: 10px; border-radius: 5px; text-align: center;">
|
| 328 |
+
<div style="font-size: 11px; opacity: 0.9; margin-bottom: 5px;">ML Model</div>
|
| 329 |
+
<div style="font-size: 20px; font-weight: bold;">{current_ml_score}</div>
|
| 330 |
+
<div style="font-size: 11px; opacity: 0.8; margin-top: 5px;">Norm: {ml_norm:.3f}</div>
|
| 331 |
+
</div>
|
| 332 |
</div>
|
| 333 |
</div>
|
| 334 |
+
"""
|
| 335 |
+
|
| 336 |
+
# Brain measures table
|
| 337 |
+
brain_html = f"""
|
| 338 |
+
<table style="width: 100%; border-collapse: collapse;">
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 339 |
<thead>
|
| 340 |
<tr style="background: #f8f9fa;">
|
| 341 |
+
<th style="padding: 10px; text-align: left; border-bottom: 2px solid #dee2e6;">Measure</th>
|
| 342 |
<th style="padding: 10px; text-align: left; border-bottom: 2px solid #dee2e6;">Raw</th>
|
| 343 |
<th style="padding: 10px; text-align: left; border-bottom: 2px solid #dee2e6;">Norm</th>
|
| 344 |
+
<th style="padding: 10px; text-align: left; border-bottom: 2px solid #dee2e6;">Type</th>
|
| 345 |
</tr>
|
| 346 |
</thead>
|
| 347 |
<tbody>
|
| 348 |
<tr style="border-bottom: 1px solid #e9ecef;">
|
| 349 |
+
<td style="padding: 8px 10px;"><strong>Cosine - Common</strong></td>
|
| 350 |
+
<td style="padding: 8px 10px;">{row.get('cosine_similarity_roi_values_common', 0):.3f}</td>
|
| 351 |
+
<td style="padding: 8px 10px;">{CornerAnalyzer.normalize_series(data['cosine_similarity_roi_values_common']).iloc[row_index]:.3f}</td>
|
| 352 |
+
<td style="padding: 8px 10px;">All regions</td>
|
| 353 |
</tr>
|
| 354 |
<tr style="border-bottom: 1px solid #e9ecef;">
|
| 355 |
+
<td style="padding: 8px 10px;"><strong>Cosine - Early</strong></td>
|
| 356 |
+
<td style="padding: 8px 10px;">{row.get('cosine_similarity_roi_values_early', 0):.3f}</td>
|
| 357 |
+
<td style="padding: 8px 10px;">{CornerAnalyzer.normalize_series(data['cosine_similarity_roi_values_early']).iloc[row_index]:.3f}</td>
|
| 358 |
+
<td style="padding: 8px 10px;">Low-level</td>
|
| 359 |
</tr>
|
| 360 |
<tr style="border-bottom: 1px solid #e9ecef;">
|
| 361 |
+
<td style="padding: 8px 10px;"><strong>Cosine - Late</strong></td>
|
| 362 |
+
<td style="padding: 8px 10px;">{row.get('cosine_similarity_roi_values_late', 0):.3f}</td>
|
| 363 |
+
<td style="padding: 8px 10px;">{CornerAnalyzer.normalize_series(data['cosine_similarity_roi_values_late']).iloc[row_index]:.3f}</td>
|
| 364 |
+
<td style="padding: 8px 10px;">High-level</td>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 365 |
</tr>
|
|
|
|
|
|
|
| 366 |
<tr style="border-bottom: 1px solid #e9ecef;">
|
| 367 |
+
<td style="padding: 8px 10px;"><strong>Pearson - Common</strong></td>
|
| 368 |
+
<td style="padding: 8px 10px;">{row.get('pearson_correlation_roi_values_common', 0):.3f}</td>
|
| 369 |
+
<td style="padding: 8px 10px;">{CornerAnalyzer.normalize_series(data['pearson_correlation_roi_values_common']).iloc[row_index]:.3f}</td>
|
| 370 |
+
<td style="padding: 8px 10px;">All regions</td>
|
| 371 |
</tr>
|
| 372 |
<tr style="border-bottom: 1px solid #e9ecef;">
|
| 373 |
+
<td style="padding: 8px 10px;"><strong>Pearson - Early</strong></td>
|
| 374 |
+
<td style="padding: 8px 10px;">{row.get('pearson_correlation_roi_values_early', 0):.3f}</td>
|
| 375 |
+
<td style="padding: 8px 10px;">{CornerAnalyzer.normalize_series(data['pearson_correlation_roi_values_early']).iloc[row_index]:.3f}</td>
|
| 376 |
+
<td style="padding: 8px 10px;">Low-level</td>
|
| 377 |
</tr>
|
| 378 |
<tr style="border-bottom: 1px solid #e9ecef;">
|
| 379 |
+
<td style="padding: 8px 10px;"><strong>Pearson - Late</strong></td>
|
| 380 |
+
<td style="padding: 8px 10px;">{row.get('pearson_correlation_roi_values_late', 0):.3f}</td>
|
| 381 |
+
<td style="padding: 8px 10px;">{CornerAnalyzer.normalize_series(data['pearson_correlation_roi_values_late']).iloc[row_index]:.3f}</td>
|
| 382 |
+
<td style="padding: 8px 10px;">High-level</td>
|
| 383 |
</tr>
|
| 384 |
</tbody>
|
| 385 |
</table>
|
| 386 |
+
"""
|
| 387 |
+
|
| 388 |
+
# Model performance HTML
|
| 389 |
+
model_html = f"""
|
| 390 |
+
<div style="margin-bottom: 15px;">
|
| 391 |
+
<strong>Category Averages</strong>
|
| 392 |
+
<table style="width: 100%; border-collapse: collapse; margin: 10px 0;">
|
| 393 |
+
<thead>
|
| 394 |
+
<tr style="background: #f8f9fa;">
|
| 395 |
+
<th style="padding: 10px; text-align: left; border-bottom: 2px solid #dee2e6;">Category</th>
|
| 396 |
+
<th style="padding: 10px; text-align: left; border-bottom: 2px solid #dee2e6;">Raw</th>
|
| 397 |
+
<th style="padding: 10px; text-align: left; border-bottom: 2px solid #dee2e6;">Norm</th>
|
| 398 |
+
</tr>
|
| 399 |
+
</thead>
|
| 400 |
+
<tbody>
|
| 401 |
+
<tr style="border-bottom: 1px solid #e9ecef;">
|
| 402 |
+
<td style="padding: 8px 10px;">Vision</td>
|
| 403 |
+
<td style="padding: 8px 10px;">{row.get('avg_vision', 0):.3f}</td>
|
| 404 |
+
<td style="padding: 8px 10px;">{CornerAnalyzer.normalize_series(data['avg_vision']).iloc[row_index]:.3f}</td>
|
| 405 |
+
</tr>
|
| 406 |
+
<tr style="border-bottom: 1px solid #e9ecef;">
|
| 407 |
+
<td style="padding: 8px 10px;">Language</td>
|
| 408 |
+
<td style="padding: 8px 10px;">{row.get('avg_language', 0):.3f}</td>
|
| 409 |
+
<td style="padding: 8px 10px;">{CornerAnalyzer.normalize_series(data['avg_language']).iloc[row_index]:.3f}</td>
|
| 410 |
+
</tr>
|
| 411 |
+
<tr style="border-bottom: 1px solid #e9ecef;">
|
| 412 |
+
<td style="padding: 8px 10px;">Semantic</td>
|
| 413 |
+
<td style="padding: 8px 10px;">{row.get('avg_semantic', 0):.3f}</td>
|
| 414 |
+
<td style="padding: 8px 10px;">{CornerAnalyzer.normalize_series(data['avg_semantic']).iloc[row_index]:.3f}</td>
|
| 415 |
+
</tr>
|
| 416 |
+
</tbody>
|
| 417 |
+
</table>
|
| 418 |
</div>
|
| 419 |
+
<div>
|
| 420 |
+
<strong>Current Selection</strong>
|
| 421 |
+
<table style="width: 100%; border-collapse: collapse; margin: 10px 0;">
|
| 422 |
+
<thead>
|
| 423 |
+
<tr style="background: #f8f9fa;">
|
| 424 |
+
<th style="padding: 10px; text-align: left; border-bottom: 2px solid #dee2e6;">Info</th>
|
| 425 |
+
<th style="padding: 10px; text-align: left; border-bottom: 2px solid #dee2e6;">Value</th>
|
| 426 |
+
</tr>
|
| 427 |
+
</thead>
|
| 428 |
+
<tbody>
|
| 429 |
+
<tr style="border-bottom: 1px solid #e9ecef;">
|
| 430 |
+
<td style="padding: 8px 10px;">Model</td>
|
| 431 |
+
<td style="padding: 8px 10px;">{current_ml_name}</td>
|
| 432 |
+
</tr>
|
| 433 |
+
<tr style="border-bottom: 1px solid #e9ecef;">
|
| 434 |
+
<td style="padding: 8px 10px;">Raw</td>
|
| 435 |
+
<td style="padding: 8px 10px;">{current_ml_score}</td>
|
| 436 |
+
</tr>
|
| 437 |
+
<tr style="border-bottom: 1px solid #e9ecef;">
|
| 438 |
+
<td style="padding: 8px 10px;">Norm</td>
|
| 439 |
+
<td style="padding: 8px 10px;">{current_ml_norm}</td>
|
| 440 |
+
</tr>
|
| 441 |
+
</tbody>
|
| 442 |
+
</table>
|
| 443 |
+
<div style="margin-top: 15px; font-size: 12px; color: #666;">
|
| 444 |
+
<strong>Dataset:</strong> Vision: {len([col for col in data.columns if 'BOLD5000_timm_' in col])},
|
| 445 |
+
Language: {len([col for col in data.columns if 'bert-' in col or 'deberta-' in col or 'sup-simcse' in col])},
|
| 446 |
+
Semantic: {len([col for col in data.columns if any(x in col for x in ["bm25", "rouge", "tf-idf", "co-occurrence"])])} models
|
| 447 |
+
</div>
|
| 448 |
+
</div>
|
| 449 |
+
"""
|
| 450 |
+
|
| 451 |
+
# Model rankings
|
| 452 |
+
rankings = self.get_model_rankings_for_pair(row_index)
|
| 453 |
+
rankings_html = ""
|
| 454 |
+
for category in ['vision', 'language', 'semantic']:
|
| 455 |
+
if category in rankings:
|
| 456 |
+
category_name = category.title()
|
| 457 |
+
rankings_html += f"""
|
| 458 |
+
<div style="margin: 15px 0;">
|
| 459 |
+
<strong>{category_name} Models:</strong>
|
| 460 |
+
<div style="display: grid; grid-template-columns: 1fr 1fr; gap: 15px; margin-top: 10px;">
|
| 461 |
+
<div>
|
| 462 |
+
<table style="width: 100%; border-collapse: collapse;">
|
| 463 |
+
<thead><tr style="background: #f8f9fa;"><th style="padding: 8px; text-align: left; border-bottom: 2px solid #dee2e6;">Top 3 Best</th><th style="padding: 8px; text-align: left; border-bottom: 2px solid #dee2e6;">Score</th></tr></thead>
|
| 464 |
+
<tbody>
|
| 465 |
+
"""
|
| 466 |
+
for model, score in rankings[category]['best']:
|
| 467 |
+
clean_name = model.replace('BOLD5000_timm_', '').replace('_sim_partial', '') if 'BOLD5000_timm_' in model else model
|
| 468 |
+
rankings_html += f'<tr style="border-bottom: 1px solid #e9ecef;"><td style="padding: 6px 8px;">{clean_name}</td><td style="padding: 6px 8px;">{score:.3f}</td></tr>'
|
| 469 |
+
|
| 470 |
+
rankings_html += """
|
| 471 |
+
</tbody>
|
| 472 |
+
</table>
|
| 473 |
+
</div>
|
| 474 |
+
<div>
|
| 475 |
+
<table style="width: 100%; border-collapse: collapse;">
|
| 476 |
+
<thead><tr style="background: #f8f9fa;"><th style="padding: 8px; text-align: left; border-bottom: 2px solid #dee2e6;">Top 3 Worst</th><th style="padding: 8px; text-align: left; border-bottom: 2px solid #dee2e6;">Score</th></tr></thead>
|
| 477 |
+
<tbody>
|
| 478 |
+
"""
|
| 479 |
+
for model, score in rankings[category]['worst']:
|
| 480 |
+
clean_name = model.replace('BOLD5000_timm_', '').replace('_sim_partial', '') if 'BOLD5000_timm_' in model else model
|
| 481 |
+
rankings_html += f'<tr style="border-bottom: 1px solid #e9ecef;"><td style="padding: 6px 8px;">{clean_name}</td><td style="padding: 6px 8px;">{score:.3f}</td></tr>'
|
| 482 |
+
|
| 483 |
+
rankings_html += """
|
| 484 |
+
</tbody>
|
| 485 |
+
</table>
|
| 486 |
+
</div>
|
| 487 |
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 488 |
</div>
|
| 489 |
+
"""
|
| 490 |
+
|
| 491 |
+
# ROI plot
|
| 492 |
+
print(f"\nCalling ROI analyzer for pair #{row_index}...")
|
| 493 |
+
roi_plot = self.roi_analyzer.create_roi_comparison_plot(data, row_index)
|
| 494 |
+
print(f"ROI plot created successfully\n")
|
| 495 |
+
|
| 496 |
+
return (img1, img2, caption1_html, caption2_html, summary_html,
|
| 497 |
+
bar_plot_html, brain_html, model_html, rankings_html, roi_plot)
|
| 498 |
+
|
| 499 |
+
except Exception as e:
|
| 500 |
+
error_msg = f"<div style='color: red;'>Error: {e}</div>"
|
| 501 |
+
import traceback
|
| 502 |
+
traceback.print_exc()
|
| 503 |
+
return (None, None, error_msg, error_msg, error_msg, error_msg, error_msg, error_msg, error_msg, None)
|
| 504 |
|
| 505 |
def set_current_model(self, ml_model_selection, ml_name):
|
| 506 |
"""Store the current ML model selection for display in image viewer"""
|
|
|
|
| 529 |
def get_corner_interpretation(self, corner_name: str) -> str:
|
| 530 |
"""Get interpretation of a corner - delegates to CornerAnalyzer"""
|
| 531 |
return self.corner_analyzer.get_interpretation(corner_name)
|
| 532 |
+
|
| 533 |
+
# Add this new method to the SimilarityApp class in app.py
|
| 534 |
+
|
| 535 |
+
# Add this new method to the SimilarityApp class in app.py
|
| 536 |
+
|
| 537 |
+
def get_point_corner_distances(self, row_index: int, brain_measure: str, ml_model_selection: Union[str, int]) -> Tuple[str, Optional[Any]]:
|
| 538 |
+
"""Get distances from a specific point to all 8 corners and create a 3D visualization"""
|
| 539 |
+
try:
|
| 540 |
+
data = self.data_loader.data
|
| 541 |
+
if row_index >= len(data):
|
| 542 |
+
return "Invalid row index", None
|
| 543 |
+
|
| 544 |
+
# Get the data for this point
|
| 545 |
+
row = data.iloc[row_index]
|
| 546 |
+
ml_data, ml_name = self.plot_generator.get_model_data(ml_model_selection)
|
| 547 |
+
|
| 548 |
+
# Get raw values
|
| 549 |
+
human_raw = row['human_judgement']
|
| 550 |
+
brain_raw = row[brain_measure]
|
| 551 |
+
ml_raw = ml_data.iloc[row_index]
|
| 552 |
+
|
| 553 |
+
# Normalize to 0-1 for distance calculations
|
| 554 |
+
from analysis.corner_analyzer import CornerAnalyzer
|
| 555 |
+
human_norm = CornerAnalyzer.normalize_series(data['human_judgement']).iloc[row_index]
|
| 556 |
+
brain_norm = CornerAnalyzer.normalize_series(data[brain_measure]).iloc[row_index]
|
| 557 |
+
ml_norm = CornerAnalyzer.normalize_series(ml_data).iloc[row_index]
|
| 558 |
+
|
| 559 |
+
# Define all 8 corners
|
| 560 |
+
corners = {
|
| 561 |
+
'(0,0,0)': (0, 0, 0),
|
| 562 |
+
'(0,0,1)': (0, 0, 1),
|
| 563 |
+
'(0,1,0)': (0, 1, 0),
|
| 564 |
+
'(0,1,1)': (0, 1, 1),
|
| 565 |
+
'(1,0,0)': (1, 0, 0),
|
| 566 |
+
'(1,0,1)': (1, 0, 1),
|
| 567 |
+
'(1,1,0)': (1, 1, 0),
|
| 568 |
+
'(1,1,1)': (1, 1, 1)
|
| 569 |
+
}
|
| 570 |
+
|
| 571 |
+
# Calculate distances to each corner
|
| 572 |
+
point = np.array([human_norm, brain_norm, ml_norm])
|
| 573 |
+
distances = {}
|
| 574 |
+
for corner_name, corner_coords in corners.items():
|
| 575 |
+
corner_array = np.array(corner_coords)
|
| 576 |
+
distance = np.linalg.norm(point - corner_array)
|
| 577 |
+
distances[corner_name] = distance
|
| 578 |
+
|
| 579 |
+
# Sort by distance
|
| 580 |
+
sorted_distances = sorted(distances.items(), key=lambda x: x[1])
|
| 581 |
+
|
| 582 |
+
# Get corner interpretations
|
| 583 |
+
corner_meanings = {
|
| 584 |
+
'(0,0,0)': 'All Low - General disagreement on similarity',
|
| 585 |
+
'(0,0,1)': 'ML High, Human & Brain Low',
|
| 586 |
+
'(0,1,0)': 'Brain High, Human & ML Low',
|
| 587 |
+
'(0,1,1)': 'Brain & ML High, Human Low',
|
| 588 |
+
'(1,0,0)': 'Human High, Brain & ML Low',
|
| 589 |
+
'(1,0,1)': 'Human & ML High, Brain Low',
|
| 590 |
+
'(1,1,0)': 'Human & Brain High, ML Low',
|
| 591 |
+
'(1,1,1)': 'All High - Strong agreement on similarity'
|
| 592 |
+
}
|
| 593 |
+
|
| 594 |
+
# Create HTML table
|
| 595 |
+
html = f"""
|
| 596 |
+
<div style="background: #f8f9fa; padding: 20px; border-radius: 10px; margin: 20px 0;">
|
| 597 |
+
<h3 style="margin-top: 0;">Distance to Each Corner (Normalized Space)</h3>
|
| 598 |
+
<div style="background: white; padding: 15px; border-radius: 5px; margin-bottom: 15px;">
|
| 599 |
+
<strong>Current Point (Normalized 0-1):</strong><br>
|
| 600 |
+
Human: {human_norm:.3f} | Brain: {brain_norm:.3f} | ML: {ml_norm:.3f}
|
| 601 |
+
</div>
|
| 602 |
+
<table style="width: 100%; border-collapse: collapse; background: white;">
|
| 603 |
+
<thead>
|
| 604 |
+
<tr style="background: #667eea; color: white;">
|
| 605 |
+
<th style="padding: 12px; text-align: left;">Rank</th>
|
| 606 |
+
<th style="padding: 12px; text-align: left;">Corner (H,B,M)</th>
|
| 607 |
+
<th style="padding: 12px; text-align: left;">Distance</th>
|
| 608 |
+
<th style="padding: 12px; text-align: left;">Meaning</th>
|
| 609 |
+
</tr>
|
| 610 |
+
</thead>
|
| 611 |
+
<tbody>
|
| 612 |
+
"""
|
| 613 |
+
|
| 614 |
+
for rank, (corner_name, distance) in enumerate(sorted_distances, 1):
|
| 615 |
+
row_color = "#e8f4f8" if rank == 1 else "white"
|
| 616 |
+
star = "⭐ " if rank == 1 else ""
|
| 617 |
+
html += f"""
|
| 618 |
+
<tr style="background: {row_color}; border-bottom: 1px solid #dee2e6;">
|
| 619 |
+
<td style="padding: 10px;"><strong>{star}{rank}</strong></td>
|
| 620 |
+
<td style="padding: 10px; font-family: monospace;">{corner_name}</td>
|
| 621 |
+
<td style="padding: 10px;"><strong>{distance:.4f}</strong></td>
|
| 622 |
+
<td style="padding: 10px; font-size: 12px;">{corner_meanings[corner_name]}</td>
|
| 623 |
+
</tr>
|
| 624 |
+
"""
|
| 625 |
+
|
| 626 |
+
html += """
|
| 627 |
+
</tbody>
|
| 628 |
+
</table>
|
| 629 |
+
<div style="margin-top: 15px; font-size: 12px; color: #666;">
|
| 630 |
+
Distances calculated in normalized 0-1 space using Euclidean distance.
|
| 631 |
+
Closer corners indicate which extreme this pair is nearest to.
|
| 632 |
+
</div>
|
| 633 |
+
</div>
|
| 634 |
+
"""
|
| 635 |
+
|
| 636 |
+
# Create 3D plot with the point highlighted
|
| 637 |
+
import plotly.graph_objects as go
|
| 638 |
+
|
| 639 |
+
# Normalize all data for the 3D plot
|
| 640 |
+
human_all_norm = CornerAnalyzer.normalize_series(data['human_judgement'])
|
| 641 |
+
brain_all_norm = CornerAnalyzer.normalize_series(data[brain_measure])
|
| 642 |
+
ml_all_norm = CornerAnalyzer.normalize_series(ml_data)
|
| 643 |
+
|
| 644 |
+
fig = go.Figure()
|
| 645 |
+
|
| 646 |
+
# Add all other points in gray
|
| 647 |
+
other_indices = [i for i in range(len(data)) if i != row_index]
|
| 648 |
+
fig.add_trace(go.Scatter3d(
|
| 649 |
+
x=human_all_norm.iloc[other_indices],
|
| 650 |
+
y=brain_all_norm.iloc[other_indices],
|
| 651 |
+
z=ml_all_norm.iloc[other_indices],
|
| 652 |
+
mode='markers',
|
| 653 |
+
marker=dict(size=3, color='lightgray', opacity=0.3),
|
| 654 |
+
name='Other pairs',
|
| 655 |
+
hoverinfo='skip'
|
| 656 |
+
))
|
| 657 |
+
|
| 658 |
+
# Add the current point in VERY VISIBLE bright color with larger size
|
| 659 |
+
fig.add_trace(go.Scatter3d(
|
| 660 |
+
x=[human_norm],
|
| 661 |
+
y=[brain_norm],
|
| 662 |
+
z=[ml_norm],
|
| 663 |
+
mode='markers',
|
| 664 |
+
marker=dict(
|
| 665 |
+
size=25, # Much larger
|
| 666 |
+
color='#FF0000', # Bright red
|
| 667 |
+
symbol='diamond',
|
| 668 |
+
line=dict(color='yellow', width=4), # Yellow outline for extra visibility
|
| 669 |
+
opacity=1.0
|
| 670 |
+
),
|
| 671 |
+
name=f'⭐ Current Pair #{row_index}',
|
| 672 |
+
hovertemplate=f'<b>⭐ CURRENT PAIR #{row_index}</b><br>' +
|
| 673 |
+
f'Human: {human_norm:.3f}<br>' +
|
| 674 |
+
f'Brain: {brain_norm:.3f}<br>' +
|
| 675 |
+
f'ML: {ml_norm:.3f}<br>' +
|
| 676 |
+
f'<extra></extra>'
|
| 677 |
+
))
|
| 678 |
+
|
| 679 |
+
# Add corner points with larger size
|
| 680 |
+
corner_x = [c[0] for c in corners.values()]
|
| 681 |
+
corner_y = [c[1] for c in corners.values()]
|
| 682 |
+
corner_z = [c[2] for c in corners.values()]
|
| 683 |
+
corner_labels = [f"{name}<br>{corner_meanings[name]}" for name in corners.keys()]
|
| 684 |
+
|
| 685 |
+
fig.add_trace(go.Scatter3d(
|
| 686 |
+
x=corner_x,
|
| 687 |
+
y=corner_y,
|
| 688 |
+
z=corner_z,
|
| 689 |
+
mode='markers+text',
|
| 690 |
+
marker=dict(size=10, color='#4169E1', symbol='square', opacity=0.8), # Larger, more visible
|
| 691 |
+
text=list(corners.keys()),
|
| 692 |
+
textposition='top center',
|
| 693 |
+
textfont=dict(size=9, color='darkblue', family='Arial Black'),
|
| 694 |
+
name='Corners',
|
| 695 |
+
hovertext=corner_labels,
|
| 696 |
+
hoverinfo='text'
|
| 697 |
+
))
|
| 698 |
+
|
| 699 |
+
# Add lines from current point to ALL corners with distance labels
|
| 700 |
+
for corner_name, distance in distances.items():
|
| 701 |
+
corner_coords = corners[corner_name]
|
| 702 |
+
|
| 703 |
+
# Determine line color based on distance (closer = more orange, farther = more gray)
|
| 704 |
+
# Normalize distance for color (distances range from 0 to sqrt(3) ≈ 1.732)
|
| 705 |
+
normalized_dist = distance / 1.732
|
| 706 |
+
if normalized_dist < 0.33:
|
| 707 |
+
line_color = '#FF4500' # OrangeRed for closest
|
| 708 |
+
line_width = 4
|
| 709 |
+
elif normalized_dist < 0.67:
|
| 710 |
+
line_color = '#FFA500' # Orange for medium
|
| 711 |
+
line_width = 3
|
| 712 |
+
else:
|
| 713 |
+
line_color = '#C0C0C0' # Silver for farthest
|
| 714 |
+
line_width = 2
|
| 715 |
+
|
| 716 |
+
# Add the line
|
| 717 |
+
fig.add_trace(go.Scatter3d(
|
| 718 |
+
x=[human_norm, corner_coords[0]],
|
| 719 |
+
y=[brain_norm, corner_coords[1]],
|
| 720 |
+
z=[ml_norm, corner_coords[2]],
|
| 721 |
+
mode='lines',
|
| 722 |
+
line=dict(color=line_color, width=line_width, dash='dot'),
|
| 723 |
+
showlegend=False,
|
| 724 |
+
hoverinfo='skip'
|
| 725 |
+
))
|
| 726 |
+
|
| 727 |
+
# Add text label at midpoint of line showing distance
|
| 728 |
+
mid_x = (human_norm + corner_coords[0]) / 2
|
| 729 |
+
mid_y = (brain_norm + corner_coords[1]) / 2
|
| 730 |
+
mid_z = (ml_norm + corner_coords[2]) / 2
|
| 731 |
+
|
| 732 |
+
fig.add_trace(go.Scatter3d(
|
| 733 |
+
x=[mid_x],
|
| 734 |
+
y=[mid_y],
|
| 735 |
+
z=[mid_z],
|
| 736 |
+
mode='text',
|
| 737 |
+
text=[f'{distance:.3f}'],
|
| 738 |
+
textfont=dict(size=20, color=line_color),
|
| 739 |
+
showlegend=False,
|
| 740 |
+
hoverinfo='skip'
|
| 741 |
+
))
|
| 742 |
+
|
| 743 |
+
brain_name = brain_measure.replace('cosine_similarity_roi_values_', '').replace('pearson_correlation_roi_values_', '').title()
|
| 744 |
+
measure_type = "Cosine" if "cosine" in brain_measure else "Pearson"
|
| 745 |
+
|
| 746 |
+
fig.update_layout(
|
| 747 |
+
title=f'Pair #{row_index} Position in Normalized 3D Space<br><sub>Red diamond shows current pair, blue squares show corners, dashed lines to 3 nearest corners</sub>',
|
| 748 |
+
scene=dict(
|
| 749 |
+
xaxis_title='Human (norm)',
|
| 750 |
+
yaxis_title=f'Brain {measure_type} (norm)',
|
| 751 |
+
zaxis_title=f'{ml_name} (norm)',
|
| 752 |
+
xaxis=dict(range=[0, 1]),
|
| 753 |
+
yaxis=dict(range=[0, 1]),
|
| 754 |
+
zaxis=dict(range=[0, 1]),
|
| 755 |
+
camera=dict(eye=dict(x=1.5, y=1.5, z=1.5))
|
| 756 |
+
),
|
| 757 |
+
height=600,
|
| 758 |
+
showlegend=True,
|
| 759 |
+
legend=dict(x=0.7, y=0.9)
|
| 760 |
+
)
|
| 761 |
+
|
| 762 |
+
return html, fig
|
| 763 |
+
|
| 764 |
+
except Exception as e:
|
| 765 |
+
import traceback
|
| 766 |
+
traceback.print_exc()
|
| 767 |
+
return f"<div style='color: red;'>Error calculating distances: {e}</div>", None
|
| 768 |
|
| 769 |
def main():
|
| 770 |
"""Main function to run the application"""
|
| 771 |
try:
|
| 772 |
# Create and launch the app
|
| 773 |
+
app = SimilarityApp('data/overall_database3.csv')
|
| 774 |
app.launch(
|
| 775 |
server_name="0.0.0.0",
|
| 776 |
server_port=7860,
|
brain/__pycache__/roi_analyzer.cpython-311.pyc
CHANGED
|
Binary files a/brain/__pycache__/roi_analyzer.cpython-311.pyc and b/brain/__pycache__/roi_analyzer.cpython-311.pyc differ
|
|
|
data/Final_similarity_matrix.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5ff83586ea8a37e91968eb0b73f23edebd61fdd88f815bef34a0f727c6d5ef35
|
| 3 |
+
size 39820273
|
data/Final_similarity_matrix2.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0cdd2d60df4ebd8de7533d82f7df5d4f8733d06134c641795d2f014c5de561b2
|
| 3 |
+
size 64812304
|
data/overall_database.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:08b27836f940efc44631b48b31964c02ae9ae0b7ef0c35ec4b33f1181e9e5480
|
| 3 |
+
size 44624446
|
data/overall_database2.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5dc09b2d3975398f84c9a734e0eac9f3a9db044a961600a91d8d34f5e1288bcb
|
| 3 |
+
size 46365658
|
data/overall_database3.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0a3c421e159026ec81497882cb4bc156f85cf93b21e5a3799eaf83bd653537aa
|
| 3 |
+
size 46506013
|
gui/__pycache__/corner_cases_tab.cpython-311.pyc
CHANGED
|
Binary files a/gui/__pycache__/corner_cases_tab.cpython-311.pyc and b/gui/__pycache__/corner_cases_tab.cpython-311.pyc differ
|
|
|
gui/__pycache__/viewer_tab.cpython-311.pyc
CHANGED
|
Binary files a/gui/__pycache__/viewer_tab.cpython-311.pyc and b/gui/__pycache__/viewer_tab.cpython-311.pyc differ
|
|
|
gui/corner_cases_tab.py
CHANGED
|
@@ -3,6 +3,13 @@
|
|
| 3 |
|
| 4 |
import gradio as gr
|
| 5 |
from typing import TYPE_CHECKING
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
if TYPE_CHECKING:
|
| 8 |
from similarity_analysis.app import SimilarityApp
|
|
@@ -19,7 +26,6 @@ class CornerCasesTab:
|
|
| 19 |
ml_options = self.app.data_loader.get_ml_model_options()
|
| 20 |
|
| 21 |
gr.Markdown("## Corner Cases Analysis")
|
| 22 |
-
# gr.Markdown("Find the top 10 image pairs closest to each corner of the 3D space (Human × Brain × ML)")
|
| 23 |
|
| 24 |
with gr.Row():
|
| 25 |
with gr.Column(scale=1):
|
|
@@ -82,6 +88,49 @@ class CornerCasesTab:
|
|
| 82 |
'results_display': results_display
|
| 83 |
}
|
| 84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
def connect_events(self, components):
|
| 86 |
"""Connect event handlers for this tab"""
|
| 87 |
|
|
@@ -115,7 +164,7 @@ class CornerCasesTab:
|
|
| 115 |
|
| 116 |
if show_images:
|
| 117 |
# Show results with images in a grid
|
| 118 |
-
output += "<div style='display: grid; grid-template-columns: repeat(auto-fill, minmax(400px, 1fr)); gap: 20px; margin-bottom: 30px;'>"
|
| 119 |
|
| 120 |
for rank, result in enumerate(results, 1):
|
| 121 |
output += "<div style='border: 1px solid #ddd; padding: 15px; border-radius: 8px; background: #f9f9f9;'>"
|
|
@@ -123,8 +172,28 @@ class CornerCasesTab:
|
|
| 123 |
|
| 124 |
# Get image URLs and captions
|
| 125 |
data = self.app.data_loader.data
|
| 126 |
-
|
| 127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
|
| 129 |
# Format captions (handle multiple descriptions separated by |)
|
| 130 |
def format_caption_html(caption_text):
|
|
@@ -140,8 +209,8 @@ class CornerCasesTab:
|
|
| 140 |
html += "</ol>"
|
| 141 |
return html
|
| 142 |
|
| 143 |
-
caption1 = format_caption_html(
|
| 144 |
-
caption2 = format_caption_html(
|
| 145 |
|
| 146 |
# Display images side by side
|
| 147 |
output += "<div style='display: flex; gap: 10px; margin: 10px 0;'>"
|
|
@@ -170,12 +239,19 @@ class CornerCasesTab:
|
|
| 170 |
output += f"<tr><td>ML:</td><td>{result['ml_norm']:.3f}</td></tr>"
|
| 171 |
output += "</table>"
|
| 172 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
output += "</div>"
|
| 174 |
|
| 175 |
output += "</div>"
|
| 176 |
else:
|
| 177 |
# Text-only table format
|
| 178 |
-
output += "<table style='width: 100%; border-collapse: collapse; margin-bottom: 30px;'>"
|
| 179 |
output += "<thead><tr style='background: #f0f0f0;'>"
|
| 180 |
output += "<th style='border: 1px solid #ddd; padding: 8px;'>Rank</th>"
|
| 181 |
output += "<th style='border: 1px solid #ddd; padding: 8px;'>Pair #</th>"
|
|
|
|
| 3 |
|
| 4 |
import gradio as gr
|
| 5 |
from typing import TYPE_CHECKING
|
| 6 |
+
import plotly.graph_objects as go
|
| 7 |
+
from plotly.subplots import make_subplots
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
import matplotlib
|
| 10 |
+
matplotlib.use('Agg') # Use non-interactive backend
|
| 11 |
+
import io
|
| 12 |
+
import base64
|
| 13 |
|
| 14 |
if TYPE_CHECKING:
|
| 15 |
from similarity_analysis.app import SimilarityApp
|
|
|
|
| 26 |
ml_options = self.app.data_loader.get_ml_model_options()
|
| 27 |
|
| 28 |
gr.Markdown("## Corner Cases Analysis")
|
|
|
|
| 29 |
|
| 30 |
with gr.Row():
|
| 31 |
with gr.Column(scale=1):
|
|
|
|
| 88 |
'results_display': results_display
|
| 89 |
}
|
| 90 |
|
| 91 |
+
def create_single_pair_bar_plot(self, result, pair_index):
|
| 92 |
+
"""Create a bar plot showing normalized values for a single pair using matplotlib"""
|
| 93 |
+
|
| 94 |
+
# Create figure
|
| 95 |
+
fig, ax = plt.subplots(figsize=(6, 4))
|
| 96 |
+
|
| 97 |
+
categories = ['Human', 'Brain', 'ML']
|
| 98 |
+
values = [result['human_norm'], result['brain_norm'], result['ml_norm']]
|
| 99 |
+
colors = ['#4A90E2', '#50C878', '#E24A4A']
|
| 100 |
+
|
| 101 |
+
# Create bars
|
| 102 |
+
bars = ax.bar(categories, values, color=colors, alpha=0.8, edgecolor='black', linewidth=1.5)
|
| 103 |
+
|
| 104 |
+
# Add value labels on top of bars
|
| 105 |
+
for bar, val in zip(bars, values):
|
| 106 |
+
height = bar.get_height()
|
| 107 |
+
ax.text(bar.get_x() + bar.get_width()/2., height + 0.02,
|
| 108 |
+
f'{val:.3f}',
|
| 109 |
+
ha='center', va='bottom', fontsize=11, fontweight='bold')
|
| 110 |
+
|
| 111 |
+
# Styling
|
| 112 |
+
ax.set_ylabel('Normalized Value (0-1)', fontsize=11, fontweight='bold')
|
| 113 |
+
ax.set_xlabel('Measure', fontsize=11, fontweight='bold')
|
| 114 |
+
ax.set_title(f'Normalized Values for Pair #{pair_index}', fontsize=12, fontweight='bold')
|
| 115 |
+
ax.set_ylim(0, 1.15)
|
| 116 |
+
ax.grid(axis='y', alpha=0.3, linestyle='--')
|
| 117 |
+
ax.set_axisbelow(True)
|
| 118 |
+
|
| 119 |
+
# Style the plot
|
| 120 |
+
ax.spines['top'].set_visible(False)
|
| 121 |
+
ax.spines['right'].set_visible(False)
|
| 122 |
+
|
| 123 |
+
plt.tight_layout()
|
| 124 |
+
|
| 125 |
+
# Convert to base64 image
|
| 126 |
+
buf = io.BytesIO()
|
| 127 |
+
plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
|
| 128 |
+
buf.seek(0)
|
| 129 |
+
img_base64 = base64.b64encode(buf.read()).decode()
|
| 130 |
+
plt.close(fig)
|
| 131 |
+
|
| 132 |
+
return f'<img src="data:image/png;base64,{img_base64}" style="width: 100%; max-width: 450px; margin: 10px auto; display: block; border: 1px solid #ddd; border-radius: 5px; padding: 5px; background: white;" />'
|
| 133 |
+
|
| 134 |
def connect_events(self, components):
|
| 135 |
"""Connect event handlers for this tab"""
|
| 136 |
|
|
|
|
| 164 |
|
| 165 |
if show_images:
|
| 166 |
# Show results with images in a grid
|
| 167 |
+
output += "<div style='display: grid; grid-template-columns: repeat(auto-fill, minmax(400px, 1fr)); gap: 20px; margin-bottom: 30px; margin-top: 20px;'>"
|
| 168 |
|
| 169 |
for rank, result in enumerate(results, 1):
|
| 170 |
output += "<div style='border: 1px solid #ddd; padding: 15px; border-radius: 8px; background: #f9f9f9;'>"
|
|
|
|
| 172 |
|
| 173 |
# Get image URLs and captions
|
| 174 |
data = self.app.data_loader.data
|
| 175 |
+
pair_row = data.iloc[result['index']]
|
| 176 |
+
|
| 177 |
+
# Get URLs
|
| 178 |
+
img1_url = pair_row['stim_1']
|
| 179 |
+
img2_url = pair_row['stim_2']
|
| 180 |
+
|
| 181 |
+
# Check if we need to swap URLs to match image_1 and image_2 filenames
|
| 182 |
+
image_1_filename = str(pair_row.get('image_1', ''))
|
| 183 |
+
image_2_filename = str(pair_row.get('image_2', ''))
|
| 184 |
+
|
| 185 |
+
stim_1_matches_image_1 = image_1_filename in str(img1_url)
|
| 186 |
+
stim_1_matches_image_2 = image_2_filename in str(img1_url)
|
| 187 |
+
|
| 188 |
+
# Swap if needed
|
| 189 |
+
if stim_1_matches_image_2 and not stim_1_matches_image_1:
|
| 190 |
+
img1_url, img2_url = img2_url, img1_url
|
| 191 |
+
# Also swap captions
|
| 192 |
+
caption1_data = pair_row.get('image_2_description', 'No caption available')
|
| 193 |
+
caption2_data = pair_row.get('image_1_description', 'No caption available')
|
| 194 |
+
else:
|
| 195 |
+
caption1_data = pair_row.get('image_1_description', 'No caption available')
|
| 196 |
+
caption2_data = pair_row.get('image_2_description', 'No caption available')
|
| 197 |
|
| 198 |
# Format captions (handle multiple descriptions separated by |)
|
| 199 |
def format_caption_html(caption_text):
|
|
|
|
| 209 |
html += "</ol>"
|
| 210 |
return html
|
| 211 |
|
| 212 |
+
caption1 = format_caption_html(caption1_data)
|
| 213 |
+
caption2 = format_caption_html(caption2_data)
|
| 214 |
|
| 215 |
# Display images side by side
|
| 216 |
output += "<div style='display: flex; gap: 10px; margin: 10px 0;'>"
|
|
|
|
| 239 |
output += f"<tr><td>ML:</td><td>{result['ml_norm']:.3f}</td></tr>"
|
| 240 |
output += "</table>"
|
| 241 |
|
| 242 |
+
# Add bar plot for this specific pair
|
| 243 |
+
try:
|
| 244 |
+
bar_plot_html = self.create_single_pair_bar_plot(result, result['index'])
|
| 245 |
+
output += f"<div style='margin: 15px 0;'>{bar_plot_html}</div>"
|
| 246 |
+
except Exception as e:
|
| 247 |
+
output += f"<div style='color: red; padding: 10px;'>Error creating plot: {e}</div>"
|
| 248 |
+
|
| 249 |
output += "</div>"
|
| 250 |
|
| 251 |
output += "</div>"
|
| 252 |
else:
|
| 253 |
# Text-only table format
|
| 254 |
+
output += "<table style='width: 100%; border-collapse: collapse; margin-bottom: 30px; margin-top: 20px;'>"
|
| 255 |
output += "<thead><tr style='background: #f0f0f0;'>"
|
| 256 |
output += "<th style='border: 1px solid #ddd; padding: 8px;'>Rank</th>"
|
| 257 |
output += "<th style='border: 1px solid #ddd; padding: 8px;'>Pair #</th>"
|
gui/viewer_tab.py
CHANGED
|
@@ -26,6 +26,25 @@ class ViewerTab:
|
|
| 26 |
info=f"Enter 0 to {len(self.app.data_loader.data)-1}",
|
| 27 |
precision=0
|
| 28 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
show_btn = gr.Button("Show Images & Details", variant="primary", size="lg")
|
| 30 |
|
| 31 |
# Images side by side
|
|
@@ -44,6 +63,12 @@ class ViewerTab:
|
|
| 44 |
with gr.Row():
|
| 45 |
summary_card = gr.HTML("<div>Select an image pair to see summary</div>")
|
| 46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
# Brain measures and Model performance side by side
|
| 48 |
with gr.Row():
|
| 49 |
with gr.Column(scale=1):
|
|
@@ -66,55 +91,99 @@ class ViewerTab:
|
|
| 66 |
gr.Markdown("### ROI Brain Activation Comparison")
|
| 67 |
roi_plot = gr.Plot(label="Side-by-Side ROI Values", show_label=False)
|
| 68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
return {
|
| 70 |
'row_input': row_input,
|
|
|
|
|
|
|
| 71 |
'show_btn': show_btn,
|
| 72 |
'image1_display': image1_display,
|
| 73 |
'image2_display': image2_display,
|
| 74 |
'caption1_display': caption1_display,
|
| 75 |
'caption2_display': caption2_display,
|
| 76 |
'summary_card': summary_card,
|
|
|
|
| 77 |
'brain_table': brain_table,
|
| 78 |
'model_table': model_table,
|
| 79 |
'rankings_display': rankings_display,
|
| 80 |
-
'roi_plot': roi_plot
|
|
|
|
|
|
|
| 81 |
}
|
| 82 |
|
| 83 |
def connect_events(self, components):
|
| 84 |
"""Connect event handlers for this tab"""
|
| 85 |
-
def show_images_and_details(row_idx):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
results = self.app.show_image_pair_multi(int(row_idx) if row_idx is not None else 0)
|
| 87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
-
# Connect with multiple outputs
|
| 90 |
components['show_btn'].click(
|
| 91 |
fn=show_images_and_details,
|
| 92 |
-
inputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
outputs=[
|
| 94 |
components['image1_display'],
|
| 95 |
components['image2_display'],
|
| 96 |
components['caption1_display'],
|
| 97 |
components['caption2_display'],
|
| 98 |
components['summary_card'],
|
|
|
|
| 99 |
components['brain_table'],
|
| 100 |
components['model_table'],
|
| 101 |
components['rankings_display'],
|
| 102 |
-
components['roi_plot']
|
|
|
|
|
|
|
| 103 |
]
|
| 104 |
)
|
| 105 |
|
| 106 |
components['row_input'].change(
|
| 107 |
fn=show_images_and_details,
|
| 108 |
-
inputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
outputs=[
|
| 110 |
components['image1_display'],
|
| 111 |
components['image2_display'],
|
| 112 |
components['caption1_display'],
|
| 113 |
components['caption2_display'],
|
| 114 |
components['summary_card'],
|
|
|
|
| 115 |
components['brain_table'],
|
| 116 |
components['model_table'],
|
| 117 |
components['rankings_display'],
|
| 118 |
-
components['roi_plot']
|
|
|
|
|
|
|
| 119 |
]
|
| 120 |
)
|
|
|
|
| 26 |
info=f"Enter 0 to {len(self.app.data_loader.data)-1}",
|
| 27 |
precision=0
|
| 28 |
)
|
| 29 |
+
|
| 30 |
+
# Add dropdowns for corner distance calculation
|
| 31 |
+
brain_options = self.app.data_loader.get_brain_measure_options()
|
| 32 |
+
ml_options = self.app.data_loader.get_ml_model_options()
|
| 33 |
+
|
| 34 |
+
brain_dropdown = gr.Dropdown(
|
| 35 |
+
choices=brain_options,
|
| 36 |
+
value=brain_options[0][1] if brain_options else None,
|
| 37 |
+
label="Brain Measure (for 3D position)",
|
| 38 |
+
info="Select brain measure for corner distance calculation"
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
ml_dropdown = gr.Dropdown(
|
| 42 |
+
choices=ml_options,
|
| 43 |
+
value=ml_options[0][1] if ml_options else None,
|
| 44 |
+
label="ML Model (for 3D position)",
|
| 45 |
+
info="Select ML model for corner distance calculation"
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
show_btn = gr.Button("Show Images & Details", variant="primary", size="lg")
|
| 49 |
|
| 50 |
# Images side by side
|
|
|
|
| 63 |
with gr.Row():
|
| 64 |
summary_card = gr.HTML("<div>Select an image pair to see summary</div>")
|
| 65 |
|
| 66 |
+
# Bar plot for normalized values
|
| 67 |
+
with gr.Row():
|
| 68 |
+
with gr.Column():
|
| 69 |
+
gr.Markdown("### Normalized Values Visualization")
|
| 70 |
+
bar_plot_display = gr.HTML("<div>Bar plot will appear here</div>")
|
| 71 |
+
|
| 72 |
# Brain measures and Model performance side by side
|
| 73 |
with gr.Row():
|
| 74 |
with gr.Column(scale=1):
|
|
|
|
| 91 |
gr.Markdown("### ROI Brain Activation Comparison")
|
| 92 |
roi_plot = gr.Plot(label="Side-by-Side ROI Values", show_label=False)
|
| 93 |
|
| 94 |
+
# NEW: Corner distances section
|
| 95 |
+
with gr.Row():
|
| 96 |
+
with gr.Column():
|
| 97 |
+
gr.Markdown("### Position in 3D Space & Corner Distances")
|
| 98 |
+
gr.Markdown("This shows where this specific pair sits in the normalized 3D space (Human × Brain × ML) and its distance to each of the 8 corners.")
|
| 99 |
+
corner_distance_table = gr.HTML("<div>Select parameters and click 'Show Images & Details' to see corner distances</div>")
|
| 100 |
+
|
| 101 |
+
with gr.Row():
|
| 102 |
+
with gr.Column():
|
| 103 |
+
corner_3d_plot = gr.Plot(label="3D Position Visualization", show_label=False)
|
| 104 |
+
|
| 105 |
return {
|
| 106 |
'row_input': row_input,
|
| 107 |
+
'brain_dropdown': brain_dropdown,
|
| 108 |
+
'ml_dropdown': ml_dropdown,
|
| 109 |
'show_btn': show_btn,
|
| 110 |
'image1_display': image1_display,
|
| 111 |
'image2_display': image2_display,
|
| 112 |
'caption1_display': caption1_display,
|
| 113 |
'caption2_display': caption2_display,
|
| 114 |
'summary_card': summary_card,
|
| 115 |
+
'bar_plot_display': bar_plot_display,
|
| 116 |
'brain_table': brain_table,
|
| 117 |
'model_table': model_table,
|
| 118 |
'rankings_display': rankings_display,
|
| 119 |
+
'roi_plot': roi_plot,
|
| 120 |
+
'corner_distance_table': corner_distance_table,
|
| 121 |
+
'corner_3d_plot': corner_3d_plot
|
| 122 |
}
|
| 123 |
|
| 124 |
def connect_events(self, components):
|
| 125 |
"""Connect event handlers for this tab"""
|
| 126 |
+
def show_images_and_details(row_idx, brain_measure, ml_model_selection):
|
| 127 |
+
# Store the current brain measure and ML model in the app for the bar plot
|
| 128 |
+
self.app._current_brain_measure = brain_measure
|
| 129 |
+
self.app._current_ml_model = ml_model_selection
|
| 130 |
+
|
| 131 |
+
# Get basic image pair info
|
| 132 |
results = self.app.show_image_pair_multi(int(row_idx) if row_idx is not None else 0)
|
| 133 |
+
|
| 134 |
+
# Get corner distances and 3D plot
|
| 135 |
+
corner_html, corner_plot = self.app.get_point_corner_distances(
|
| 136 |
+
int(row_idx) if row_idx is not None else 0,
|
| 137 |
+
brain_measure,
|
| 138 |
+
ml_model_selection
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# Return all outputs including the new corner distance outputs
|
| 142 |
+
return (*results, corner_html, corner_plot)
|
| 143 |
|
| 144 |
+
# Connect with multiple outputs (added bar_plot_display)
|
| 145 |
components['show_btn'].click(
|
| 146 |
fn=show_images_and_details,
|
| 147 |
+
inputs=[
|
| 148 |
+
components['row_input'],
|
| 149 |
+
components['brain_dropdown'],
|
| 150 |
+
components['ml_dropdown']
|
| 151 |
+
],
|
| 152 |
outputs=[
|
| 153 |
components['image1_display'],
|
| 154 |
components['image2_display'],
|
| 155 |
components['caption1_display'],
|
| 156 |
components['caption2_display'],
|
| 157 |
components['summary_card'],
|
| 158 |
+
components['bar_plot_display'],
|
| 159 |
components['brain_table'],
|
| 160 |
components['model_table'],
|
| 161 |
components['rankings_display'],
|
| 162 |
+
components['roi_plot'],
|
| 163 |
+
components['corner_distance_table'],
|
| 164 |
+
components['corner_3d_plot']
|
| 165 |
]
|
| 166 |
)
|
| 167 |
|
| 168 |
components['row_input'].change(
|
| 169 |
fn=show_images_and_details,
|
| 170 |
+
inputs=[
|
| 171 |
+
components['row_input'],
|
| 172 |
+
components['brain_dropdown'],
|
| 173 |
+
components['ml_dropdown']
|
| 174 |
+
],
|
| 175 |
outputs=[
|
| 176 |
components['image1_display'],
|
| 177 |
components['image2_display'],
|
| 178 |
components['caption1_display'],
|
| 179 |
components['caption2_display'],
|
| 180 |
components['summary_card'],
|
| 181 |
+
components['bar_plot_display'],
|
| 182 |
components['brain_table'],
|
| 183 |
components['model_table'],
|
| 184 |
components['rankings_display'],
|
| 185 |
+
components['roi_plot'],
|
| 186 |
+
components['corner_distance_table'],
|
| 187 |
+
components['corner_3d_plot']
|
| 188 |
]
|
| 189 |
)
|
visualization/__pycache__/image_viewer.cpython-311.pyc
CHANGED
|
Binary files a/visualization/__pycache__/image_viewer.cpython-311.pyc and b/visualization/__pycache__/image_viewer.cpython-311.pyc differ
|
|
|
visualization/image_viewer.py
CHANGED
|
@@ -24,16 +24,58 @@ class ImageViewer:
|
|
| 24 |
return placeholder
|
| 25 |
|
| 26 |
@staticmethod
|
| 27 |
-
def get_image_pair(data: pd.DataFrame, row_index: int) -> Tuple[Optional[Image.Image], Optional[Image.Image]]:
|
| 28 |
-
"""Get image pair for a specific row
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
if row_index >= len(data):
|
| 30 |
-
return None, None
|
| 31 |
|
| 32 |
row = data.iloc[row_index]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
img1_url = row.get('stim_1', '')
|
| 34 |
img2_url = row.get('stim_2', '')
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
img1 = ImageViewer.load_image_from_url(img1_url) if img1_url else None
|
| 37 |
img2 = ImageViewer.load_image_from_url(img2_url) if img2_url else None
|
| 38 |
|
| 39 |
-
return img1, img2
|
|
|
|
| 24 |
return placeholder
|
| 25 |
|
| 26 |
@staticmethod
|
| 27 |
+
def get_image_pair(data: pd.DataFrame, row_index: int) -> Tuple[Optional[Image.Image], Optional[Image.Image], bool]:
|
| 28 |
+
"""Get image pair for a specific row
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
tuple: (img1, img2, was_swapped)
|
| 32 |
+
- img1: Image for image_1
|
| 33 |
+
- img2: Image for image_2
|
| 34 |
+
- was_swapped: True if URLs were swapped to match filenames
|
| 35 |
+
"""
|
| 36 |
if row_index >= len(data):
|
| 37 |
+
return None, None, False
|
| 38 |
|
| 39 |
row = data.iloc[row_index]
|
| 40 |
+
was_swapped = False
|
| 41 |
+
|
| 42 |
+
# DEBUG: Check which columns exist
|
| 43 |
+
print(f"\n[IMAGE_VIEWER DEBUG] Available URL columns: {[col for col in row.index if 'stim' in col.lower() or 'url' in col.lower()]}")
|
| 44 |
+
print(f"[IMAGE_VIEWER DEBUG] image_1: {row.get('image_1', 'MISSING')}")
|
| 45 |
+
print(f"[IMAGE_VIEWER DEBUG] image_2: {row.get('image_2', 'MISSING')}")
|
| 46 |
+
|
| 47 |
+
# Try different possible column names for URLs
|
| 48 |
+
# Priority: use stim_1/stim_2 which are the URL columns
|
| 49 |
img1_url = row.get('stim_1', '')
|
| 50 |
img2_url = row.get('stim_2', '')
|
| 51 |
|
| 52 |
+
print(f"[IMAGE_VIEWER DEBUG] Loading stim_1 URL: {img1_url[:80] if img1_url else 'EMPTY'}...")
|
| 53 |
+
print(f"[IMAGE_VIEWER DEBUG] Loading stim_2 URL: {img2_url[:80] if img2_url else 'EMPTY'}...")
|
| 54 |
+
|
| 55 |
+
# Check if stim_1 corresponds to image_1 or image_2 by looking at the filename in the URL
|
| 56 |
+
# If the URL contains the image_1 filename, then stim_1 = image_1
|
| 57 |
+
# Otherwise they might be swapped
|
| 58 |
+
image_1_filename = str(row.get('image_1', ''))
|
| 59 |
+
image_2_filename = str(row.get('image_2', ''))
|
| 60 |
+
|
| 61 |
+
# Check if we need to swap
|
| 62 |
+
stim_1_matches_image_1 = image_1_filename in str(img1_url)
|
| 63 |
+
stim_1_matches_image_2 = image_2_filename in str(img1_url)
|
| 64 |
+
|
| 65 |
+
print(f"[IMAGE_VIEWER DEBUG] stim_1 contains image_1 filename? {stim_1_matches_image_1}")
|
| 66 |
+
print(f"[IMAGE_VIEWER DEBUG] stim_1 contains image_2 filename? {stim_1_matches_image_2}")
|
| 67 |
+
|
| 68 |
+
# If stim_1 contains image_2's filename, we need to swap
|
| 69 |
+
if stim_1_matches_image_2 and not stim_1_matches_image_1:
|
| 70 |
+
print("[IMAGE_VIEWER DEBUG] ⚠️ SWAPPING: stim_1 corresponds to image_2, swapping URLs")
|
| 71 |
+
img1_url, img2_url = img2_url, img1_url
|
| 72 |
+
was_swapped = True
|
| 73 |
+
elif stim_1_matches_image_1:
|
| 74 |
+
print("[IMAGE_VIEWER DEBUG] ✓ No swap needed: stim_1 corresponds to image_1")
|
| 75 |
+
else:
|
| 76 |
+
print("[IMAGE_VIEWER DEBUG] ⚠️ WARNING: Could not determine correspondence, assuming stim_1=image_1")
|
| 77 |
+
|
| 78 |
img1 = ImageViewer.load_image_from_url(img1_url) if img1_url else None
|
| 79 |
img2 = ImageViewer.load_image_from_url(img2_url) if img2_url else None
|
| 80 |
|
| 81 |
+
return img1, img2, was_swapped
|