Anisha Bhatnagar commited on
Commit
ce95080
·
1 Parent(s): 40fde16

plot zoom working

Browse files
Files changed (3) hide show
  1. app.py +73 -7
  2. utils/interp_space_utils.py +38 -46
  3. utils/visualizations.py +28 -53
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  import json
3
-
4
 
5
  import os
6
  os.environ["GRADIO_TEMP_DIR"] = "./datasets/temp" # Set a custom temp directory for Gradio
@@ -55,7 +55,7 @@ def validate_ground_truth(gt1, gt2, gt3):
55
  return index, f"Candidate {index+1} is marked as the ground truth author."
56
 
57
 
58
- def app(share=False):#, use_cluster_feats=False):
59
  instances, instance_ids = get_instances(cfg['instances_to_explain_path'])
60
 
61
  interp = load_interp_space(cfg)
@@ -392,9 +392,6 @@ def app(share=False):#, use_cluster_feats=False):
392
  visible_zoomed_authors = gr.State()
393
 
394
  gr.HTML(instruction_callout("Zoom in on the plot to select a set of background authors and see the presence of the top features from this set in candidate and mystery authors."))
395
-
396
- # State to store precomputed regions
397
- precomputed_regions_state = gr.State()
398
 
399
  # Add this after the plot generation
400
  gr.HTML("""
@@ -413,6 +410,17 @@ def app(share=False):#, use_cluster_feats=False):
413
  Select a precomputed region to analyze, or zoom manually on the plot above
414
  </div>
415
  """)
 
 
 
 
 
 
 
 
 
 
 
416
 
417
  precomputed_regions_radio = gr.Radio(
418
  choices=["None"],
@@ -420,6 +428,8 @@ def app(share=False):#, use_cluster_feats=False):
420
  label="Precomputed Regions",
421
  info="Select a region to automatically zoom and analyze"
422
  )
 
 
423
 
424
 
425
  with gr.Row():
@@ -471,9 +481,65 @@ def app(share=False):#, use_cluster_feats=False):
471
  )
472
 
473
  precomputed_regions_radio.change(
474
- fn=lambda region_name, precomputed_regions: trigger_precomputed_region(region_name, precomputed_regions),
475
  inputs=[precomputed_regions_radio, precomputed_regions_state],
476
- outputs=[axis_ranges]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
477
  )
478
 
479
  axis_ranges.change(
 
1
  import gradio as gr
2
  import json
3
+ import ast
4
 
5
  import os
6
  os.environ["GRADIO_TEMP_DIR"] = "./datasets/temp" # Set a custom temp directory for Gradio
 
55
  return index, f"Candidate {index+1} is marked as the ground truth author."
56
 
57
 
58
+ def app(share=False):
59
  instances, instance_ids = get_instances(cfg['instances_to_explain_path'])
60
 
61
  interp = load_interp_space(cfg)
 
392
  visible_zoomed_authors = gr.State()
393
 
394
  gr.HTML(instruction_callout("Zoom in on the plot to select a set of background authors and see the presence of the top features from this set in candidate and mystery authors."))
 
 
 
395
 
396
  # Add this after the plot generation
397
  gr.HTML("""
 
410
  Select a precomputed region to analyze, or zoom manually on the plot above
411
  </div>
412
  """)
413
+
414
+ # State to store precomputed regions
415
+ precomputed_regions_state = gr.Textbox(
416
+ visible=True, # Keep it visible to DOM
417
+ elem_id="precomputed-regions",
418
+ interactive=True,
419
+ show_label=False,
420
+ container=False,
421
+ value="",
422
+ elem_classes=["hidden-textbox"] # Add custom CSS class
423
+ )
424
 
425
  precomputed_regions_radio = gr.Radio(
426
  choices=["None"],
 
428
  label="Precomputed Regions",
429
  info="Select a region to automatically zoom and analyze"
430
  )
431
+ # Add a hidden HTML component for JavaScript execution
432
+ js_trigger = gr.HTML(visible=False, elem_id="js-trigger")
433
 
434
 
435
  with gr.Row():
 
481
  )
482
 
483
  precomputed_regions_radio.change(
484
+ fn=lambda region_name, precomputed_regions_json: trigger_precomputed_region(region_name, ast.literal_eval(precomputed_regions_json)),
485
  inputs=[precomputed_regions_radio, precomputed_regions_state],
486
+ outputs=[axis_ranges],
487
+ js="""
488
+ function(region_name, regions_json_str) {
489
+ console.log('=== ZOOM DEBUG START ===');
490
+ console.log('Region selected:', region_name);
491
+ console.log('Regions JSON string received:', typeof regions_json_str);
492
+
493
+ // Check if Plotly is available
494
+ console.log('Plotly available:', typeof window.Plotly);
495
+
496
+ // Find plot element
497
+ const plotDiv = document.querySelector('#feature-plot .js-plotly-plot');
498
+ console.log('Plot element found:', !!plotDiv);
499
+
500
+ if (plotDiv) {
501
+ console.log('Plot element exists');
502
+ }
503
+
504
+ // Try to parse regions
505
+ try {
506
+ const precomputed_regions = JSON.parse(regions_json_str);
507
+ console.log('Regions parsed successfully');
508
+ console.log('Available regions:', Object.keys(precomputed_regions));
509
+
510
+ if (region_name !== "None" && precomputed_regions[region_name]) {
511
+ const region = precomputed_regions[region_name];
512
+ const bbox = region.bbox;
513
+ console.log('Bbox to apply:', bbox);
514
+
515
+ if (window.Plotly && plotDiv) {
516
+ console.log('Calling Plotly.relayout...');
517
+
518
+ const update = {
519
+ 'xaxis.range': [bbox.xaxis[0], bbox.xaxis[1]],
520
+ 'yaxis.range': [bbox.yaxis[0], bbox.yaxis[1]],
521
+ 'xaxis.autorange': false,
522
+ 'yaxis.autorange': false
523
+ };
524
+ console.log('Update object:', update);
525
+
526
+ window.Plotly.relayout(plotDiv, update)
527
+ .then(() => console.log('✓ Relayout completed successfully'))
528
+ .catch(err => console.log('✗ Relayout failed:', err));
529
+ } else {
530
+ console.log('Missing requirements - Plotly:', !!window.Plotly, 'PlotDiv:', !!plotDiv);
531
+ }
532
+ } else {
533
+ console.log('Region not found or None selected');
534
+ }
535
+ } catch(e) {
536
+ console.log('Error in region processing:', e);
537
+ }
538
+
539
+ console.log('=== ZOOM DEBUG END ===');
540
+ return [region_name, regions_json_str];
541
+ }
542
+ """
543
  )
544
 
545
  axis_ranges.change(
utils/interp_space_utils.py CHANGED
@@ -828,7 +828,7 @@ def compute_predicted_author(task_authors_df: pd.DataFrame, col_name: str) -> in
828
  return predicted_author
829
 
830
 
831
- def compute_precomputed_regions(bg_proj, bg_ids, q_proj, c_proj, n_neighbors=7):
832
  """
833
  Compute precomputed regions for mystery author and candidates.
834
 
@@ -857,14 +857,19 @@ def compute_precomputed_regions(bg_proj, bg_ids, q_proj, c_proj, n_neighbors=7):
857
  print(f"Cache miss. Computing regions.")
858
 
859
  regions = {}
 
860
  # All points for distance calculation (mystery + candidates + background)
861
  all_points = np.vstack([q_proj.reshape(1, -1), c_proj, bg_proj])
862
  all_ids = ['mystery'] + [f'candidate_{i}' for i in range(3)] + bg_ids
863
 
864
- def get_region_around_point(center_point, region_name):
865
  """Get region around a specific point"""
 
 
 
 
866
  # Calculate distances from center point to all background authors
867
- distances = euclidean_distances([center_point], bg_proj)[0]
868
 
869
  # Get indices of closest neighbors
870
  closest_indices = np.argsort(distances)[:n_neighbors]
@@ -872,7 +877,14 @@ def compute_precomputed_regions(bg_proj, bg_ids, q_proj, c_proj, n_neighbors=7):
872
  closest_points = bg_proj[closest_indices]
873
 
874
  # Include the center point in the region
875
- region_points = np.vstack([center_point.reshape(1, -1), closest_points])
 
 
 
 
 
 
 
876
 
877
  # Calculate bounding box with some padding
878
  x_min, x_max = region_points[:, 0].min(), region_points[:, 0].max()
@@ -898,10 +910,11 @@ def compute_precomputed_regions(bg_proj, bg_ids, q_proj, c_proj, n_neighbors=7):
898
  """Get region around the midpoint between two points"""
899
  midpoint = (point1 + point2) / 2
900
  region_name = f"{name1} & {name2}"
901
- return get_region_around_point(midpoint, region_name)
902
-
 
 
903
  # Region 1: Around mystery author only
904
- print(f"Mystery author: {q_proj}")
905
  regions["Mystery Author Neighborhood"] = get_region_around_point(
906
  q_proj, "Mystery Author"
907
  )
@@ -913,10 +926,8 @@ def compute_precomputed_regions(bg_proj, bg_ids, q_proj, c_proj, n_neighbors=7):
913
  )
914
 
915
  # Regions 5-7: Between mystery and each candidate
916
-
917
  for i in range(3):
918
  region_name = f"Mystery & Candidate {i+1}"
919
- print(q_proj, c_proj[i])
920
  regions[region_name] = get_region_between_points(
921
  q_proj, c_proj[i], "Mystery", f"Candidate {i+1}"
922
  )
@@ -939,46 +950,27 @@ def compute_precomputed_regions(bg_proj, bg_ids, q_proj, c_proj, n_neighbors=7):
939
  task_centroid, "All Task Authors"
940
  )
941
 
942
- # Region 12: Wider region encompassing all task authors
943
- all_task_points = np.vstack([q_proj, c_proj])
944
- task_centroid = np.mean(all_task_points, axis=0)
945
-
946
- # Find distances from task centroid to all background authors
947
- distances_from_centroid = euclidean_distances([task_centroid], bg_proj)[0]
948
-
949
- # Take a larger number of neighbors (e.g., 20) for the expanded region
950
- n_expanded = min(20, len(bg_ids)) # Don't exceed available authors
951
- expanded_indices = np.argsort(distances_from_centroid)[:n_expanded]
952
- expanded_authors = [bg_ids[i] for i in expanded_indices]
953
- expanded_points = bg_proj[expanded_indices]
954
-
955
- # Include all task points in the bounding box calculation
956
- all_region_points = np.vstack([all_task_points, expanded_points])
957
-
958
- x_min, x_max = all_region_points[:, 0].min(), all_region_points[:, 0].max()
959
- y_min, y_max = all_region_points[:, 1].min(), all_region_points[:, 1].max()
960
-
961
- # Add moderate padding
962
- x_padding = (x_max - x_min) * 0.15
963
- y_padding = (y_max - y_min) * 0.15
964
-
965
- expanded_bbox = {
966
- 'xaxis': [x_min - x_padding, x_max + x_padding],
967
- 'yaxis': [y_min - y_padding, y_max + y_padding]
968
- }
969
-
970
- regions["Expanded Task Region"] = {
971
- 'bbox': expanded_bbox,
972
- 'authors': expanded_authors,
973
- 'center_point': task_centroid,
974
- 'description': f"Expanded region around all task authors ({len(expanded_authors)} authors)"
975
- }
976
 
977
- cache[key] = regions
 
 
978
  with open(REGION_CACHE, 'wb') as f:
979
  pickle.dump(cache, f)
980
-
981
- return regions
982
 
983
  if __name__ == "__main__":
984
  background_corpus = pd.read_pickle('../datasets/luar_interp_space_cluster_19/train_authors.pkl')
 
828
  return predicted_author
829
 
830
 
831
+ def compute_precomputed_regions(bg_proj, bg_ids, q_proj, c_proj, mystery_id, candidate_ids, n_neighbors=7):
832
  """
833
  Compute precomputed regions for mystery author and candidates.
834
 
 
857
  print(f"Cache miss. Computing regions.")
858
 
859
  regions = {}
860
+
861
  # All points for distance calculation (mystery + candidates + background)
862
  all_points = np.vstack([q_proj.reshape(1, -1), c_proj, bg_proj])
863
  all_ids = ['mystery'] + [f'candidate_{i}' for i in range(3)] + bg_ids
864
 
865
+ def get_region_around_point(center_point, region_name, include_points=None):
866
  """Get region around a specific point"""
867
+ # Ensure center_point is 2D for euclidean_distances
868
+ if center_point.ndim == 1:
869
+ center_point = center_point.reshape(1, -1)
870
+
871
  # Calculate distances from center point to all background authors
872
+ distances = euclidean_distances(center_point, bg_proj)[0]
873
 
874
  # Get indices of closest neighbors
875
  closest_indices = np.argsort(distances)[:n_neighbors]
 
877
  closest_points = bg_proj[closest_indices]
878
 
879
  # Include the center point in the region
880
+ # region_points = np.vstack([center_point.reshape(1, -1), closest_points])
881
+ if include_points is not None:
882
+ region_points = include_points.copy()
883
+ # Add center point and closest background authors
884
+ region_points = np.vstack([region_points, center_point, closest_points])
885
+ else:
886
+ # Standard case - just center point and neighbors
887
+ region_points = np.vstack([center_point, closest_points])
888
 
889
  # Calculate bounding box with some padding
890
  x_min, x_max = region_points[:, 0].min(), region_points[:, 0].max()
 
910
  """Get region around the midpoint between two points"""
911
  midpoint = (point1 + point2) / 2
912
  region_name = f"{name1} & {name2}"
913
+ # Include both original points in the region
914
+ include_points = np.vstack([point1.reshape(1, -1), point2.reshape(1, -1)])
915
+ return get_region_around_point(midpoint, region_name, include_points=include_points)
916
+
917
  # Region 1: Around mystery author only
 
918
  regions["Mystery Author Neighborhood"] = get_region_around_point(
919
  q_proj, "Mystery Author"
920
  )
 
926
  )
927
 
928
  # Regions 5-7: Between mystery and each candidate
 
929
  for i in range(3):
930
  region_name = f"Mystery & Candidate {i+1}"
 
931
  regions[region_name] = get_region_between_points(
932
  q_proj, c_proj[i], "Mystery", f"Candidate {i+1}"
933
  )
 
950
  task_centroid, "All Task Authors"
951
  )
952
 
953
+ def serialize_numpy_dtypes(obj):
954
+ if isinstance(obj, np.ndarray):
955
+ return obj.tolist()
956
+ elif isinstance(obj, (np.float32, np.float64)):
957
+ return float(obj)
958
+ elif isinstance(obj, (np.int32, np.int64)):
959
+ return int(obj)
960
+ elif isinstance(obj, dict):
961
+ return {key: serialize_numpy_dtypes(value) for key, value in obj.items()}
962
+ elif isinstance(obj, list):
963
+ return [serialize_numpy_dtypes(item) for item in obj]
964
+ else:
965
+ return obj
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
966
 
967
+ serializable_regions = serialize_numpy_dtypes(regions)
968
+ response = json.dumps(serializable_regions, default=str)
969
+ cache[key] = response
970
  with open(REGION_CACHE, 'wb') as f:
971
  pickle.dump(cache, f)
972
+
973
+ return response
974
 
975
  if __name__ == "__main__":
976
  background_corpus = pd.read_pickle('../datasets/luar_interp_space_cluster_19/train_authors.pkl')
utils/visualizations.py CHANGED
@@ -515,12 +515,16 @@ def visualize_clusters_plotly(iid, cfg, instances, model_radio, custom_model_inp
515
  bg_ids_for_regions = bg_ids[4:] # Background IDs
516
 
517
  # Compute precomputed regions
 
 
 
518
  precomputed_regions = compute_precomputed_regions(
519
- bg_proj_for_regions, bg_ids_for_regions, q_proj, c_proj
520
  )
521
 
522
  # Create choices for radio buttons
523
- region_choices = ["None"] + list(precomputed_regions.keys())
 
524
 
525
  print('Done processing....')
526
 
@@ -537,60 +541,31 @@ def visualize_clusters_plotly(iid, cfg, instances, model_radio, custom_model_inp
537
  )
538
  # return fig, update(choices=feature_list, value=feature_list[0]),feature_list
539
 
540
-
541
- def extract_cluster_key(display_label: str) -> int:
542
- """
543
- Given a dropdown label like
544
- "Cluster 5 (closest to mystery author; closest to Candidate 1 author)"
545
- returns the integer 5.
546
- """
547
- m = re.match(r"Cluster\s+(\d+)", display_label)
548
- if not m:
549
- raise ValueError(f"Unrecognized cluster label: {display_label}")
550
- return int(m.group(1))
551
-
552
  def trigger_precomputed_region(region_name, precomputed_regions):
553
  """
554
  Simulate a zoom event for a precomputed region.
555
  Returns the JSON payload that would be sent to axis_ranges.
556
  """
557
  print(f"[INFO] Triggering precomputed region: {region_name}")
558
- print(f"Available regions: {list(precomputed_regions.keys())}")
559
- if region_name == "None" or region_name not in precomputed_regions:
560
- return ""
561
-
562
- region = precomputed_regions[region_name]
563
- payload = region['bbox']
564
- json_payload = {
565
- 'xaxis': [float(payload['xaxis'][0]), float(payload['xaxis'][1])],
566
- 'yaxis': [float(payload['yaxis'][0]), float(payload['yaxis'][1])]
567
- }
568
- return json.dumps(json_payload)
569
-
570
- # When a cluster is selected, split features and populate radio buttons
571
- def on_cluster_change(selected_cluster, style_map):
572
- cluster_key = extract_cluster_key(selected_cluster)
573
- all_feats = style_map[cluster_key]
574
- llm_feats, g2v_feats = split_features(all_feats)
575
- # print(f"Selected cluster: {selected_cluster} ({cluster_key})")
576
- # print(f"LLM features: {llm_feats}")
577
-
578
- # Add "None" as a default selectable option
579
- llm_feats = ["None"] + llm_feats
580
-
581
- # filter out any g2v feature without a shorthand
582
- filtered_g2v = []
583
- for feat in g2v_feats:
584
- if get_shorthand(feat) is None:
585
- print(f"Skipping Gram2Vec feature without shorthand: {feat}")
586
- else:
587
- filtered_g2v.append(feat)
588
-
589
- # Add "None" as a default selectable option
590
- filtered_g2v = ["None"] + filtered_g2v
591
-
592
- return (
593
- gr.update(choices=llm_feats, value=llm_feats[0]),
594
- gr.update(choices=filtered_g2v, value=filtered_g2v[0]),
595
- llm_feats
596
- )
 
515
  bg_ids_for_regions = bg_ids[4:] # Background IDs
516
 
517
  # Compute precomputed regions
518
+ mystery_id = task_authors_df['authorID'].iloc[0] # Mystery author ID
519
+ candidate_ids = task_authors_df['authorID'].iloc[1:4].tolist() # 3 candidate IDs
520
+
521
  precomputed_regions = compute_precomputed_regions(
522
+ bg_proj_for_regions, bg_ids_for_regions, q_proj, c_proj, mystery_id, candidate_ids
523
  )
524
 
525
  # Create choices for radio buttons
526
+ pc=json.loads(precomputed_regions)
527
+ region_choices = ["None"] + list(pc.keys())
528
 
529
  print('Done processing....')
530
 
 
541
  )
542
  # return fig, update(choices=feature_list, value=feature_list[0]),feature_list
543
 
 
 
 
 
 
 
 
 
 
 
 
 
544
  def trigger_precomputed_region(region_name, precomputed_regions):
545
  """
546
  Simulate a zoom event for a precomputed region.
547
  Returns the JSON payload that would be sent to axis_ranges.
548
  """
549
  print(f"[INFO] Triggering precomputed region: {region_name}")
550
+ print(f"precomputed_regions type: {type(precomputed_regions)}")
551
+ # print(f"precomputed_regions content: {precomputed_regions}")
552
+ try:
553
+ # Parse the JSON string back to dictionary
554
+ # precomputed_regions = json.loads(precomputed_regions) if precomputed_regions else {}
555
+ print(f"Available regions: {len(list(precomputed_regions.keys()))}")
556
+ # print(f"Available regions: {list(precomputed_regions.keys())}")
557
+ if region_name == "None" or region_name not in precomputed_regions:
558
+ return ""
559
+
560
+ region = precomputed_regions[region_name]
561
+ payload = region['bbox']
562
+ json_payload = {
563
+ 'xaxis': [float(payload['xaxis'][0]), float(payload['xaxis'][1])],
564
+ 'yaxis': [float(payload['yaxis'][0]), float(payload['yaxis'][1])]
565
+ }
566
+
567
+ # js_code = trigger_plot_zoom_js(region_name, precomputed_regions)
568
+ return json.dumps(json_payload)#, js_code
569
+ except Exception as e:
570
+ print(f"[ERROR] Failed to trigger precomputed region: {e}")
571
+ return ""