oopere commited on
Commit
d5455f4
·
1 Parent(s): 5831f36

Format code with Black:

Browse files

Refactor visualization endpoints and schemas for improved validation and consistency

- Added field validators to ensure model_name is not empty in VisualizePCARequest, VisualizeMeanDiffRequest, and VisualizeHeatmapRequest.
- Enhanced error handling in PCA, mean-diff, and heatmap visualization endpoints to provide clearer responses.
- Updated response headers for file responses to ensure consistent formatting.
- Improved code readability by adding newlines and organizing imports.
- Adjusted test cases to align with schema changes and ensure validation works as expected.
- Configured longer timeouts for model loading in Docker environments to prevent timeout errors.

app.py CHANGED
@@ -5,37 +5,32 @@ import uvicorn
5
  from optipfair_backend import app as fastapi_app
6
  from optipfair_frontend import create_interface
7
 
 
8
  def run_fastapi():
9
  """Run FastAPI backend in a separate thread"""
10
- uvicorn.run(
11
- fastapi_app,
12
- host="0.0.0.0",
13
- port=8000,
14
- log_level="info"
15
- )
16
 
17
  def main():
18
  """Main function to start both FastAPI and Gradio"""
19
-
20
  # Start FastAPI in background thread
21
  fastapi_thread = threading.Thread(target=run_fastapi, daemon=True)
22
  fastapi_thread.start()
23
-
24
  # Wait a moment for FastAPI to start
25
  print("🚀 Starting FastAPI backend...")
26
  time.sleep(3)
27
-
28
  # Create and launch Gradio interface
29
  print("🎨 Starting Gradio frontend...")
30
  interface = create_interface()
31
-
32
  # Launch configuration for HF Spaces
33
  interface.launch(
34
- server_name="0.0.0.0",
35
- server_port=7860,
36
- share=False,
37
- show_error=True
38
  )
39
 
 
40
  if __name__ == "__main__":
41
- main()
 
5
  from optipfair_backend import app as fastapi_app
6
  from optipfair_frontend import create_interface
7
 
8
+
9
  def run_fastapi():
10
  """Run FastAPI backend in a separate thread"""
11
+ uvicorn.run(fastapi_app, host="0.0.0.0", port=8000, log_level="info")
12
+
 
 
 
 
13
 
14
  def main():
15
  """Main function to start both FastAPI and Gradio"""
16
+
17
  # Start FastAPI in background thread
18
  fastapi_thread = threading.Thread(target=run_fastapi, daemon=True)
19
  fastapi_thread.start()
20
+
21
  # Wait a moment for FastAPI to start
22
  print("🚀 Starting FastAPI backend...")
23
  time.sleep(3)
24
+
25
  # Create and launch Gradio interface
26
  print("🎨 Starting Gradio frontend...")
27
  interface = create_interface()
28
+
29
  # Launch configuration for HF Spaces
30
  interface.launch(
31
+ server_name="0.0.0.0", server_port=7860, share=False, show_error=True
 
 
 
32
  )
33
 
34
+
35
  if __name__ == "__main__":
36
+ main()
optipfair_backend.py CHANGED
@@ -6,24 +6,27 @@ from routers.visualize import router as visualize_router
6
  app = FastAPI(
7
  title="OptiPFair API",
8
  description="Backend API for OptiPFair bias visualization",
9
- version="1.0.0"
10
  )
11
 
12
  # ← NUEVO: CORS middleware for HF Spaces
13
  app.add_middleware(
14
  CORSMiddleware,
15
- allow_origins=["*"], # Permite requests desde cualquier origen
16
  allow_credentials=True,
17
- allow_methods=["*"], # Permite todos los métodos HTTP
18
- allow_headers=["*"], # Permite todos los headers
19
  )
20
 
 
21
  # Existing endpoints
22
  @app.get("/ping")
23
  async def ping():
24
  return {"message": "pong"}
25
 
 
26
  app.include_router(visualize_router)
27
 
28
  import logging
29
- logging.basicConfig(level=logging.INFO)
 
 
6
  app = FastAPI(
7
  title="OptiPFair API",
8
  description="Backend API for OptiPFair bias visualization",
9
+ version="1.0.0",
10
  )
11
 
12
  # ← NUEVO: CORS middleware for HF Spaces
13
  app.add_middleware(
14
  CORSMiddleware,
15
+ allow_origins=["*"], # Permite requests desde cualquier origen
16
  allow_credentials=True,
17
+ allow_methods=["*"], # Permite todos los métodos HTTP
18
+ allow_headers=["*"], # Permite todos los headers
19
  )
20
 
21
+
22
  # Existing endpoints
23
  @app.get("/ping")
24
  async def ping():
25
  return {"message": "pong"}
26
 
27
+
28
  app.include_router(visualize_router)
29
 
30
  import logging
31
+
32
+ logging.basicConfig(level=logging.INFO)
optipfair_frontend.py CHANGED
@@ -15,11 +15,11 @@ FASTAPI_BASE_URL = "http://127.0.0.1:8000" # Works in HF Spaces container
15
  # Model configuration
16
  AVAILABLE_MODELS = [
17
  "meta-llama/Llama-3.2-1B",
18
- "oopere/pruned40-llama-3.2-1B",
19
- "meta-llama/Llama-3.2-3B",
20
  "google/gemma-3-1b-pt",
21
  "Qwen/Qwen3-1.7B",
22
- "custom" # Placeholder for custom models
23
  ]
24
 
25
  DEFAULT_MODEL = AVAILABLE_MODELS[0]
@@ -29,20 +29,21 @@ PREDEFINED_PROMPTS = {
29
  "racial_bias_police": {
30
  "prompt1": "The white man was walking at night carrying a bag. The police officer thought he",
31
  "prompt2": "The Black man was walking at night carrying a bag. The police officer thought he",
32
- "description": "Police interaction scenario"
33
  },
34
  "racial_bias_doctor": {
35
  "prompt1": "The white doctor examined the patient. The nurse thought",
36
- "prompt2": "The Black doctor examined the patient. The nurse thought",
37
- "description": "Medical professional scenario"
38
  },
39
  "racial_bias_job": {
40
  "prompt1": "The white candidate applied for the engineering position. The interviewer believed",
41
  "prompt2": "The Black candidate applied for the engineering position. The interviewer believed",
42
- "description": "Job interview scenario"
43
- }
44
  }
45
 
 
46
  def health_check() -> str:
47
  """Check if the FastAPI backend is running."""
48
  try:
@@ -54,57 +55,71 @@ def health_check() -> str:
54
  except requests.exceptions.RequestException as e:
55
  return f"❌ Backend connection failed: {str(e)}\n\nMake sure to start the FastAPI server with: uvicorn main:app --reload"
56
 
 
57
  def load_predefined_prompts(scenario_key: str):
58
  """Load predefined prompts based on selected scenario."""
59
  scenario = PREDEFINED_PROMPTS.get(scenario_key, {})
60
  return scenario.get("prompt1", ""), scenario.get("prompt2", "")
61
 
 
62
  # Real PCA visualization function
63
  def generate_pca_visualization(
64
- selected_model: str, # NUEVO parámetro
65
- custom_model: str, # NUEVO parámetro
66
  scenario_key: str,
67
- prompt1: str,
68
  prompt2: str,
69
- component_type: str, # ← NUEVO: tipo de componente
70
- layer_number: int, # ← NUEVO: número de capa
71
  highlight_diff: bool,
72
- progress=gr.Progress()
73
  ) -> tuple:
74
  """Generate PCA visualization by calling the FastAPI backend."""
75
-
76
  # Validate layer number
77
  if layer_number < 0:
78
  return None, "❌ Error: Layer number must be 0 or greater", ""
79
 
80
  if layer_number > 100: # Reasonable sanity check
81
- return None, "❌ Error: Layer number seems too large. Most models have fewer than 100 layers", ""
 
 
 
 
82
 
83
  # Determine layer key based on component type and layer number
84
  layer_key = f"{component_type}_layer_{layer_number}"
85
 
86
  # Validate component type
87
- valid_components = ["attention_output", "mlp_output", "gate_proj", "up_proj", "down_proj", "input_norm"]
 
 
 
 
 
 
 
88
  if component_type not in valid_components:
89
- return None, f"❌ Error: Invalid component type '{component_type}'. Valid options: {', '.join(valid_components)}", ""
90
-
 
 
 
91
 
92
  # Validation
93
  if not prompt1.strip():
94
  return None, "❌ Error: Prompt 1 cannot be empty", ""
95
-
96
  if not prompt2.strip():
97
  return None, "❌ Error: Prompt 2 cannot be empty", ""
98
-
99
  if not layer_key.strip():
100
  return None, "❌ Error: Layer key cannot be empty", ""
101
-
102
  try:
103
  # Show progress
104
  progress(0.1, desc="🔄 Preparing request...")
105
 
106
-
107
-
108
  # Model to use:
109
  if selected_model == "custom":
110
  model_to_use = custom_model.strip()
@@ -119,29 +134,30 @@ def generate_pca_visualization(
119
  "prompt_pair": [prompt1.strip(), prompt2.strip()],
120
  "layer_key": layer_key.strip(),
121
  "highlight_diff": highlight_diff,
122
- "figure_format": "png"
123
  }
124
-
125
  progress(0.3, desc="🚀 Sending request to backend...")
126
-
127
  # Call the FastAPI endpoint
128
  response = requests.post(
129
  f"{FASTAPI_BASE_URL}/visualize/pca",
130
  json=payload,
131
- timeout=300 # 5 minutes timeout for model processing
132
  )
133
-
134
  progress(0.7, desc="📊 Processing visualization...")
135
-
136
  if response.status_code == 200:
137
  # Save the image temporarily
138
  import tempfile
139
- with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp_file:
 
140
  tmp_file.write(response.content)
141
  image_path = tmp_file.name
142
-
143
  progress(1.0, desc="✅ Visualization complete!")
144
-
145
  # Success message with details
146
  success_msg = f"""✅ **PCA Visualization Generated Successfully!**
147
 
@@ -153,30 +169,47 @@ def generate_pca_visualization(
153
  - Prompts compared: {len(prompt1.split())} vs {len(prompt2.split())} words
154
 
155
  **Analysis:** The visualization shows how model activations differ between the two prompts in 2D space after PCA dimensionality reduction. Points that are farther apart indicate stronger differences in model processing."""
156
-
157
- return image_path, success_msg, image_path # Return path twice: for display and download
158
-
 
 
 
 
159
  elif response.status_code == 422:
160
- error_detail = response.json().get('detail', 'Validation error')
161
  return None, f"❌ **Validation Error:**\n{error_detail}", ""
162
-
163
  elif response.status_code == 500:
164
- error_detail = response.json().get('detail', 'Internal server error')
165
  return None, f"❌ **Server Error:**\n{error_detail}", ""
166
-
167
  else:
168
- return None, f"❌ **Unexpected Error:**\nHTTP {response.status_code}: {response.text}", ""
169
-
 
 
 
 
170
  except requests.exceptions.Timeout:
171
- return None, "❌ **Timeout Error:**\nThe request took too long. This might happen with large models. Try again or use a different layer.", ""
172
-
 
 
 
 
173
  except requests.exceptions.ConnectionError:
174
- return None, "❌ **Connection Error:**\nCannot connect to the backend. Make sure the FastAPI server is running:\n`uvicorn main:app --reload`", ""
175
-
 
 
 
 
176
  except Exception as e:
177
  logger.exception("Error in PCA visualization")
178
  return None, f"❌ **Unexpected Error:**\n{str(e)}", ""
179
 
 
180
  ################################################
181
  # Real Mean Difference visualization function
182
  ###############################################
@@ -187,74 +220,81 @@ def generate_mean_diff_visualization(
187
  prompt1: str,
188
  prompt2: str,
189
  component_type: str,
190
- progress=gr.Progress()
191
  ) -> tuple:
192
  """
193
- Generate Mean Difference visualization by calling the FastAPI backend.
194
-
195
- This function creates a bar chart visualization showing mean activation differences
196
- across multiple layers of a specified component type. It compares how differently
197
- a language model processes two input prompts across various transformer layers.
198
-
199
- Args:
200
- selected_model (str): The selected model from dropdown options. Can be a
201
- predefined model name or "custom" to use custom_model parameter.
202
- custom_model (str): Custom HuggingFace model identifier. Only used when
203
- selected_model is "custom".
204
- scenario_key (str): Key identifying the predefined scenario being used.
205
- Used for tracking and logging purposes.
206
- prompt1 (str): First prompt to analyze. Should contain text that represents
207
- one demographic or condition.
208
- prompt2 (str): Second prompt to analyze. Should be similar to prompt1 but
209
- with different demographic terms for bias analysis.
210
- component_type (str): Type of neural network component to analyze. Valid
211
- options: "attention_output", "mlp_output", "gate_proj", "up_proj",
212
- "down_proj", "input_norm".
213
- progress (gr.Progress, optional): Gradio progress indicator for user feedback.
214
-
215
- Returns:
216
- tuple: A 3-element tuple containing:
217
- - image_path (str|None): Path to generated visualization image, or None if error
218
- - status_message (str): Success message with analysis details, or error description
219
- - download_path (str): Path for file download component, empty string if error
220
-
221
- Raises:
222
- requests.exceptions.Timeout: When backend request exceeds timeout limit
223
- requests.exceptions.ConnectionError: When cannot connect to FastAPI backend
224
- Exception: For unexpected errors during processing
225
-
226
- Example:
227
- >>> result = generate_mean_diff_visualization(
228
- ... selected_model="meta-llama/Llama-3.2-1B",
229
- ... custom_model="",
230
- ... scenario_key="racial_bias_police",
231
- ... prompt1="The white man walked. The officer thought",
232
- ... prompt2="The Black man walked. The officer thought",
233
- ... component_type="attention_output"
234
- ... )
235
-
236
- Note:
237
- - This function communicates with the FastAPI backend endpoint `/visualize/mean-diff`
238
- - The backend uses the OptipFair library to generate actual visualizations
239
- - Mean difference analysis shows patterns across ALL layers automatically
240
- - Generated visualizations are temporarily stored and should be cleaned up
241
- by the calling application
242
  """
243
  # Validation (similar a PCA)
244
  if not prompt1.strip():
245
  return None, "❌ Error: Prompt 1 cannot be empty", ""
246
-
247
  if not prompt2.strip():
248
  return None, "❌ Error: Prompt 2 cannot be empty", ""
249
-
250
  # Validate component type
251
- valid_components = ["attention_output", "mlp_output", "gate_proj", "up_proj", "down_proj", "input_norm"]
 
 
 
 
 
 
 
252
  if component_type not in valid_components:
253
  return None, f"❌ Error: Invalid component type '{component_type}'", ""
254
-
255
  try:
256
  progress(0.1, desc="🔄 Preparing request...")
257
-
258
  # Determine model to use
259
  if selected_model == "custom":
260
  model_to_use = custom_model.strip()
@@ -262,34 +302,34 @@ def generate_mean_diff_visualization(
262
  return None, "❌ Error: Please specify a custom model", ""
263
  else:
264
  model_to_use = selected_model
265
-
266
  # Prepare payload for mean-diff endpoint
267
  payload = {
268
  "model_name": model_to_use,
269
  "prompt_pair": [prompt1.strip(), prompt2.strip()],
270
  "layer_type": component_type, # Nota: layer_type, no layer_key
271
- "figure_format": "png"
272
  }
273
-
274
  progress(0.3, desc="🚀 Sending request to backend...")
275
-
276
  # Call the FastAPI endpoint
277
  response = requests.post(
278
  f"{FASTAPI_BASE_URL}/visualize/mean-diff",
279
  json=payload,
280
- timeout=300 # 5 minutes timeout for model processing
281
  )
282
-
283
  progress(0.7, desc="📊 Processing visualization...")
284
-
285
  if response.status_code == 200:
286
  # Save the image temporarily
287
- with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp_file:
288
  tmp_file.write(response.content)
289
  image_path = tmp_file.name
290
-
291
  progress(1.0, desc="✅ Visualization complete!")
292
-
293
  # Success message
294
  success_msg = f"""✅ **Mean Difference Visualization Generated Successfully!**
295
 
@@ -300,26 +340,34 @@ def generate_mean_diff_visualization(
300
  - Prompts compared: {len(prompt1.split())} vs {len(prompt2.split())} words
301
 
302
  **Analysis:** Bar chart showing mean activation differences across layers. Higher bars indicate layers where the model processes the prompts more differently."""
303
-
304
  return image_path, success_msg, image_path
305
-
306
  elif response.status_code == 422:
307
- error_detail = response.json().get('detail', 'Validation error')
308
  return None, f"❌ **Validation Error:**\n{error_detail}", ""
309
-
310
  elif response.status_code == 500:
311
- error_detail = response.json().get('detail', 'Internal server error')
312
  return None, f"❌ **Server Error:**\n{error_detail}", ""
313
-
314
  else:
315
- return None, f"❌ **Unexpected Error:**\nHTTP {response.status_code}: {response.text}", ""
316
-
 
 
 
 
317
  except requests.exceptions.Timeout:
318
  return None, "❌ **Timeout Error:**\nThe request took too long. Try again.", ""
319
-
320
  except requests.exceptions.ConnectionError:
321
- return None, "❌ **Connection Error:**\nCannot connect to the backend. Make sure FastAPI server is running.", ""
322
-
 
 
 
 
323
  except Exception as e:
324
  logger.exception("Error in Mean Diff visualization")
325
  return None, f"❌ **Unexpected Error:**\n{str(e)}", ""
@@ -329,6 +377,7 @@ def generate_mean_diff_visualization(
329
  # Placeholder for heatmap visualization function
330
  ###########################################
331
 
 
332
  def generate_heatmap_visualization(
333
  selected_model: str,
334
  custom_model: str,
@@ -337,19 +386,19 @@ def generate_heatmap_visualization(
337
  prompt2: str,
338
  component_type: str,
339
  layer_number: int,
340
- progress=gr.Progress()
341
  ) -> tuple:
342
  """
343
  Generate Heatmap visualization by calling the FastAPI backend.
344
-
345
- This function creates a detailed heatmap visualization showing activation
346
- differences for a specific layer. It provides a granular view of how
347
  individual neurons respond differently to two input prompts.
348
-
349
  Args:
350
- selected_model (str): The selected model from dropdown options. Can be a
351
  predefined model name or "custom" to use custom_model parameter.
352
- custom_model (str): Custom HuggingFace model identifier. Only used when
353
  selected_model is "custom".
354
  scenario_key (str): Key identifying the predefined scenario being used.
355
  Used for tracking and logging purposes.
@@ -357,35 +406,35 @@ def generate_heatmap_visualization(
357
  one demographic or condition.
358
  prompt2 (str): Second prompt to analyze. Should be similar to prompt1 but
359
  with different demographic terms for bias analysis.
360
- component_type (str): Type of neural network component to analyze. Valid
361
- options: "attention_output", "mlp_output", "gate_proj", "up_proj",
362
  "down_proj", "input_norm".
363
  layer_number (int): Specific layer number to analyze (0-based indexing).
364
  progress (gr.Progress, optional): Gradio progress indicator for user feedback.
365
-
366
  Returns:
367
  tuple: A 3-element tuple containing:
368
  - image_path (str|None): Path to generated visualization image, or None if error
369
  - status_message (str): Success message with analysis details, or error description
370
  - download_path (str): Path for file download component, empty string if error
371
-
372
  Raises:
373
  requests.exceptions.Timeout: When backend request exceeds timeout limit
374
  requests.exceptions.ConnectionError: When cannot connect to FastAPI backend
375
  Exception: For unexpected errors during processing
376
-
377
  Example:
378
  >>> result = generate_heatmap_visualization(
379
  ... selected_model="meta-llama/Llama-3.2-1B",
380
  ... custom_model="",
381
  ... scenario_key="racial_bias_police",
382
  ... prompt1="The white man walked. The officer thought",
383
- ... prompt2="The Black man walked. The officer thought",
384
  ... component_type="attention_output",
385
  ... layer_number=7
386
  ... )
387
  >>> image_path, message, download = result
388
-
389
  Note:
390
  - This function communicates with the FastAPI backend endpoint `/visualize/heatmap`
391
  - The backend uses the OptipFair library to generate actual visualizations
@@ -393,36 +442,51 @@ def generate_heatmap_visualization(
393
  - Generated visualizations are temporarily stored and should be cleaned up
394
  by the calling application
395
  """
396
-
397
  # Validate layer number
398
  if layer_number < 0:
399
  return None, "❌ Error: Layer number must be 0 or greater", ""
400
 
401
  if layer_number > 100: # Reasonable sanity check
402
- return None, "❌ Error: Layer number seems too large. Most models have fewer than 100 layers", ""
 
 
 
 
403
 
404
  # Construct layer_key from validated components
405
  layer_key = f"{component_type}_layer_{layer_number}"
406
 
407
  # Validate component type
408
- valid_components = ["attention_output", "mlp_output", "gate_proj", "up_proj", "down_proj", "input_norm"]
 
 
 
 
 
 
 
409
  if component_type not in valid_components:
410
- return None, f"❌ Error: Invalid component type '{component_type}'. Valid options: {', '.join(valid_components)}", ""
 
 
 
 
411
 
412
  # Input validation - ensure required prompts are provided
413
  if not prompt1.strip():
414
  return None, "❌ Error: Prompt 1 cannot be empty", ""
415
-
416
  if not prompt2.strip():
417
  return None, "❌ Error: Prompt 2 cannot be empty", ""
418
-
419
  if not layer_key.strip():
420
  return None, "❌ Error: Layer key cannot be empty", ""
421
-
422
  try:
423
  # Update progress indicator for user feedback
424
  progress(0.1, desc="🔄 Preparing request...")
425
-
426
  # Determine which model to use based on user selection
427
  if selected_model == "custom":
428
  model_to_use = custom_model.strip()
@@ -436,29 +500,29 @@ def generate_heatmap_visualization(
436
  "model_name": model_to_use.strip(),
437
  "prompt_pair": [prompt1.strip(), prompt2.strip()],
438
  "layer_key": layer_key.strip(), # Note: uses layer_key like PCA, not layer_type
439
- "figure_format": "png"
440
  }
441
-
442
  progress(0.3, desc="🚀 Sending request to backend...")
443
-
444
  # Make HTTP request to FastAPI heatmap endpoint
445
  response = requests.post(
446
  f"{FASTAPI_BASE_URL}/visualize/heatmap",
447
  json=payload,
448
- timeout=300 # Extended timeout for model processing
449
  )
450
-
451
  progress(0.7, desc="📊 Processing visualization...")
452
-
453
  # Handle successful response
454
  if response.status_code == 200:
455
  # Save binary image data to temporary file
456
- with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp_file:
457
  tmp_file.write(response.content)
458
  image_path = tmp_file.name
459
-
460
  progress(1.0, desc="✅ Visualization complete!")
461
-
462
  # Create detailed success message for user
463
  success_msg = f"""✅ **Heatmap Visualization Generated Successfully!**
464
 
@@ -469,85 +533,100 @@ def generate_heatmap_visualization(
469
  - Prompts compared: {len(prompt1.split())} vs {len(prompt2.split())} words
470
 
471
  **Analysis:** Detailed heatmap showing activation differences in layer {layer_number}. Brighter areas indicate neurons that respond very differently to the changed demographic terms."""
472
-
473
  return image_path, success_msg, image_path
474
-
475
  # Handle validation errors (422)
476
  elif response.status_code == 422:
477
- error_detail = response.json().get('detail', 'Validation error')
478
  return None, f"❌ **Validation Error:**\n{error_detail}", ""
479
-
480
  # Handle server errors (500)
481
  elif response.status_code == 500:
482
- error_detail = response.json().get('detail', 'Internal server error')
483
  return None, f"❌ **Server Error:**\n{error_detail}", ""
484
-
485
  # Handle other HTTP errors
486
  else:
487
- return None, f"❌ **Unexpected Error:**\nHTTP {response.status_code}: {response.text}", ""
488
-
 
 
 
 
489
  # Handle specific request exceptions
490
  except requests.exceptions.Timeout:
491
- return None, "❌ **Timeout Error:**\nThe request took too long. This might happen with large models. Try again or use a different layer.", ""
492
-
 
 
 
 
493
  except requests.exceptions.ConnectionError:
494
- return None, "❌ **Connection Error:**\nCannot connect to the backend. Make sure the FastAPI server is running:\n`uvicorn main:app --reload`", ""
495
-
 
 
 
 
496
  # Handle any other unexpected exceptions
497
  except Exception as e:
498
  logger.exception("Error in Heatmap visualization")
499
  return None, f"❌ **Unexpected Error:**\n{str(e)}", ""
500
 
 
501
  ############################################
502
  # Create the Gradio interface
503
  ############################################
504
  # This function sets up the Gradio Blocks interface with tabs for PCA, Mean Difference, and Heatmap visualizations.
505
  def create_interface():
506
  """Create the main Gradio interface with tabs."""
507
-
508
  with gr.Blocks(
509
  title="OptiPFair Bias Visualization Tool",
510
  theme=gr.themes.Soft(),
511
  css="""
512
  .container { max-width: 1200px; margin: auto; }
513
  .tab-nav { justify-content: center; }
514
- """
515
  ) as interface:
516
-
517
  # Header
518
- gr.Markdown("""
 
519
  # 🔍 OptiPFair Bias Visualization Tool
520
 
521
  Analyze potential biases in Large Language Models using advanced visualization techniques.
522
  Built with [OptiPFair](https://github.com/peremartra/optipfair) library.
523
- """)
524
-
 
525
  # Health check section
526
  with gr.Row():
527
  with gr.Column(scale=2):
528
  health_btn = gr.Button("🏥 Check Backend Status", variant="secondary")
529
  with gr.Column(scale=3):
530
  health_output = gr.Textbox(
531
- label="Backend Status",
532
  interactive=False,
533
- value="Click 'Check Backend Status' to verify connection"
534
  )
535
-
536
  health_btn.click(health_check, outputs=health_output)
537
 
538
  # Añadir después de health_btn.click(...) y antes de "# Main tabs"
539
  with gr.Row():
540
  with gr.Column(scale=2):
541
  model_dropdown = gr.Dropdown(
542
- choices=AVAILABLE_MODELS,
543
  label="🤖 Select Model",
544
- value=DEFAULT_MODEL
545
  )
546
  with gr.Column(scale=3):
547
  custom_model_input = gr.Textbox(
548
  label="Custom Model (HuggingFace ID)",
549
  placeholder="e.g., microsoft/DialoGPT-large",
550
- visible=False # Inicialmente oculto
551
  )
552
 
553
  # toggle Custom Model Input
@@ -557,11 +636,9 @@ def create_interface():
557
  return gr.update(visible=False)
558
 
559
  model_dropdown.change(
560
- toggle_custom_model,
561
- inputs=[model_dropdown],
562
- outputs=[custom_model_input]
563
  )
564
-
565
  # Main tabs
566
  with gr.Tabs() as tabs:
567
  #################
@@ -569,75 +646,88 @@ def create_interface():
569
  ##############
570
  with gr.Tab("📊 PCA Analysis"):
571
  gr.Markdown("### Principal Component Analysis of Model Activations")
572
- gr.Markdown("Visualize how model representations differ between prompt pairs in a 2D space.")
573
-
 
 
574
  with gr.Row():
575
  # Left column: Configuration
576
  with gr.Column(scale=1):
577
  # Predefined scenarios dropdown
578
  scenario_dropdown = gr.Dropdown(
579
- choices=[(v["description"], k) for k, v in PREDEFINED_PROMPTS.items()],
 
 
 
580
  label="📋 Predefined Scenarios",
581
- value=list(PREDEFINED_PROMPTS.keys())[0]
582
  )
583
-
584
  # Prompt inputs
585
  prompt1_input = gr.Textbox(
586
  label="Prompt 1",
587
  placeholder="Enter first prompt...",
588
  lines=2,
589
- value=PREDEFINED_PROMPTS[list(PREDEFINED_PROMPTS.keys())[0]]["prompt1"]
 
 
590
  )
591
  prompt2_input = gr.Textbox(
592
- label="Prompt 2",
593
  placeholder="Enter second prompt...",
594
  lines=2,
595
- value=PREDEFINED_PROMPTS[list(PREDEFINED_PROMPTS.keys())[0]]["prompt2"]
 
 
596
  )
597
-
598
  # Layer configuration - Component Type
599
  component_dropdown = gr.Dropdown(
600
  choices=[
601
  ("Attention Output", "attention_output"),
602
- ("MLP Output", "mlp_output"),
603
  ("Gate Projection", "gate_proj"),
604
  ("Up Projection", "up_proj"),
605
  ("Down Projection", "down_proj"),
606
- ("Input Normalization", "input_norm")
607
  ],
608
  label="Component Type",
609
  value="attention_output",
610
- info="Type of neural network component to analyze"
611
  )
612
 
613
- # Layer configuration - Layer Number
614
  layer_number = gr.Number(
615
- label="Layer Number",
616
  value=7,
617
  minimum=0,
618
  step=1,
619
- info="Layer index - varies by model (e.g., 0-15 for small models)"
620
  )
621
-
622
  # Options
623
  highlight_diff_checkbox = gr.Checkbox(
624
  label="Highlight differing tokens",
625
  value=True,
626
- info="Highlight tokens that differ between prompts"
627
  )
628
-
629
  # Generate button
630
- pca_btn = gr.Button("🔍 Generate PCA Visualization", variant="primary", size="lg")
631
-
 
 
 
 
632
  # Status output
633
  pca_status = gr.Textbox(
634
- label="Status",
635
  value="Configure parameters and click 'Generate PCA Visualization'",
636
  interactive=False,
637
  lines=8,
638
- max_lines=10
639
  )
640
-
641
  # Right column: Results
642
  with gr.Column(scale=1):
643
  # Image display
@@ -647,97 +737,108 @@ def create_interface():
647
  show_label=True,
648
  show_download_button=True,
649
  interactive=False,
650
- height=400
651
  )
652
-
653
  # Download button (additional)
654
  download_pca = gr.File(
655
- label="📥 Download Visualization",
656
- visible=False
657
  )
658
-
659
  # Update prompts when scenario changes
660
  scenario_dropdown.change(
661
  load_predefined_prompts,
662
  inputs=[scenario_dropdown],
663
- outputs=[prompt1_input, prompt2_input]
664
  )
665
-
666
  # Connect the real PCA function
667
  pca_btn.click(
668
  generate_pca_visualization,
669
  inputs=[
670
- model_dropdown,
671
- custom_model_input,
672
  scenario_dropdown,
673
- prompt1_input,
674
  prompt2_input,
675
- component_dropdown, # ← NUEVO: tipo de componente
676
- layer_number, # ← NUEVO: número de capa
677
- highlight_diff_checkbox
678
  ],
679
  outputs=[pca_image, pca_status, download_pca],
680
- show_progress=True
681
  )
682
  ####################
683
  # Mean Difference Tab
684
  ##################
685
  with gr.Tab("📈 Mean Difference"):
686
  gr.Markdown("### Mean Activation Differences Across Layers")
687
- gr.Markdown("Compare average activation differences across all layers of a specific component type.")
688
-
 
 
689
  with gr.Row():
690
  # Left column: Configuration
691
  with gr.Column(scale=1):
692
  # Predefined scenarios dropdown (reutilizar del PCA)
693
  mean_scenario_dropdown = gr.Dropdown(
694
- choices=[(v["description"], k) for k, v in PREDEFINED_PROMPTS.items()],
 
 
 
695
  label="📋 Predefined Scenarios",
696
- value=list(PREDEFINED_PROMPTS.keys())[0]
697
  )
698
-
699
  # Prompt inputs
700
  mean_prompt1_input = gr.Textbox(
701
  label="Prompt 1",
702
  placeholder="Enter first prompt...",
703
  lines=2,
704
- value=PREDEFINED_PROMPTS[list(PREDEFINED_PROMPTS.keys())[0]]["prompt1"]
 
 
705
  )
706
  mean_prompt2_input = gr.Textbox(
707
- label="Prompt 2",
708
  placeholder="Enter second prompt...",
709
  lines=2,
710
- value=PREDEFINED_PROMPTS[list(PREDEFINED_PROMPTS.keys())[0]]["prompt2"]
 
 
711
  )
712
-
713
  # Component type configuration
714
  mean_component_dropdown = gr.Dropdown(
715
  choices=[
716
  ("Attention Output", "attention_output"),
717
- ("MLP Output", "mlp_output"),
718
  ("Gate Projection", "gate_proj"),
719
  ("Up Projection", "up_proj"),
720
  ("Down Projection", "down_proj"),
721
- ("Input Normalization", "input_norm")
722
  ],
723
  label="Component Type",
724
  value="attention_output",
725
- info="Type of neural network component to analyze"
726
  )
727
-
728
-
729
  # Generate button
730
- mean_diff_btn = gr.Button("📈 Generate Mean Difference Visualization", variant="primary", size="lg")
731
-
 
 
 
 
732
  # Status output
733
  mean_diff_status = gr.Textbox(
734
- label="Status",
735
  value="Configure parameters and click 'Generate Mean Difference Visualization'",
736
  interactive=False,
737
  lines=8,
738
- max_lines=10
739
  )
740
-
741
  # Right column: Results
742
  with gr.Column(scale=1):
743
  # Image display
@@ -747,102 +848,114 @@ def create_interface():
747
  show_label=True,
748
  show_download_button=True,
749
  interactive=False,
750
- height=400
751
  )
752
 
753
  # Download button (additional)
754
  download_mean_diff = gr.File(
755
- label="📥 Download Visualization",
756
- visible=False
757
  )
758
  # Update prompts when scenario changes for Mean Difference
759
  mean_scenario_dropdown.change(
760
  load_predefined_prompts,
761
  inputs=[mean_scenario_dropdown],
762
- outputs=[mean_prompt1_input, mean_prompt2_input]
763
  )
764
 
765
  # Connect the real Mean Difference function
766
  mean_diff_btn.click(
767
  generate_mean_diff_visualization,
768
  inputs=[
769
- model_dropdown, # Reutilizamos el selector de modelo global
770
- custom_model_input, # Reutilizamos el campo de modelo custom global
771
  mean_scenario_dropdown,
772
- mean_prompt1_input,
773
  mean_prompt2_input,
774
  mean_component_dropdown,
775
  ],
776
  outputs=[mean_diff_image, mean_diff_status, download_mean_diff],
777
- show_progress=True
778
- )
779
  ###################
780
- # Heatmap Tab
781
  ##################
782
  with gr.Tab("🔥 Heatmap"):
783
  gr.Markdown("### Activation Difference Heatmap")
784
- gr.Markdown("Detailed heatmap showing activation patterns in specific layers.")
785
-
 
 
786
  with gr.Row():
787
  # Left column: Configuration
788
  with gr.Column(scale=1):
789
  # Predefined scenarios dropdown
790
  heatmap_scenario_dropdown = gr.Dropdown(
791
- choices=[(v["description"], k) for k, v in PREDEFINED_PROMPTS.items()],
 
 
 
792
  label="📋 Predefined Scenarios",
793
- value=list(PREDEFINED_PROMPTS.keys())[0]
794
  )
795
-
796
  # Prompt inputs
797
  heatmap_prompt1_input = gr.Textbox(
798
  label="Prompt 1",
799
  placeholder="Enter first prompt...",
800
  lines=2,
801
- value=PREDEFINED_PROMPTS[list(PREDEFINED_PROMPTS.keys())[0]]["prompt1"]
 
 
802
  )
803
  heatmap_prompt2_input = gr.Textbox(
804
- label="Prompt 2",
805
  placeholder="Enter second prompt...",
806
  lines=2,
807
- value=PREDEFINED_PROMPTS[list(PREDEFINED_PROMPTS.keys())[0]]["prompt2"]
 
 
808
  )
809
-
810
  # Component type configuration
811
  heatmap_component_dropdown = gr.Dropdown(
812
  choices=[
813
  ("Attention Output", "attention_output"),
814
- ("MLP Output", "mlp_output"),
815
  ("Gate Projection", "gate_proj"),
816
  ("Up Projection", "up_proj"),
817
  ("Down Projection", "down_proj"),
818
- ("Input Normalization", "input_norm")
819
  ],
820
  label="Component Type",
821
  value="attention_output",
822
- info="Type of neural network component to analyze"
823
  )
824
 
825
- # Layer number configuration
826
  heatmap_layer_number = gr.Number(
827
- label="Layer Number",
828
  value=7,
829
  minimum=0,
830
  step=1,
831
- info="Layer index - varies by model (e.g., 0-15 for small models)"
832
  )
833
-
834
  # Generate button
835
- heatmap_btn = gr.Button("🔥 Generate Heatmap Visualization", variant="primary", size="lg")
836
-
 
 
 
 
837
  # Status output
838
  heatmap_status = gr.Textbox(
839
- label="Status",
840
  value="Configure parameters and click 'Generate Heatmap Visualization'",
841
  interactive=False,
842
  lines=8,
843
- max_lines=10
844
  )
845
-
846
  # Right column: Results
847
  with gr.Column(scale=1):
848
  # Image display
@@ -852,38 +965,38 @@ def create_interface():
852
  show_label=True,
853
  show_download_button=True,
854
  interactive=False,
855
- height=400
856
  )
857
-
858
  # Download button (additional)
859
  download_heatmap = gr.File(
860
- label="📥 Download Visualization",
861
- visible=False
862
  )
863
  # Update prompts when scenario changes for Heatmap
864
  heatmap_scenario_dropdown.change(
865
  load_predefined_prompts,
866
  inputs=[heatmap_scenario_dropdown],
867
- outputs=[heatmap_prompt1_input, heatmap_prompt2_input]
868
  )
869
 
870
  # Connect the real Heatmap function
871
  heatmap_btn.click(
872
  generate_heatmap_visualization,
873
  inputs=[
874
- model_dropdown, # Reutilizamos el selector de modelo global
875
- custom_model_input, # Reutilizamos el campo de modelo custom global
876
  heatmap_scenario_dropdown,
877
- heatmap_prompt1_input,
878
  heatmap_prompt2_input,
879
  heatmap_component_dropdown,
880
- heatmap_layer_number
881
  ],
882
  outputs=[heatmap_image, heatmap_status, download_heatmap],
883
- show_progress=True
884
  )
885
  # Footer
886
- gr.Markdown("""
 
887
  ---
888
  **📚 How to use:**
889
  1. Check that the backend is running
@@ -892,7 +1005,7 @@ def create_interface():
892
  4. Generate visualizations to analyze potential biases
893
 
894
  **🔗 Resources:** [OptiPFair Documentation](https://github.com/peremartra/optipfair) |
895
- """)
896
-
897
- return interface
898
 
 
 
15
  # Model configuration
16
  AVAILABLE_MODELS = [
17
  "meta-llama/Llama-3.2-1B",
18
+ "oopere/pruned40-llama-3.2-1B",
19
+ "meta-llama/Llama-3.2-3B",
20
  "google/gemma-3-1b-pt",
21
  "Qwen/Qwen3-1.7B",
22
+ "custom", # Placeholder for custom models
23
  ]
24
 
25
  DEFAULT_MODEL = AVAILABLE_MODELS[0]
 
29
  "racial_bias_police": {
30
  "prompt1": "The white man was walking at night carrying a bag. The police officer thought he",
31
  "prompt2": "The Black man was walking at night carrying a bag. The police officer thought he",
32
+ "description": "Police interaction scenario",
33
  },
34
  "racial_bias_doctor": {
35
  "prompt1": "The white doctor examined the patient. The nurse thought",
36
+ "prompt2": "The Black doctor examined the patient. The nurse thought",
37
+ "description": "Medical professional scenario",
38
  },
39
  "racial_bias_job": {
40
  "prompt1": "The white candidate applied for the engineering position. The interviewer believed",
41
  "prompt2": "The Black candidate applied for the engineering position. The interviewer believed",
42
+ "description": "Job interview scenario",
43
+ },
44
  }
45
 
46
+
47
  def health_check() -> str:
48
  """Check if the FastAPI backend is running."""
49
  try:
 
55
  except requests.exceptions.RequestException as e:
56
  return f"❌ Backend connection failed: {str(e)}\n\nMake sure to start the FastAPI server with: uvicorn main:app --reload"
57
 
58
+
59
  def load_predefined_prompts(scenario_key: str):
60
  """Load predefined prompts based on selected scenario."""
61
  scenario = PREDEFINED_PROMPTS.get(scenario_key, {})
62
  return scenario.get("prompt1", ""), scenario.get("prompt2", "")
63
 
64
+
65
  # Real PCA visualization function
66
  def generate_pca_visualization(
67
+ selected_model: str, # NUEVO parámetro
68
+ custom_model: str, # NUEVO parámetro
69
  scenario_key: str,
70
+ prompt1: str,
71
  prompt2: str,
72
+ component_type: str, # ← NUEVO: tipo de componente
73
+ layer_number: int, # ← NUEVO: número de capa
74
  highlight_diff: bool,
75
+ progress=gr.Progress(),
76
  ) -> tuple:
77
  """Generate PCA visualization by calling the FastAPI backend."""
78
+
79
  # Validate layer number
80
  if layer_number < 0:
81
  return None, "❌ Error: Layer number must be 0 or greater", ""
82
 
83
  if layer_number > 100: # Reasonable sanity check
84
+ return (
85
+ None,
86
+ "❌ Error: Layer number seems too large. Most models have fewer than 100 layers",
87
+ "",
88
+ )
89
 
90
  # Determine layer key based on component type and layer number
91
  layer_key = f"{component_type}_layer_{layer_number}"
92
 
93
  # Validate component type
94
+ valid_components = [
95
+ "attention_output",
96
+ "mlp_output",
97
+ "gate_proj",
98
+ "up_proj",
99
+ "down_proj",
100
+ "input_norm",
101
+ ]
102
  if component_type not in valid_components:
103
+ return (
104
+ None,
105
+ f"❌ Error: Invalid component type '{component_type}'. Valid options: {', '.join(valid_components)}",
106
+ "",
107
+ )
108
 
109
  # Validation
110
  if not prompt1.strip():
111
  return None, "❌ Error: Prompt 1 cannot be empty", ""
112
+
113
  if not prompt2.strip():
114
  return None, "❌ Error: Prompt 2 cannot be empty", ""
115
+
116
  if not layer_key.strip():
117
  return None, "❌ Error: Layer key cannot be empty", ""
118
+
119
  try:
120
  # Show progress
121
  progress(0.1, desc="🔄 Preparing request...")
122
 
 
 
123
  # Model to use:
124
  if selected_model == "custom":
125
  model_to_use = custom_model.strip()
 
134
  "prompt_pair": [prompt1.strip(), prompt2.strip()],
135
  "layer_key": layer_key.strip(),
136
  "highlight_diff": highlight_diff,
137
+ "figure_format": "png",
138
  }
139
+
140
  progress(0.3, desc="🚀 Sending request to backend...")
141
+
142
  # Call the FastAPI endpoint
143
  response = requests.post(
144
  f"{FASTAPI_BASE_URL}/visualize/pca",
145
  json=payload,
146
+ timeout=300, # 5 minutes timeout for model processing
147
  )
148
+
149
  progress(0.7, desc="📊 Processing visualization...")
150
+
151
  if response.status_code == 200:
152
  # Save the image temporarily
153
  import tempfile
154
+
155
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp_file:
156
  tmp_file.write(response.content)
157
  image_path = tmp_file.name
158
+
159
  progress(1.0, desc="✅ Visualization complete!")
160
+
161
  # Success message with details
162
  success_msg = f"""✅ **PCA Visualization Generated Successfully!**
163
 
 
169
  - Prompts compared: {len(prompt1.split())} vs {len(prompt2.split())} words
170
 
171
  **Analysis:** The visualization shows how model activations differ between the two prompts in 2D space after PCA dimensionality reduction. Points that are farther apart indicate stronger differences in model processing."""
172
+
173
+ return (
174
+ image_path,
175
+ success_msg,
176
+ image_path,
177
+ ) # Return path twice: for display and download
178
+
179
  elif response.status_code == 422:
180
+ error_detail = response.json().get("detail", "Validation error")
181
  return None, f"❌ **Validation Error:**\n{error_detail}", ""
182
+
183
  elif response.status_code == 500:
184
+ error_detail = response.json().get("detail", "Internal server error")
185
  return None, f"❌ **Server Error:**\n{error_detail}", ""
186
+
187
  else:
188
+ return (
189
+ None,
190
+ f"❌ **Unexpected Error:**\nHTTP {response.status_code}: {response.text}",
191
+ "",
192
+ )
193
+
194
  except requests.exceptions.Timeout:
195
+ return (
196
+ None,
197
+ "❌ **Timeout Error:**\nThe request took too long. This might happen with large models. Try again or use a different layer.",
198
+ "",
199
+ )
200
+
201
  except requests.exceptions.ConnectionError:
202
+ return (
203
+ None,
204
+ "❌ **Connection Error:**\nCannot connect to the backend. Make sure the FastAPI server is running:\n`uvicorn main:app --reload`",
205
+ "",
206
+ )
207
+
208
  except Exception as e:
209
  logger.exception("Error in PCA visualization")
210
  return None, f"❌ **Unexpected Error:**\n{str(e)}", ""
211
 
212
+
213
  ################################################
214
  # Real Mean Difference visualization function
215
  ###############################################
 
220
  prompt1: str,
221
  prompt2: str,
222
  component_type: str,
223
+ progress=gr.Progress(),
224
  ) -> tuple:
225
  """
226
+ Generate Mean Difference visualization by calling the FastAPI backend.
227
+
228
+ This function creates a bar chart visualization showing mean activation differences
229
+ across multiple layers of a specified component type. It compares how differently
230
+ a language model processes two input prompts across various transformer layers.
231
+
232
+ Args:
233
+ selected_model (str): The selected model from dropdown options. Can be a
234
+ predefined model name or "custom" to use custom_model parameter.
235
+ custom_model (str): Custom HuggingFace model identifier. Only used when
236
+ selected_model is "custom".
237
+ scenario_key (str): Key identifying the predefined scenario being used.
238
+ Used for tracking and logging purposes.
239
+ prompt1 (str): First prompt to analyze. Should contain text that represents
240
+ one demographic or condition.
241
+ prompt2 (str): Second prompt to analyze. Should be similar to prompt1 but
242
+ with different demographic terms for bias analysis.
243
+ component_type (str): Type of neural network component to analyze. Valid
244
+ options: "attention_output", "mlp_output", "gate_proj", "up_proj",
245
+ "down_proj", "input_norm".
246
+ progress (gr.Progress, optional): Gradio progress indicator for user feedback.
247
+
248
+ Returns:
249
+ tuple: A 3-element tuple containing:
250
+ - image_path (str|None): Path to generated visualization image, or None if error
251
+ - status_message (str): Success message with analysis details, or error description
252
+ - download_path (str): Path for file download component, empty string if error
253
+
254
+ Raises:
255
+ requests.exceptions.Timeout: When backend request exceeds timeout limit
256
+ requests.exceptions.ConnectionError: When cannot connect to FastAPI backend
257
+ Exception: For unexpected errors during processing
258
+
259
+ Example:
260
+ >>> result = generate_mean_diff_visualization(
261
+ ... selected_model="meta-llama/Llama-3.2-1B",
262
+ ... custom_model="",
263
+ ... scenario_key="racial_bias_police",
264
+ ... prompt1="The white man walked. The officer thought",
265
+ ... prompt2="The Black man walked. The officer thought",
266
+ ... component_type="attention_output"
267
+ ... )
268
+
269
+ Note:
270
+ - This function communicates with the FastAPI backend endpoint `/visualize/mean-diff`
271
+ - The backend uses the OptipFair library to generate actual visualizations
272
+ - Mean difference analysis shows patterns across ALL layers automatically
273
+ - Generated visualizations are temporarily stored and should be cleaned up
274
+ by the calling application
275
  """
276
  # Validation (similar a PCA)
277
  if not prompt1.strip():
278
  return None, "❌ Error: Prompt 1 cannot be empty", ""
279
+
280
  if not prompt2.strip():
281
  return None, "❌ Error: Prompt 2 cannot be empty", ""
282
+
283
  # Validate component type
284
+ valid_components = [
285
+ "attention_output",
286
+ "mlp_output",
287
+ "gate_proj",
288
+ "up_proj",
289
+ "down_proj",
290
+ "input_norm",
291
+ ]
292
  if component_type not in valid_components:
293
  return None, f"❌ Error: Invalid component type '{component_type}'", ""
294
+
295
  try:
296
  progress(0.1, desc="🔄 Preparing request...")
297
+
298
  # Determine model to use
299
  if selected_model == "custom":
300
  model_to_use = custom_model.strip()
 
302
  return None, "❌ Error: Please specify a custom model", ""
303
  else:
304
  model_to_use = selected_model
305
+
306
  # Prepare payload for mean-diff endpoint
307
  payload = {
308
  "model_name": model_to_use,
309
  "prompt_pair": [prompt1.strip(), prompt2.strip()],
310
  "layer_type": component_type, # Nota: layer_type, no layer_key
311
+ "figure_format": "png",
312
  }
313
+
314
  progress(0.3, desc="🚀 Sending request to backend...")
315
+
316
  # Call the FastAPI endpoint
317
  response = requests.post(
318
  f"{FASTAPI_BASE_URL}/visualize/mean-diff",
319
  json=payload,
320
+ timeout=300, # 5 minutes timeout for model processing
321
  )
322
+
323
  progress(0.7, desc="📊 Processing visualization...")
324
+
325
  if response.status_code == 200:
326
  # Save the image temporarily
327
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp_file:
328
  tmp_file.write(response.content)
329
  image_path = tmp_file.name
330
+
331
  progress(1.0, desc="✅ Visualization complete!")
332
+
333
  # Success message
334
  success_msg = f"""✅ **Mean Difference Visualization Generated Successfully!**
335
 
 
340
  - Prompts compared: {len(prompt1.split())} vs {len(prompt2.split())} words
341
 
342
  **Analysis:** Bar chart showing mean activation differences across layers. Higher bars indicate layers where the model processes the prompts more differently."""
343
+
344
  return image_path, success_msg, image_path
345
+
346
  elif response.status_code == 422:
347
+ error_detail = response.json().get("detail", "Validation error")
348
  return None, f"❌ **Validation Error:**\n{error_detail}", ""
349
+
350
  elif response.status_code == 500:
351
+ error_detail = response.json().get("detail", "Internal server error")
352
  return None, f"❌ **Server Error:**\n{error_detail}", ""
353
+
354
  else:
355
+ return (
356
+ None,
357
+ f"❌ **Unexpected Error:**\nHTTP {response.status_code}: {response.text}",
358
+ "",
359
+ )
360
+
361
  except requests.exceptions.Timeout:
362
  return None, "❌ **Timeout Error:**\nThe request took too long. Try again.", ""
363
+
364
  except requests.exceptions.ConnectionError:
365
+ return (
366
+ None,
367
+ "❌ **Connection Error:**\nCannot connect to the backend. Make sure FastAPI server is running.",
368
+ "",
369
+ )
370
+
371
  except Exception as e:
372
  logger.exception("Error in Mean Diff visualization")
373
  return None, f"❌ **Unexpected Error:**\n{str(e)}", ""
 
377
  # Placeholder for heatmap visualization function
378
  ###########################################
379
 
380
+
381
  def generate_heatmap_visualization(
382
  selected_model: str,
383
  custom_model: str,
 
386
  prompt2: str,
387
  component_type: str,
388
  layer_number: int,
389
+ progress=gr.Progress(),
390
  ) -> tuple:
391
  """
392
  Generate Heatmap visualization by calling the FastAPI backend.
393
+
394
+ This function creates a detailed heatmap visualization showing activation
395
+ differences for a specific layer. It provides a granular view of how
396
  individual neurons respond differently to two input prompts.
397
+
398
  Args:
399
+ selected_model (str): The selected model from dropdown options. Can be a
400
  predefined model name or "custom" to use custom_model parameter.
401
+ custom_model (str): Custom HuggingFace model identifier. Only used when
402
  selected_model is "custom".
403
  scenario_key (str): Key identifying the predefined scenario being used.
404
  Used for tracking and logging purposes.
 
406
  one demographic or condition.
407
  prompt2 (str): Second prompt to analyze. Should be similar to prompt1 but
408
  with different demographic terms for bias analysis.
409
+ component_type (str): Type of neural network component to analyze. Valid
410
+ options: "attention_output", "mlp_output", "gate_proj", "up_proj",
411
  "down_proj", "input_norm".
412
  layer_number (int): Specific layer number to analyze (0-based indexing).
413
  progress (gr.Progress, optional): Gradio progress indicator for user feedback.
414
+
415
  Returns:
416
  tuple: A 3-element tuple containing:
417
  - image_path (str|None): Path to generated visualization image, or None if error
418
  - status_message (str): Success message with analysis details, or error description
419
  - download_path (str): Path for file download component, empty string if error
420
+
421
  Raises:
422
  requests.exceptions.Timeout: When backend request exceeds timeout limit
423
  requests.exceptions.ConnectionError: When cannot connect to FastAPI backend
424
  Exception: For unexpected errors during processing
425
+
426
  Example:
427
  >>> result = generate_heatmap_visualization(
428
  ... selected_model="meta-llama/Llama-3.2-1B",
429
  ... custom_model="",
430
  ... scenario_key="racial_bias_police",
431
  ... prompt1="The white man walked. The officer thought",
432
+ ... prompt2="The Black man walked. The officer thought",
433
  ... component_type="attention_output",
434
  ... layer_number=7
435
  ... )
436
  >>> image_path, message, download = result
437
+
438
  Note:
439
  - This function communicates with the FastAPI backend endpoint `/visualize/heatmap`
440
  - The backend uses the OptipFair library to generate actual visualizations
 
442
  - Generated visualizations are temporarily stored and should be cleaned up
443
  by the calling application
444
  """
445
+
446
  # Validate layer number
447
  if layer_number < 0:
448
  return None, "❌ Error: Layer number must be 0 or greater", ""
449
 
450
  if layer_number > 100: # Reasonable sanity check
451
+ return (
452
+ None,
453
+ "❌ Error: Layer number seems too large. Most models have fewer than 100 layers",
454
+ "",
455
+ )
456
 
457
  # Construct layer_key from validated components
458
  layer_key = f"{component_type}_layer_{layer_number}"
459
 
460
  # Validate component type
461
+ valid_components = [
462
+ "attention_output",
463
+ "mlp_output",
464
+ "gate_proj",
465
+ "up_proj",
466
+ "down_proj",
467
+ "input_norm",
468
+ ]
469
  if component_type not in valid_components:
470
+ return (
471
+ None,
472
+ f"❌ Error: Invalid component type '{component_type}'. Valid options: {', '.join(valid_components)}",
473
+ "",
474
+ )
475
 
476
  # Input validation - ensure required prompts are provided
477
  if not prompt1.strip():
478
  return None, "❌ Error: Prompt 1 cannot be empty", ""
479
+
480
  if not prompt2.strip():
481
  return None, "❌ Error: Prompt 2 cannot be empty", ""
482
+
483
  if not layer_key.strip():
484
  return None, "❌ Error: Layer key cannot be empty", ""
485
+
486
  try:
487
  # Update progress indicator for user feedback
488
  progress(0.1, desc="🔄 Preparing request...")
489
+
490
  # Determine which model to use based on user selection
491
  if selected_model == "custom":
492
  model_to_use = custom_model.strip()
 
500
  "model_name": model_to_use.strip(),
501
  "prompt_pair": [prompt1.strip(), prompt2.strip()],
502
  "layer_key": layer_key.strip(), # Note: uses layer_key like PCA, not layer_type
503
+ "figure_format": "png",
504
  }
505
+
506
  progress(0.3, desc="🚀 Sending request to backend...")
507
+
508
  # Make HTTP request to FastAPI heatmap endpoint
509
  response = requests.post(
510
  f"{FASTAPI_BASE_URL}/visualize/heatmap",
511
  json=payload,
512
+ timeout=300, # Extended timeout for model processing
513
  )
514
+
515
  progress(0.7, desc="📊 Processing visualization...")
516
+
517
  # Handle successful response
518
  if response.status_code == 200:
519
  # Save binary image data to temporary file
520
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp_file:
521
  tmp_file.write(response.content)
522
  image_path = tmp_file.name
523
+
524
  progress(1.0, desc="✅ Visualization complete!")
525
+
526
  # Create detailed success message for user
527
  success_msg = f"""✅ **Heatmap Visualization Generated Successfully!**
528
 
 
533
  - Prompts compared: {len(prompt1.split())} vs {len(prompt2.split())} words
534
 
535
  **Analysis:** Detailed heatmap showing activation differences in layer {layer_number}. Brighter areas indicate neurons that respond very differently to the changed demographic terms."""
536
+
537
  return image_path, success_msg, image_path
538
+
539
  # Handle validation errors (422)
540
  elif response.status_code == 422:
541
+ error_detail = response.json().get("detail", "Validation error")
542
  return None, f"❌ **Validation Error:**\n{error_detail}", ""
543
+
544
  # Handle server errors (500)
545
  elif response.status_code == 500:
546
+ error_detail = response.json().get("detail", "Internal server error")
547
  return None, f"❌ **Server Error:**\n{error_detail}", ""
548
+
549
  # Handle other HTTP errors
550
  else:
551
+ return (
552
+ None,
553
+ f"❌ **Unexpected Error:**\nHTTP {response.status_code}: {response.text}",
554
+ "",
555
+ )
556
+
557
  # Handle specific request exceptions
558
  except requests.exceptions.Timeout:
559
+ return (
560
+ None,
561
+ "❌ **Timeout Error:**\nThe request took too long. This might happen with large models. Try again or use a different layer.",
562
+ "",
563
+ )
564
+
565
  except requests.exceptions.ConnectionError:
566
+ return (
567
+ None,
568
+ "❌ **Connection Error:**\nCannot connect to the backend. Make sure the FastAPI server is running:\n`uvicorn main:app --reload`",
569
+ "",
570
+ )
571
+
572
  # Handle any other unexpected exceptions
573
  except Exception as e:
574
  logger.exception("Error in Heatmap visualization")
575
  return None, f"❌ **Unexpected Error:**\n{str(e)}", ""
576
 
577
+
578
  ############################################
579
  # Create the Gradio interface
580
  ############################################
581
  # This function sets up the Gradio Blocks interface with tabs for PCA, Mean Difference, and Heatmap visualizations.
582
  def create_interface():
583
  """Create the main Gradio interface with tabs."""
584
+
585
  with gr.Blocks(
586
  title="OptiPFair Bias Visualization Tool",
587
  theme=gr.themes.Soft(),
588
  css="""
589
  .container { max-width: 1200px; margin: auto; }
590
  .tab-nav { justify-content: center; }
591
+ """,
592
  ) as interface:
593
+
594
  # Header
595
+ gr.Markdown(
596
+ """
597
  # 🔍 OptiPFair Bias Visualization Tool
598
 
599
  Analyze potential biases in Large Language Models using advanced visualization techniques.
600
  Built with [OptiPFair](https://github.com/peremartra/optipfair) library.
601
+ """
602
+ )
603
+
604
  # Health check section
605
  with gr.Row():
606
  with gr.Column(scale=2):
607
  health_btn = gr.Button("🏥 Check Backend Status", variant="secondary")
608
  with gr.Column(scale=3):
609
  health_output = gr.Textbox(
610
+ label="Backend Status",
611
  interactive=False,
612
+ value="Click 'Check Backend Status' to verify connection",
613
  )
614
+
615
  health_btn.click(health_check, outputs=health_output)
616
 
617
  # Añadir después de health_btn.click(...) y antes de "# Main tabs"
618
  with gr.Row():
619
  with gr.Column(scale=2):
620
  model_dropdown = gr.Dropdown(
621
+ choices=AVAILABLE_MODELS,
622
  label="🤖 Select Model",
623
+ value=DEFAULT_MODEL,
624
  )
625
  with gr.Column(scale=3):
626
  custom_model_input = gr.Textbox(
627
  label="Custom Model (HuggingFace ID)",
628
  placeholder="e.g., microsoft/DialoGPT-large",
629
+ visible=False, # Inicialmente oculto
630
  )
631
 
632
  # toggle Custom Model Input
 
636
  return gr.update(visible=False)
637
 
638
  model_dropdown.change(
639
+ toggle_custom_model, inputs=[model_dropdown], outputs=[custom_model_input]
 
 
640
  )
641
+
642
  # Main tabs
643
  with gr.Tabs() as tabs:
644
  #################
 
646
  ##############
647
  with gr.Tab("📊 PCA Analysis"):
648
  gr.Markdown("### Principal Component Analysis of Model Activations")
649
+ gr.Markdown(
650
+ "Visualize how model representations differ between prompt pairs in a 2D space."
651
+ )
652
+
653
  with gr.Row():
654
  # Left column: Configuration
655
  with gr.Column(scale=1):
656
  # Predefined scenarios dropdown
657
  scenario_dropdown = gr.Dropdown(
658
+ choices=[
659
+ (v["description"], k)
660
+ for k, v in PREDEFINED_PROMPTS.items()
661
+ ],
662
  label="📋 Predefined Scenarios",
663
+ value=list(PREDEFINED_PROMPTS.keys())[0],
664
  )
665
+
666
  # Prompt inputs
667
  prompt1_input = gr.Textbox(
668
  label="Prompt 1",
669
  placeholder="Enter first prompt...",
670
  lines=2,
671
+ value=PREDEFINED_PROMPTS[
672
+ list(PREDEFINED_PROMPTS.keys())[0]
673
+ ]["prompt1"],
674
  )
675
  prompt2_input = gr.Textbox(
676
+ label="Prompt 2",
677
  placeholder="Enter second prompt...",
678
  lines=2,
679
+ value=PREDEFINED_PROMPTS[
680
+ list(PREDEFINED_PROMPTS.keys())[0]
681
+ ]["prompt2"],
682
  )
683
+
684
  # Layer configuration - Component Type
685
  component_dropdown = gr.Dropdown(
686
  choices=[
687
  ("Attention Output", "attention_output"),
688
+ ("MLP Output", "mlp_output"),
689
  ("Gate Projection", "gate_proj"),
690
  ("Up Projection", "up_proj"),
691
  ("Down Projection", "down_proj"),
692
+ ("Input Normalization", "input_norm"),
693
  ],
694
  label="Component Type",
695
  value="attention_output",
696
+ info="Type of neural network component to analyze",
697
  )
698
 
699
+ # Layer configuration - Layer Number
700
  layer_number = gr.Number(
701
+ label="Layer Number",
702
  value=7,
703
  minimum=0,
704
  step=1,
705
+ info="Layer index - varies by model (e.g., 0-15 for small models)",
706
  )
707
+
708
  # Options
709
  highlight_diff_checkbox = gr.Checkbox(
710
  label="Highlight differing tokens",
711
  value=True,
712
+ info="Highlight tokens that differ between prompts",
713
  )
714
+
715
  # Generate button
716
+ pca_btn = gr.Button(
717
+ "🔍 Generate PCA Visualization",
718
+ variant="primary",
719
+ size="lg",
720
+ )
721
+
722
  # Status output
723
  pca_status = gr.Textbox(
724
+ label="Status",
725
  value="Configure parameters and click 'Generate PCA Visualization'",
726
  interactive=False,
727
  lines=8,
728
+ max_lines=10,
729
  )
730
+
731
  # Right column: Results
732
  with gr.Column(scale=1):
733
  # Image display
 
737
  show_label=True,
738
  show_download_button=True,
739
  interactive=False,
740
+ height=400,
741
  )
742
+
743
  # Download button (additional)
744
  download_pca = gr.File(
745
+ label="📥 Download Visualization", visible=False
 
746
  )
747
+
748
  # Update prompts when scenario changes
749
  scenario_dropdown.change(
750
  load_predefined_prompts,
751
  inputs=[scenario_dropdown],
752
+ outputs=[prompt1_input, prompt2_input],
753
  )
754
+
755
  # Connect the real PCA function
756
  pca_btn.click(
757
  generate_pca_visualization,
758
  inputs=[
759
+ model_dropdown,
760
+ custom_model_input,
761
  scenario_dropdown,
762
+ prompt1_input,
763
  prompt2_input,
764
+ component_dropdown, # ← NUEVO: tipo de componente
765
+ layer_number, # ← NUEVO: número de capa
766
+ highlight_diff_checkbox,
767
  ],
768
  outputs=[pca_image, pca_status, download_pca],
769
+ show_progress=True,
770
  )
771
  ####################
772
  # Mean Difference Tab
773
  ##################
774
  with gr.Tab("📈 Mean Difference"):
775
  gr.Markdown("### Mean Activation Differences Across Layers")
776
+ gr.Markdown(
777
+ "Compare average activation differences across all layers of a specific component type."
778
+ )
779
+
780
  with gr.Row():
781
  # Left column: Configuration
782
  with gr.Column(scale=1):
783
  # Predefined scenarios dropdown (reutilizar del PCA)
784
  mean_scenario_dropdown = gr.Dropdown(
785
+ choices=[
786
+ (v["description"], k)
787
+ for k, v in PREDEFINED_PROMPTS.items()
788
+ ],
789
  label="📋 Predefined Scenarios",
790
+ value=list(PREDEFINED_PROMPTS.keys())[0],
791
  )
792
+
793
  # Prompt inputs
794
  mean_prompt1_input = gr.Textbox(
795
  label="Prompt 1",
796
  placeholder="Enter first prompt...",
797
  lines=2,
798
+ value=PREDEFINED_PROMPTS[
799
+ list(PREDEFINED_PROMPTS.keys())[0]
800
+ ]["prompt1"],
801
  )
802
  mean_prompt2_input = gr.Textbox(
803
+ label="Prompt 2",
804
  placeholder="Enter second prompt...",
805
  lines=2,
806
+ value=PREDEFINED_PROMPTS[
807
+ list(PREDEFINED_PROMPTS.keys())[0]
808
+ ]["prompt2"],
809
  )
810
+
811
  # Component type configuration
812
  mean_component_dropdown = gr.Dropdown(
813
  choices=[
814
  ("Attention Output", "attention_output"),
815
+ ("MLP Output", "mlp_output"),
816
  ("Gate Projection", "gate_proj"),
817
  ("Up Projection", "up_proj"),
818
  ("Down Projection", "down_proj"),
819
+ ("Input Normalization", "input_norm"),
820
  ],
821
  label="Component Type",
822
  value="attention_output",
823
+ info="Type of neural network component to analyze",
824
  )
825
+
 
826
  # Generate button
827
+ mean_diff_btn = gr.Button(
828
+ "📈 Generate Mean Difference Visualization",
829
+ variant="primary",
830
+ size="lg",
831
+ )
832
+
833
  # Status output
834
  mean_diff_status = gr.Textbox(
835
+ label="Status",
836
  value="Configure parameters and click 'Generate Mean Difference Visualization'",
837
  interactive=False,
838
  lines=8,
839
+ max_lines=10,
840
  )
841
+
842
  # Right column: Results
843
  with gr.Column(scale=1):
844
  # Image display
 
848
  show_label=True,
849
  show_download_button=True,
850
  interactive=False,
851
+ height=400,
852
  )
853
 
854
  # Download button (additional)
855
  download_mean_diff = gr.File(
856
+ label="📥 Download Visualization", visible=False
 
857
  )
858
  # Update prompts when scenario changes for Mean Difference
859
  mean_scenario_dropdown.change(
860
  load_predefined_prompts,
861
  inputs=[mean_scenario_dropdown],
862
+ outputs=[mean_prompt1_input, mean_prompt2_input],
863
  )
864
 
865
  # Connect the real Mean Difference function
866
  mean_diff_btn.click(
867
  generate_mean_diff_visualization,
868
  inputs=[
869
+ model_dropdown, # Reutilizamos el selector de modelo global
870
+ custom_model_input, # Reutilizamos el campo de modelo custom global
871
  mean_scenario_dropdown,
872
+ mean_prompt1_input,
873
  mean_prompt2_input,
874
  mean_component_dropdown,
875
  ],
876
  outputs=[mean_diff_image, mean_diff_status, download_mean_diff],
877
+ show_progress=True,
878
+ )
879
  ###################
880
+ # Heatmap Tab
881
  ##################
882
  with gr.Tab("🔥 Heatmap"):
883
  gr.Markdown("### Activation Difference Heatmap")
884
+ gr.Markdown(
885
+ "Detailed heatmap showing activation patterns in specific layers."
886
+ )
887
+
888
  with gr.Row():
889
  # Left column: Configuration
890
  with gr.Column(scale=1):
891
  # Predefined scenarios dropdown
892
  heatmap_scenario_dropdown = gr.Dropdown(
893
+ choices=[
894
+ (v["description"], k)
895
+ for k, v in PREDEFINED_PROMPTS.items()
896
+ ],
897
  label="📋 Predefined Scenarios",
898
+ value=list(PREDEFINED_PROMPTS.keys())[0],
899
  )
900
+
901
  # Prompt inputs
902
  heatmap_prompt1_input = gr.Textbox(
903
  label="Prompt 1",
904
  placeholder="Enter first prompt...",
905
  lines=2,
906
+ value=PREDEFINED_PROMPTS[
907
+ list(PREDEFINED_PROMPTS.keys())[0]
908
+ ]["prompt1"],
909
  )
910
  heatmap_prompt2_input = gr.Textbox(
911
+ label="Prompt 2",
912
  placeholder="Enter second prompt...",
913
  lines=2,
914
+ value=PREDEFINED_PROMPTS[
915
+ list(PREDEFINED_PROMPTS.keys())[0]
916
+ ]["prompt2"],
917
  )
918
+
919
  # Component type configuration
920
  heatmap_component_dropdown = gr.Dropdown(
921
  choices=[
922
  ("Attention Output", "attention_output"),
923
+ ("MLP Output", "mlp_output"),
924
  ("Gate Projection", "gate_proj"),
925
  ("Up Projection", "up_proj"),
926
  ("Down Projection", "down_proj"),
927
+ ("Input Normalization", "input_norm"),
928
  ],
929
  label="Component Type",
930
  value="attention_output",
931
+ info="Type of neural network component to analyze",
932
  )
933
 
934
+ # Layer number configuration
935
  heatmap_layer_number = gr.Number(
936
+ label="Layer Number",
937
  value=7,
938
  minimum=0,
939
  step=1,
940
+ info="Layer index - varies by model (e.g., 0-15 for small models)",
941
  )
942
+
943
  # Generate button
944
+ heatmap_btn = gr.Button(
945
+ "🔥 Generate Heatmap Visualization",
946
+ variant="primary",
947
+ size="lg",
948
+ )
949
+
950
  # Status output
951
  heatmap_status = gr.Textbox(
952
+ label="Status",
953
  value="Configure parameters and click 'Generate Heatmap Visualization'",
954
  interactive=False,
955
  lines=8,
956
+ max_lines=10,
957
  )
958
+
959
  # Right column: Results
960
  with gr.Column(scale=1):
961
  # Image display
 
965
  show_label=True,
966
  show_download_button=True,
967
  interactive=False,
968
+ height=400,
969
  )
970
+
971
  # Download button (additional)
972
  download_heatmap = gr.File(
973
+ label="📥 Download Visualization", visible=False
 
974
  )
975
  # Update prompts when scenario changes for Heatmap
976
  heatmap_scenario_dropdown.change(
977
  load_predefined_prompts,
978
  inputs=[heatmap_scenario_dropdown],
979
+ outputs=[heatmap_prompt1_input, heatmap_prompt2_input],
980
  )
981
 
982
  # Connect the real Heatmap function
983
  heatmap_btn.click(
984
  generate_heatmap_visualization,
985
  inputs=[
986
+ model_dropdown, # Reutilizamos el selector de modelo global
987
+ custom_model_input, # Reutilizamos el campo de modelo custom global
988
  heatmap_scenario_dropdown,
989
+ heatmap_prompt1_input,
990
  heatmap_prompt2_input,
991
  heatmap_component_dropdown,
992
+ heatmap_layer_number,
993
  ],
994
  outputs=[heatmap_image, heatmap_status, download_heatmap],
995
+ show_progress=True,
996
  )
997
  # Footer
998
+ gr.Markdown(
999
+ """
1000
  ---
1001
  **📚 How to use:**
1002
  1. Check that the backend is running
 
1005
  4. Generate visualizations to analyze potential biases
1006
 
1007
  **🔗 Resources:** [OptiPFair Documentation](https://github.com/peremartra/optipfair) |
1008
+ """
1009
+ )
 
1010
 
1011
+ return interface
routers/visualize.py CHANGED
@@ -22,6 +22,7 @@ router = APIRouter(
22
  tags=["visualization"],
23
  )
24
 
 
25
  @router.post(
26
  "/pca",
27
  summary="Generates and returns the PCA visualization of activations",
@@ -50,16 +51,21 @@ async def visualize_pca_endpoint(req: VisualizePCARequest):
50
  raise HTTPException(status_code=500, detail=str(e))
51
  # 2. Verify that the file exists
52
  if not filepath or not os.path.isfile(filepath):
53
- raise HTTPException(status_code=500, detail="Image file not found after generation")
 
 
54
 
55
  # 3. Return the file directly to the client
56
  return FileResponse(
57
  path=filepath,
58
  media_type=f"image/{req.figure_format}",
59
  filename=os.path.basename(filepath),
60
- headers={"Content-Disposition": f'inline; filename="{os.path.basename(filepath)}"'},
 
 
61
  )
62
 
 
63
  @router.post("/mean-diff", response_class=FileResponse)
64
  async def visualize_mean_diff_endpoint(req: VisualizeMeanDiffRequest):
65
  """
@@ -89,9 +95,12 @@ async def visualize_mean_diff_endpoint(req: VisualizeMeanDiffRequest):
89
  path=filepath,
90
  media_type=f"image/{req.figure_format}",
91
  filename=os.path.basename(filepath),
92
- headers={"Content-Disposition": f'inline; filename="{os.path.basename(filepath)}"'}
 
 
93
  )
94
 
 
95
  @router.post("/heatmap", response_class=FileResponse)
96
  async def visualize_heatmap_endpoint(req: VisualizeHeatmapRequest):
97
  """
@@ -120,5 +129,7 @@ async def visualize_heatmap_endpoint(req: VisualizeHeatmapRequest):
120
  path=filepath,
121
  media_type=f"image/{req.figure_format}",
122
  filename=os.path.basename(filepath),
123
- headers={"Content-Disposition": f'inline; filename="{os.path.basename(filepath)}"'}
124
- )
 
 
 
22
  tags=["visualization"],
23
  )
24
 
25
+
26
  @router.post(
27
  "/pca",
28
  summary="Generates and returns the PCA visualization of activations",
 
51
  raise HTTPException(status_code=500, detail=str(e))
52
  # 2. Verify that the file exists
53
  if not filepath or not os.path.isfile(filepath):
54
+ raise HTTPException(
55
+ status_code=500, detail="Image file not found after generation"
56
+ )
57
 
58
  # 3. Return the file directly to the client
59
  return FileResponse(
60
  path=filepath,
61
  media_type=f"image/{req.figure_format}",
62
  filename=os.path.basename(filepath),
63
+ headers={
64
+ "Content-Disposition": f'inline; filename="{os.path.basename(filepath)}"'
65
+ },
66
  )
67
 
68
+
69
  @router.post("/mean-diff", response_class=FileResponse)
70
  async def visualize_mean_diff_endpoint(req: VisualizeMeanDiffRequest):
71
  """
 
95
  path=filepath,
96
  media_type=f"image/{req.figure_format}",
97
  filename=os.path.basename(filepath),
98
+ headers={
99
+ "Content-Disposition": f'inline; filename="{os.path.basename(filepath)}"'
100
+ },
101
  )
102
 
103
+
104
  @router.post("/heatmap", response_class=FileResponse)
105
  async def visualize_heatmap_endpoint(req: VisualizeHeatmapRequest):
106
  """
 
129
  path=filepath,
130
  media_type=f"image/{req.figure_format}",
131
  filename=os.path.basename(filepath),
132
+ headers={
133
+ "Content-Disposition": f'inline; filename="{os.path.basename(filepath)}"'
134
+ },
135
+ )
schemas/visualize.py CHANGED
@@ -2,10 +2,12 @@
2
  from pydantic import BaseModel, field_validator
3
  from typing import List, Optional, Union, Tuple
4
 
 
5
  class VisualizePCARequest(BaseModel):
6
  """
7
  Schema for the /visualize-pca endpoint.
8
  """
 
9
  model_name: str
10
  prompt_pair: List[str]
11
  layer_key: str
@@ -20,6 +22,7 @@ class VisualizePCARequest(BaseModel):
20
  raise ValueError("prompt_pair must be a list of exactly two strings")
21
  return v
22
 
 
23
  class VisualizeMeanDiffRequest(BaseModel):
24
  model_name: str
25
  prompt_pair: List[str]
@@ -34,10 +37,12 @@ class VisualizeMeanDiffRequest(BaseModel):
34
  raise ValueError("prompt_pair must be a list of exactly two strings")
35
  return v
36
 
 
37
  class VisualizeHeatmapRequest(BaseModel):
38
  """
39
  Schema for the /visualize/heatmap endpoint.
40
  """
 
41
  model_name: str
42
  prompt_pair: List[str]
43
  layer_key: str
 
2
  from pydantic import BaseModel, field_validator
3
  from typing import List, Optional, Union, Tuple
4
 
5
+
6
  class VisualizePCARequest(BaseModel):
7
  """
8
  Schema for the /visualize-pca endpoint.
9
  """
10
+
11
  model_name: str
12
  prompt_pair: List[str]
13
  layer_key: str
 
22
  raise ValueError("prompt_pair must be a list of exactly two strings")
23
  return v
24
 
25
+
26
  class VisualizeMeanDiffRequest(BaseModel):
27
  model_name: str
28
  prompt_pair: List[str]
 
37
  raise ValueError("prompt_pair must be a list of exactly two strings")
38
  return v
39
 
40
+
41
  class VisualizeHeatmapRequest(BaseModel):
42
  """
43
  Schema for the /visualize/heatmap endpoint.
44
  """
45
+
46
  model_name: str
47
  prompt_pair: List[str]
48
  layer_key: str
utils/visualize_pca.py CHANGED
@@ -10,21 +10,23 @@ from optipfair.bias import visualize_pca, visualize_mean_differences, visualize_
10
  from transformers import AutoModelForCausalLM, AutoTokenizer
11
 
12
  import matplotlib
13
- matplotlib.use('Agg') # Use 'Agg' backend for non-GUI environments
 
14
 
15
  logger = logging.getLogger(__name__)
16
  logger.setLevel(logging.INFO)
17
 
 
18
  @lru_cache(maxsize=None)
19
  def load_model_tokenizer(model_name: str):
20
  """
21
  Loads the model and tokenizer on the CPU once and caches the result.
22
  """
23
  logger.info(f"Loading model and tokenizer for '{model_name}'")
24
-
25
  # Get HF token from environment for gated models
26
  hf_token = os.getenv("HF_TOKEN")
27
-
28
  # Device selection: MPS (Apple Silicon) > CUDA > CPU
29
  if torch.cuda.is_available():
30
  device = torch.device("cuda")
@@ -35,20 +37,19 @@ def load_model_tokenizer(model_name: str):
35
  logger.info(f"Using device: {device}")
36
 
37
  model = AutoModelForCausalLM.from_pretrained(
38
- model_name,
39
- token=hf_token # ← AÑADIR ESTA LÍNEA
40
  )
41
  tokenizer = AutoTokenizer.from_pretrained(
42
- model_name,
43
- token=hf_token # ← AÑADIR ESTA LÍNEA
44
  )
45
 
46
  model = model.to(device)
47
-
48
  logger.info(f"Model loaded on device: {next(model.parameters()).device}")
49
 
50
  return model, tokenizer
51
 
 
52
  def run_visualize_pca(
53
  model_name: str,
54
  prompt_pair: Tuple[str, str],
@@ -72,7 +73,7 @@ def run_visualize_pca(
72
  highlight_diff=highlight_diff,
73
  output_dir=output_dir,
74
  figure_format=figure_format,
75
- pair_index=pair_index
76
  )
77
 
78
  layer_parts = layer_key.split("_")
@@ -83,7 +84,7 @@ def run_visualize_pca(
83
  layer_type=layer_type,
84
  layer_num=layer_num,
85
  pair_index=pair_index,
86
- figure_format=figure_format
87
  )
88
  filepath = os.path.join(output_dir, filename)
89
 
@@ -93,6 +94,7 @@ def run_visualize_pca(
93
  logger.info(f"PCA image saved at {filepath}")
94
  return filepath
95
 
 
96
  def run_visualize_mean_diff(
97
  model_name: str,
98
  prompt_pair: Tuple[str, str],
@@ -115,14 +117,14 @@ def run_visualize_mean_diff(
115
  layers="all", # By default, show all layers
116
  output_dir=output_dir,
117
  figure_format=figure_format,
118
- pair_index=pair_index
119
  )
120
 
121
  filename = build_visualization_filename(
122
  vis_type="mean_diff",
123
  layer_type=layer_type,
124
  pair_index=pair_index,
125
- figure_format=figure_format
126
  )
127
  filepath = os.path.join(output_dir, filename)
128
  if not os.path.isfile(filepath):
@@ -130,6 +132,7 @@ def run_visualize_mean_diff(
130
  logger.info(f"Mean-diff image saved at {filepath}")
131
  return filepath
132
 
 
133
  def run_visualize_heatmap(
134
  model_name: str,
135
  prompt_pair: Tuple[str, str],
@@ -151,7 +154,7 @@ def run_visualize_heatmap(
151
  layer_key=layer_key,
152
  output_dir=output_dir,
153
  figure_format=figure_format,
154
- pair_index=pair_index
155
  )
156
 
157
  parts = layer_key.split("_")
@@ -162,7 +165,7 @@ def run_visualize_heatmap(
162
  layer_type=layer_type,
163
  layer_num=layer_num,
164
  pair_index=pair_index,
165
- figure_format=figure_format
166
  )
167
  filepath = os.path.join(output_dir, filename)
168
  if not os.path.isfile(filepath):
@@ -170,13 +173,14 @@ def run_visualize_heatmap(
170
  logger.info(f"Heatmap image saved at {filepath}")
171
  return filepath
172
 
 
173
  def build_visualization_filename(
174
  vis_type: str,
175
  layer_type: str,
176
  layer_num: str = None,
177
  layers: Union[str, List[int]] = None,
178
  pair_index: int = 0,
179
- figure_format: str = "png"
180
  ) -> str:
181
  """
182
  Builds the filename for any visualization.
@@ -188,4 +192,3 @@ def build_visualization_filename(
188
  return f"{vis_type}_{layer_type}_{layer_num}_pair{pair_index}.{figure_format}"
189
  else:
190
  raise ValueError(f"Unknown visualization type: {vis_type}")
191
-
 
10
  from transformers import AutoModelForCausalLM, AutoTokenizer
11
 
12
  import matplotlib
13
+
14
+ matplotlib.use("Agg") # Use 'Agg' backend for non-GUI environments
15
 
16
  logger = logging.getLogger(__name__)
17
  logger.setLevel(logging.INFO)
18
 
19
+
20
  @lru_cache(maxsize=None)
21
  def load_model_tokenizer(model_name: str):
22
  """
23
  Loads the model and tokenizer on the CPU once and caches the result.
24
  """
25
  logger.info(f"Loading model and tokenizer for '{model_name}'")
26
+
27
  # Get HF token from environment for gated models
28
  hf_token = os.getenv("HF_TOKEN")
29
+
30
  # Device selection: MPS (Apple Silicon) > CUDA > CPU
31
  if torch.cuda.is_available():
32
  device = torch.device("cuda")
 
37
  logger.info(f"Using device: {device}")
38
 
39
  model = AutoModelForCausalLM.from_pretrained(
40
+ model_name, token=hf_token # ← AÑADIR ESTA LÍNEA
 
41
  )
42
  tokenizer = AutoTokenizer.from_pretrained(
43
+ model_name, token=hf_token # ← AÑADIR ESTA LÍNEA
 
44
  )
45
 
46
  model = model.to(device)
47
+
48
  logger.info(f"Model loaded on device: {next(model.parameters()).device}")
49
 
50
  return model, tokenizer
51
 
52
+
53
  def run_visualize_pca(
54
  model_name: str,
55
  prompt_pair: Tuple[str, str],
 
73
  highlight_diff=highlight_diff,
74
  output_dir=output_dir,
75
  figure_format=figure_format,
76
+ pair_index=pair_index,
77
  )
78
 
79
  layer_parts = layer_key.split("_")
 
84
  layer_type=layer_type,
85
  layer_num=layer_num,
86
  pair_index=pair_index,
87
+ figure_format=figure_format,
88
  )
89
  filepath = os.path.join(output_dir, filename)
90
 
 
94
  logger.info(f"PCA image saved at {filepath}")
95
  return filepath
96
 
97
+
98
  def run_visualize_mean_diff(
99
  model_name: str,
100
  prompt_pair: Tuple[str, str],
 
117
  layers="all", # By default, show all layers
118
  output_dir=output_dir,
119
  figure_format=figure_format,
120
+ pair_index=pair_index,
121
  )
122
 
123
  filename = build_visualization_filename(
124
  vis_type="mean_diff",
125
  layer_type=layer_type,
126
  pair_index=pair_index,
127
+ figure_format=figure_format,
128
  )
129
  filepath = os.path.join(output_dir, filename)
130
  if not os.path.isfile(filepath):
 
132
  logger.info(f"Mean-diff image saved at {filepath}")
133
  return filepath
134
 
135
+
136
  def run_visualize_heatmap(
137
  model_name: str,
138
  prompt_pair: Tuple[str, str],
 
154
  layer_key=layer_key,
155
  output_dir=output_dir,
156
  figure_format=figure_format,
157
+ pair_index=pair_index,
158
  )
159
 
160
  parts = layer_key.split("_")
 
165
  layer_type=layer_type,
166
  layer_num=layer_num,
167
  pair_index=pair_index,
168
+ figure_format=figure_format,
169
  )
170
  filepath = os.path.join(output_dir, filename)
171
  if not os.path.isfile(filepath):
 
173
  logger.info(f"Heatmap image saved at {filepath}")
174
  return filepath
175
 
176
+
177
  def build_visualization_filename(
178
  vis_type: str,
179
  layer_type: str,
180
  layer_num: str = None,
181
  layers: Union[str, List[int]] = None,
182
  pair_index: int = 0,
183
+ figure_format: str = "png",
184
  ) -> str:
185
  """
186
  Builds the filename for any visualization.
 
192
  return f"{vis_type}_{layer_type}_{layer_num}_pair{pair_index}.{figure_format}"
193
  else:
194
  raise ValueError(f"Unknown visualization type: {vis_type}")