Milad Alshomary commited on
Commit
912be5c
Β·
2 Parent(s): a721fcf 5cc0947
app.py CHANGED
@@ -462,13 +462,6 @@ def app(share=False):
462
  ">
463
  Gram2Vec Features prominent in the zoomed-in region
464
  </div>
465
- <div style="
466
- font-size: 0.9em;
467
- color: #666;
468
- margin-bottom: 1em;
469
- ">
470
- Features shown with normalized z-scores
471
- </div>
472
  """)
473
  gram2vec_rb = gr.Radio(choices=[], label="Gram2Vec features for this zoomed-in region")#, label="Top-10 Gram2Vec Features most likely to occur in Mystery Author", info="Most prominent Gram2Vec features in the mystery text")
474
  gram2vec_state = gr.State()
@@ -562,7 +555,7 @@ def app(share=False):
562
 
563
  axis_ranges.change(
564
  fn=handle_zoom_with_retries,
565
- inputs=[axis_ranges, bg_proj_state, bg_lbls_state, bg_authors_df, task_authors_embeddings_df],
566
  outputs=[features_rb, gram2vec_rb , llm_style_feats_analysis, feature_list_state, visible_zoomed_authors]
567
  )
568
 
 
462
  ">
463
  Gram2Vec Features prominent in the zoomed-in region
464
  </div>
 
 
 
 
 
 
 
465
  """)
466
  gram2vec_rb = gr.Radio(choices=[], label="Gram2Vec features for this zoomed-in region")#, label="Top-10 Gram2Vec Features most likely to occur in Mystery Author", info="Most prominent Gram2Vec features in the mystery text")
467
  gram2vec_state = gr.State()
 
555
 
556
  axis_ranges.change(
557
  fn=handle_zoom_with_retries,
558
+ inputs=[axis_ranges, bg_proj_state, bg_lbls_state, bg_authors_df, task_authors_embeddings_df, predicted_author],
559
  outputs=[features_rb, gram2vec_rb , llm_style_feats_analysis, feature_list_state, visible_zoomed_authors]
560
  )
561
 
utils/gram2vec_feat_utils.py CHANGED
@@ -284,6 +284,54 @@ def create_html(texts, llm_spans_list, gram_spans_list, selected_feature_llm, se
284
  # print(i, label, txt[:30])
285
  label = get_label(label, predicted_author, ground_truth_author, i) if background else get_label(label, predicted_author, ground_truth_author)
286
  combined = highlight_both_spans(txt, llm_spans_list[i], gram_spans_list[i])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  notice = ""
288
  if selected_feature_llm == "None":
289
  notice += f"""
@@ -317,6 +365,7 @@ def create_html(texts, llm_spans_list, gram_spans_list, selected_feature_llm, se
317
  """
318
  html.append(f"""
319
  <h3>{label}</h3>
 
320
  {notice}
321
  <div style="border:1px solid #ccc; padding:8px; margin-bottom:1em;">
322
  {combined}
 
284
  # print(i, label, txt[:30])
285
  label = get_label(label, predicted_author, ground_truth_author, i) if background else get_label(label, predicted_author, ground_truth_author)
286
  combined = highlight_both_spans(txt, llm_spans_list[i], gram_spans_list[i])
287
+
288
+ # Count spans for display
289
+ llm_span_count = len(llm_spans_list[i])
290
+ gram_span_count = len(gram_spans_list[i])
291
+
292
+ # Build span count display
293
+ span_count_info = ""
294
+ if selected_feature_llm != "None" or selected_feature_g2v != "None":
295
+ span_count_info = """
296
+ <div style="
297
+ background: #f5f5f5;
298
+ border: 1px solid #ddd;
299
+ border-radius: 4px;
300
+ padding: 8px;
301
+ margin-bottom: 8px;
302
+ font-size: 0.95em;
303
+ display: flex;
304
+ gap: 1em;
305
+ ">
306
+ """
307
+ if selected_feature_llm != "None":
308
+ span_count_info += f"""
309
+ <div style="flex: 1;">
310
+ <strong>LLM Feature Spans:</strong>
311
+ <span style="
312
+ background: #FFEB3B;
313
+ padding: 2px 8px;
314
+ border-radius: 3px;
315
+ margin-left: 4px;
316
+ font-weight: bold;
317
+ ">{llm_span_count}</span>
318
+ </div>
319
+ """
320
+ if selected_feature_g2v != "None":
321
+ span_count_info += f"""
322
+ <div style="flex: 1;">
323
+ <strong>G2V Feature Spans:</strong>
324
+ <span style="
325
+ background: #5CB3FF;
326
+ padding: 2px 8px;
327
+ border-radius: 3px;
328
+ margin-left: 4px;
329
+ font-weight: bold;
330
+ ">{gram_span_count}</span>
331
+ </div>
332
+ """
333
+ span_count_info += "</div>"
334
+
335
  notice = ""
336
  if selected_feature_llm == "None":
337
  notice += f"""
 
365
  """
366
  html.append(f"""
367
  <h3>{label}</h3>
368
+ {span_count_info}
369
  {notice}
370
  <div style="border:1px solid #ccc; padding:8px; margin-bottom:1em;">
371
  {combined}
utils/interp_space_utils.py CHANGED
@@ -580,9 +580,11 @@ def compute_clusters_style_representation_3(
580
  max_authors_for_span_extraction=4,
581
  top_k: int = 10,
582
  return_only_feats= False,
 
583
  ):
584
 
585
  print(f"Computing style representation for visible clusters: {len(cluster_ids)}")
 
586
  # STEP 1: Identify features on max_num_authors's max_num_documents_per_author number of documents
587
  background_corpus_df['fullText'] = background_corpus_df['fullText'].map(lambda x: '\n\n'.join(x[:max_num_documents_per_author]) if isinstance(x, list) else x)
588
  background_corpus_df_feat_id = background_corpus_df[background_corpus_df[cluster_label_clm_name].isin(cluster_ids)]
@@ -607,19 +609,30 @@ def compute_clusters_style_representation_3(
607
 
608
  # Filter-in only task authors that are part of the current selection
609
  task_author_names = {'Mystery author', 'Candidate Author 1', 'Candidate Author 2', 'Candidate Author 3'}
610
-
611
-
 
 
 
 
612
  feature_importance = {f : 0 for f in features}
613
  for author, feature_map in spans_by_author.items():
614
  if author in task_author_names.intersection(set(cluster_ids)):
615
  for feature, spans in feature_map.items():
616
  if spans:
617
- feature_importance[feature] += len(spans)
 
 
 
 
 
618
  else:
 
619
  for feature, spans in feature_map.items():
620
  if spans:
621
  feature_importance[feature] -= len(spans)
622
- # print(feature_importance)
 
623
  selected_features_ranked = sorted(feature_importance, key=lambda f: -feature_importance[f])[:int(top_k)]
624
 
625
  #print('filtered set of features (min coverage', len(author_present_feature_sets), '): ', selected_features_ranked)
@@ -739,8 +752,10 @@ def compute_clusters_g2v_representation(
739
  features_clm_name: str,
740
  top_n: int = 10,
741
  max_candidates_for_span_sorting: int = 50,
 
742
  ) -> List[tuple]: # Changed return type to List[tuple] to include scores
743
 
 
744
  # 1) Identify selected authors in the zoom region
745
  selected_mask = background_corpus_df['authorID'].isin(author_ids).to_numpy()
746
 
@@ -788,15 +803,17 @@ def compute_clusters_g2v_representation(
788
  # Get task author data
789
  task_authors_df = background_corpus_df[background_corpus_df['authorID'].isin(task_authors_in_selection)]
790
 
791
- # Count spans for each feature across task authors
792
- feature_span_counts = {}
 
 
 
 
793
  for feat_shorthand, z_score in candidate_features:
794
- span_count = 0
795
-
796
- # Convert shorthand to human-readable for display (if needed)
797
- # Note: features in gram2vec dict are in shorthand format like "pos_unigrams:ADJ"
798
 
799
  for _, author_row in task_authors_df.iterrows():
 
800
  author_text = author_row['fullText']
801
  if isinstance(author_text, list):
802
  author_text = '\n\n'.join(author_text)
@@ -804,20 +821,27 @@ def compute_clusters_g2v_representation(
804
  try:
805
  # find_feature_spans expects shorthand format like "pos_unigrams:ADJ"
806
  spans = find_feature_spans(author_text, feat_shorthand)
807
- span_count += len(spans)
 
 
 
 
 
 
 
808
  except Exception as e:
809
  # If span extraction fails, continue with 0 spans for this author
810
  pass
811
 
812
- feature_span_counts[feat_shorthand] = span_count
813
 
814
- # 8) Sort features by span frequency, then by z-score as tiebreaker
815
  sorted_by_spans = sorted(
816
  candidate_features,
817
- key=lambda x: (-feature_span_counts.get(x[0], 0), -x[1])
818
  )
819
 
820
- # print(f"[INFO] Sorted gram2vec features by span frequency: {[(f, feature_span_counts.get(f, 0), z) for f, z in sorted_by_spans[:top_n]]}")
821
 
822
  return sorted_by_spans[:top_n]
823
 
 
580
  max_authors_for_span_extraction=4,
581
  top_k: int = 10,
582
  return_only_feats= False,
583
+ predicted_author: int = None
584
  ):
585
 
586
  print(f"Computing style representation for visible clusters: {len(cluster_ids)}")
587
+ print(f"Predicted author: {predicted_author}")
588
  # STEP 1: Identify features on max_num_authors's max_num_documents_per_author number of documents
589
  background_corpus_df['fullText'] = background_corpus_df['fullText'].map(lambda x: '\n\n'.join(x[:max_num_documents_per_author]) if isinstance(x, list) else x)
590
  background_corpus_df_feat_id = background_corpus_df[background_corpus_df[cluster_label_clm_name].isin(cluster_ids)]
 
609
 
610
  # Filter-in only task authors that are part of the current selection
611
  task_author_names = {'Mystery author', 'Candidate Author 1', 'Candidate Author 2', 'Candidate Author 3'}
612
+
613
+ # Define mystery and predicted author names
614
+ mystery_author = 'Mystery author'
615
+ predicted_author_name = f'Candidate Author {predicted_author + 1}' if predicted_author is not None else None
616
+
617
+ # Compute feature importance based on Mystery + Predicted author vs. other candidates
618
  feature_importance = {f : 0 for f in features}
619
  for author, feature_map in spans_by_author.items():
620
  if author in task_author_names.intersection(set(cluster_ids)):
621
  for feature, spans in feature_map.items():
622
  if spans:
623
+ # Add span count if Mystery or Predicted author, subtract if other candidate
624
+ if author == mystery_author or (predicted_author is not None and author == predicted_author_name):
625
+ feature_importance[feature] += len(spans)
626
+ else:
627
+ # Other candidates - subtract their span counts
628
+ feature_importance[feature] -= len(spans)
629
  else:
630
+ # Background authors - subtract their span counts
631
  for feature, spans in feature_map.items():
632
  if spans:
633
  feature_importance[feature] -= len(spans)
634
+
635
+ print(f"Feature importance scores: {feature_importance}")
636
  selected_features_ranked = sorted(feature_importance, key=lambda f: -feature_importance[f])[:int(top_k)]
637
 
638
  #print('filtered set of features (min coverage', len(author_present_feature_sets), '): ', selected_features_ranked)
 
752
  features_clm_name: str,
753
  top_n: int = 10,
754
  max_candidates_for_span_sorting: int = 50,
755
+ predicted_author: int = None
756
  ) -> List[tuple]: # Changed return type to List[tuple] to include scores
757
 
758
+ print(f"[INFO] Computing G2V representation with predicted_author: {predicted_author}")
759
  # 1) Identify selected authors in the zoom region
760
  selected_mask = background_corpus_df['authorID'].isin(author_ids).to_numpy()
761
 
 
803
  # Get task author data
804
  task_authors_df = background_corpus_df[background_corpus_df['authorID'].isin(task_authors_in_selection)]
805
 
806
+ # Define mystery and predicted author names
807
+ mystery_author = 'Mystery author'
808
+ predicted_author_name = f'Candidate Author {predicted_author + 1}' if predicted_author is not None else None
809
+
810
+ # Count spans for each feature: +1 for Mystery/Predicted, -1 for other candidates
811
+ feature_span_scores = {}
812
  for feat_shorthand, z_score in candidate_features:
813
+ span_score = 0
 
 
 
814
 
815
  for _, author_row in task_authors_df.iterrows():
816
+ author_name = author_row['authorID']
817
  author_text = author_row['fullText']
818
  if isinstance(author_text, list):
819
  author_text = '\n\n'.join(author_text)
 
821
  try:
822
  # find_feature_spans expects shorthand format like "pos_unigrams:ADJ"
823
  spans = find_feature_spans(author_text, feat_shorthand)
824
+ span_count = len(spans)
825
+
826
+ # Add span count if Mystery or Predicted author, subtract if other candidate
827
+ if author_name == mystery_author or (predicted_author is not None and author_name == predicted_author_name):
828
+ span_score += span_count
829
+ else:
830
+ # Other candidates - subtract their span counts
831
+ span_score -= span_count
832
  except Exception as e:
833
  # If span extraction fails, continue with 0 spans for this author
834
  pass
835
 
836
+ feature_span_scores[feat_shorthand] = span_score
837
 
838
+ # 8) Sort features by span score (Mystery+Predicted vs Others), then by z-score as tiebreaker
839
  sorted_by_spans = sorted(
840
  candidate_features,
841
+ key=lambda x: (-feature_span_scores.get(x[0], 0), -x[1])
842
  )
843
 
844
+ print(f"[INFO] Top 5 gram2vec features by span score: {[(f, feature_span_scores.get(f, 0), z) for f, z in sorted_by_spans[:5]]}")
845
 
846
  return sorted_by_spans[:top_n]
847
 
utils/visualizations.py CHANGED
@@ -204,11 +204,11 @@ def load_interp_space(cfg):
204
  # Function to process G2V features and create display choices
205
  def format_g2v_features_for_display(g2v_features_with_scores):
206
  """
207
- Convert G2V features with z-scores into display format for Gradio radio buttons.
208
 
209
  Args:
210
  g2v_features_with_scores: List of tuples like:
211
- [('None', None), ('Feature Name', z_score), ...]
212
 
213
  Returns:
214
  tuple: (display_choices, original_values)
@@ -218,22 +218,15 @@ def format_g2v_features_for_display(g2v_features_with_scores):
218
 
219
  for item in g2v_features_with_scores:
220
  if len(item) == 2:
221
- feature_name, z_score = item
222
 
223
  # Handle None case
224
- if feature_name == "None" or z_score is None:
225
  display_choices.append("None")
226
  original_values.append("None")
227
  else:
228
- # Convert numpy float to regular float if needed
229
- if hasattr(z_score, 'item'):
230
- z_score = float(z_score.item())
231
- else:
232
- z_score = float(z_score)
233
-
234
- # Create display string with z-score
235
- display_string = f"{feature_name} | [Z={z_score:.2f}]"
236
- display_choices.append(display_string)
237
  original_values.append(feature_name)
238
  else:
239
  # Handle unexpected format
@@ -243,14 +236,17 @@ def format_g2v_features_for_display(g2v_features_with_scores):
243
  return display_choices, original_values
244
 
245
  #function to handle zoom events
246
- def handle_zoom(event_json, bg_proj, bg_lbls, clustered_authors_df, task_authors_df):
247
  """
248
  event_json – stringified JSON from JS listener
249
  bg_proj – (N,2) numpy array with 2D coordinates
250
  bg_lbls – list of N author IDs
251
  clustered_authors_df – pd.DataFrame containing authorID and final_attribute_name
 
 
252
  """
253
  print("[INFO] Handling zoom event")
 
254
 
255
  if not event_json:
256
  return gr.update(value=""), gr.update(value=""), None, None, None
@@ -280,6 +276,7 @@ def handle_zoom(event_json, bg_proj, bg_lbls, clustered_authors_df, task_authors
280
  background_corpus_df=merged_authors_df,
281
  cluster_ids=visible_authors,
282
  cluster_label_clm_name='authorID',
 
283
  )
284
 
285
  llm_feats = ['None'] + style_analysis_response['features']
@@ -292,7 +289,8 @@ def handle_zoom(event_json, bg_proj, bg_lbls, clustered_authors_df, task_authors
292
  author_ids=visible_authors,
293
  other_author_ids=[],
294
  features_clm_name='g2v_vector',
295
- top_n=50
 
296
  )
297
 
298
  # ── Span-existence filter on task authors in the zoom ───────────────────
@@ -357,19 +355,20 @@ def handle_zoom(event_json, bg_proj, bg_lbls, clustered_authors_df, task_authors
357
  )
358
  # return gr.update(value="\n".join(llm_feats).join("\n").join(g2v_feats)), llm_feats, g2v_feats
359
 
360
- def handle_zoom_with_retries(event_json, bg_proj, bg_lbls, clustered_authors_df, task_authors_df):
361
  """
362
  event_json – stringified JSON from JS listener
363
  bg_proj – (N,2) numpy array with 2D coordinates
364
  bg_lbls – list of N author IDs
365
  clustered_authors_df – pd.DataFrame containing authorID and final_attribute_name
366
  task_authors_df – pd.DataFrame containing authorID and final_attribute_name
 
367
  """
368
  print("[INFO] Handling zoom event with retries")
369
 
370
  for attempt in range(3):
371
  try:
372
- return handle_zoom(event_json, bg_proj, bg_lbls, clustered_authors_df, task_authors_df)
373
  except Exception as e:
374
  print(f"[ERROR] Attempt {attempt + 1} failed: {e}")
375
  if attempt < 2:
 
204
  # Function to process G2V features and create display choices
205
  def format_g2v_features_for_display(g2v_features_with_scores):
206
  """
207
+ Convert G2V features into display format for Gradio radio buttons.
208
 
209
  Args:
210
  g2v_features_with_scores: List of tuples like:
211
+ [('None', None), ('Feature Name', score), ...]
212
 
213
  Returns:
214
  tuple: (display_choices, original_values)
 
218
 
219
  for item in g2v_features_with_scores:
220
  if len(item) == 2:
221
+ feature_name, score = item
222
 
223
  # Handle None case
224
+ if feature_name == "None" or score is None:
225
  display_choices.append("None")
226
  original_values.append("None")
227
  else:
228
+ # Just show the feature name without scores
229
+ display_choices.append(feature_name)
 
 
 
 
 
 
 
230
  original_values.append(feature_name)
231
  else:
232
  # Handle unexpected format
 
236
  return display_choices, original_values
237
 
238
  #function to handle zoom events
239
+ def handle_zoom(event_json, bg_proj, bg_lbls, clustered_authors_df, task_authors_df, predicted_author=None):
240
  """
241
  event_json – stringified JSON from JS listener
242
  bg_proj – (N,2) numpy array with 2D coordinates
243
  bg_lbls – list of N author IDs
244
  clustered_authors_df – pd.DataFrame containing authorID and final_attribute_name
245
+ task_authors_df – pd.DataFrame containing task authors
246
+ predicted_author – index of predicted author (0, 1, or 2)
247
  """
248
  print("[INFO] Handling zoom event")
249
+ print(f"[INFO] Predicted author: {predicted_author}")
250
 
251
  if not event_json:
252
  return gr.update(value=""), gr.update(value=""), None, None, None
 
276
  background_corpus_df=merged_authors_df,
277
  cluster_ids=visible_authors,
278
  cluster_label_clm_name='authorID',
279
+ predicted_author=predicted_author
280
  )
281
 
282
  llm_feats = ['None'] + style_analysis_response['features']
 
289
  author_ids=visible_authors,
290
  other_author_ids=[],
291
  features_clm_name='g2v_vector',
292
+ top_n=50,
293
+ predicted_author=predicted_author
294
  )
295
 
296
  # ── Span-existence filter on task authors in the zoom ───────────────────
 
355
  )
356
  # return gr.update(value="\n".join(llm_feats).join("\n").join(g2v_feats)), llm_feats, g2v_feats
357
 
358
+ def handle_zoom_with_retries(event_json, bg_proj, bg_lbls, clustered_authors_df, task_authors_df, predicted_author=None):
359
  """
360
  event_json – stringified JSON from JS listener
361
  bg_proj – (N,2) numpy array with 2D coordinates
362
  bg_lbls – list of N author IDs
363
  clustered_authors_df – pd.DataFrame containing authorID and final_attribute_name
364
  task_authors_df – pd.DataFrame containing authorID and final_attribute_name
365
+ predicted_author – index of predicted author (0, 1, or 2)
366
  """
367
  print("[INFO] Handling zoom event with retries")
368
 
369
  for attempt in range(3):
370
  try:
371
+ return handle_zoom(event_json, bg_proj, bg_lbls, clustered_authors_df, task_authors_df, predicted_author)
372
  except Exception as e:
373
  print(f"[ERROR] Attempt {attempt + 1} failed: {e}")
374
  if attempt < 2: