Milad Alshomary commited on
Commit
0f2fc55
Β·
2 Parent(s): d8b3dc8 884a75c
app.py CHANGED
@@ -29,6 +29,7 @@ cfg = load_config()
29
  download_file_override(cfg.get('background_authors_df_url'), cfg.get('background_authors_df_path'))
30
  download_file_override(cfg.get('instances_to_explain_url'), cfg.get('instances_to_explain_path'))
31
  download_file_override(cfg.get('gram2vec_feats_url'), cfg.get('gram2vec_feats_path'))
 
32
  download_file_override(cfg.get('embeddings_cache_url'), cfg.get('embeddings_cache_path'))
33
  download_file_override(cfg.get('zoom_cache_url'), cfg.get('zoom_cache_path'))
34
  download_file_override(cfg.get('region_cache_url'), cfg.get('region_cache_path'))
@@ -142,13 +143,13 @@ def app(share=False):
142
  # ── Model Selection ─────────────────────────────────
143
  model_radio = gr.Radio(
144
  choices=[
 
145
  'gabrielloiseau/LUAR-MUD-sentence-transformers',
146
  'gabrielloiseau/LUAR-CRUD-sentence-transformers',
147
  'miladalsh/light-luar',
148
- 'AnnaWegmann/Style-Embedding',
149
  'Other'
150
  ],
151
- value='gabrielloiseau/LUAR-MUD-sentence-transformers',
152
  label='Choose a Model to inspect'
153
  )
154
  print(f"Model choices: {model_radio.choices}")
@@ -168,8 +169,8 @@ def app(share=False):
168
 
169
  # ── Task Source Selection ─────────────────────────────────
170
  task_mode = gr.Radio(
171
- choices=["Predefined HRS Task", "Upload Your Own Task"],
172
- value="Predefined HRS Task",
173
  label="Select Task Source"
174
  )
175
 
 
29
  download_file_override(cfg.get('background_authors_df_url'), cfg.get('background_authors_df_path'))
30
  download_file_override(cfg.get('instances_to_explain_url'), cfg.get('instances_to_explain_path'))
31
  download_file_override(cfg.get('gram2vec_feats_url'), cfg.get('gram2vec_feats_path'))
32
+ download_file_override(cfg.get('gram2vec_cache_url'), cfg.get('gram2vec_cache_path'))
33
  download_file_override(cfg.get('embeddings_cache_url'), cfg.get('embeddings_cache_path'))
34
  download_file_override(cfg.get('zoom_cache_url'), cfg.get('zoom_cache_path'))
35
  download_file_override(cfg.get('region_cache_url'), cfg.get('region_cache_path'))
 
143
  # ── Model Selection ─────────────────────────────────
144
  model_radio = gr.Radio(
145
  choices=[
146
+ 'AnnaWegmann/Style-Embedding',
147
  'gabrielloiseau/LUAR-MUD-sentence-transformers',
148
  'gabrielloiseau/LUAR-CRUD-sentence-transformers',
149
  'miladalsh/light-luar',
 
150
  'Other'
151
  ],
152
+ value='AnnaWegmann/Style-Embedding',
153
  label='Choose a Model to inspect'
154
  )
155
  print(f"Model choices: {model_radio.choices}")
 
169
 
170
  # ── Task Source Selection ─────────────────────────────────
171
  task_mode = gr.Radio(
172
+ choices=["Predefined Reddit Task", "Upload Your Own Task"],
173
+ value="Predefined Reddit Task",
174
  label="Select Task Source"
175
  )
176
 
config/config.yaml CHANGED
@@ -9,6 +9,9 @@ background_authors_df_url: "https://huggingface.co/datasets/miladalsh/explana
9
  gram2vec_feats_path: "./datasets/gram2vec_feats.csv"
10
  gram2vec_feats_url: "https://huggingface.co/datasets/miladalsh/explanation_tool_files/resolve/main/gram2vec_feats.csv?download=true"
11
 
 
 
 
12
  embeddings_cache_url: "https://huggingface.co/datasets/miladalsh/explanation_tool_files/resolve/main/embeddings_cache.zip?download=true"
13
  embeddings_cache_path: "./datasets/embeddings_cache/"
14
 
 
9
  gram2vec_feats_path: "./datasets/gram2vec_feats.csv"
10
  gram2vec_feats_url: "https://huggingface.co/datasets/miladalsh/explanation_tool_files/resolve/main/gram2vec_feats.csv?download=true"
11
 
12
+ gram2vec_cache_url: "https://huggingface.co/datasets/miladalsh/explanation_tool_files/resolve/main/gram2vec_cache.zip?download=true"
13
+ gram2vec_cache_path: "./datasets/gram2vec_cache/"
14
+
15
  embeddings_cache_url: "https://huggingface.co/datasets/miladalsh/explanation_tool_files/resolve/main/embeddings_cache.zip?download=true"
16
  embeddings_cache_path: "./datasets/embeddings_cache/"
17
 
precompute_caches.py CHANGED
@@ -8,18 +8,30 @@ import pandas as pd
8
  from datetime import datetime
9
  import yaml
10
 
11
- # Import your actual modules exactly as app.py does
12
- from utils.visualizations import get_instances, load_interp_space, trigger_precomputed_region, handle_zoom_with_retries
13
- from utils.ui import update_task_display
14
 
15
  def load_config(path="config/config.yaml"):
16
  with open(path, "r") as f:
17
  return yaml.safe_load(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  def precompute_all_caches(
20
  models_to_test=None,
21
  instances_to_process=None,
22
- config_path="config/config.yaml"
23
  ):
24
  """
25
  Precompute all cache files using the EXACT same methods as app.py.
@@ -34,16 +46,12 @@ def precompute_all_caches(
34
  'AnnaWegmann/Style-Embedding'
35
  ]
36
 
37
- print("=" * 60)
38
  print("CACHE PRECOMPUTATION STARTED")
39
  print(f"Timestamp: {datetime.now()}")
40
  print(f"Models to test: {len(models_to_test)}")
41
  print("=" * 60)
42
-
43
- # Load configuration and instances EXACTLY like app.py
44
- cfg = load_config(config_path)
45
- print(f"Configuration loaded from {config_path}")
46
- print(f"config : \n{cfg}")
47
  instances, instance_ids = get_instances(cfg['instances_to_explain_path'])
48
  # interp = load_interp_space(cfg)
49
  # clustered_authors_df = interp['clustered_authors_df']
@@ -72,7 +80,9 @@ def precompute_all_caches(
72
  for instance_id in tqdm(instances_to_process, desc=f"Processing instances for {model_name.split('/')[-1]}"):
73
  current_combination += 1
74
  try:
75
- print(f"\n[{current_combination}/{total_combinations}] Processing Instance {instance_id}")
 
 
76
 
77
  # STEP 1: Replicate the exact flow from load_button.click()
78
  print(" β†’ Replicating load_button.click() flow...")
@@ -82,7 +92,7 @@ def precompute_all_caches(
82
 
83
  # Call update_task_display EXACTLY like app.py does
84
  task_results = update_task_display(
85
- mode="Predefined HRS Task", # Always use predefined for caching
86
  iid=f"Task {instance_id}",
87
  instances=instances,
88
  background_df=clustered_authors_df,
@@ -137,6 +147,7 @@ def precompute_all_caches(
137
  if precomputed_regions_state:
138
  regions_dict = ast.literal_eval(precomputed_regions_state)
139
  test_regions = list(regions_dict.keys())
 
140
 
141
  for region_name in test_regions:
142
  try:
@@ -194,7 +205,7 @@ from utils.visualizations import visualize_clusters_plotly
194
 
195
  if __name__ == "__main__":
196
  # Test with a small subset first
197
- instances=[i for i in range(10)] # First 20 instances for testing
198
  cache_stats = precompute_all_caches(
199
  models_to_test=[
200
  'AnnaWegmann/Style-Embedding'
 
8
  from datetime import datetime
9
  import yaml
10
 
11
+ CONFIG_PATH="config/config.yaml"
 
 
12
 
13
  def load_config(path="config/config.yaml"):
14
  with open(path, "r") as f:
15
  return yaml.safe_load(f)
16
+
17
+ # Load configuration and instances EXACTLY like app.py
18
+ cfg = load_config(CONFIG_PATH)
19
+ print(f"Configuration loaded from {CONFIG_PATH}")
20
+ print(f"config : \n{cfg}")
21
+
22
+ # Import your actual modules exactly as app.py does
23
+ from utils.file_download import download_file_override
24
+
25
+ download_file_override(cfg.get('background_authors_df_url'), cfg.get('background_authors_df_path'))
26
+ download_file_override(cfg.get('instances_to_explain_url'), cfg.get('instances_to_explain_path'))
27
+ download_file_override(cfg.get('gram2vec_feats_url'), cfg.get('gram2vec_feats_path'))
28
+
29
+ from utils.visualizations import get_instances, trigger_precomputed_region, handle_zoom_with_retries
30
+ from utils.ui import update_task_display
31
 
32
  def precompute_all_caches(
33
  models_to_test=None,
34
  instances_to_process=None,
 
35
  ):
36
  """
37
  Precompute all cache files using the EXACT same methods as app.py.
 
46
  'AnnaWegmann/Style-Embedding'
47
  ]
48
 
49
+ print("\n\n" + "=" * 60)
50
  print("CACHE PRECOMPUTATION STARTED")
51
  print(f"Timestamp: {datetime.now()}")
52
  print(f"Models to test: {len(models_to_test)}")
53
  print("=" * 60)
54
+
 
 
 
 
55
  instances, instance_ids = get_instances(cfg['instances_to_explain_path'])
56
  # interp = load_interp_space(cfg)
57
  # clustered_authors_df = interp['clustered_authors_df']
 
80
  for instance_id in tqdm(instances_to_process, desc=f"Processing instances for {model_name.split('/')[-1]}"):
81
  current_combination += 1
82
  try:
83
+ # print(f"\n\n\n[{current_combination}/{total_combinations}] Processing Instance {instance_id}")
84
+ print(f"\n\n\n\033[1m\033[93m>>> [{current_combination}/{total_combinations}] Processing Instance {instance_id} <<<\033[0m\n")
85
+
86
 
87
  # STEP 1: Replicate the exact flow from load_button.click()
88
  print(" β†’ Replicating load_button.click() flow...")
 
92
 
93
  # Call update_task_display EXACTLY like app.py does
94
  task_results = update_task_display(
95
+ mode="Predefined Reddit Task", # Always use predefined for caching
96
  iid=f"Task {instance_id}",
97
  instances=instances,
98
  background_df=clustered_authors_df,
 
147
  if precomputed_regions_state:
148
  regions_dict = ast.literal_eval(precomputed_regions_state)
149
  test_regions = list(regions_dict.keys())
150
+ print(f" β†’ Found {len(test_regions)} regions to test")
151
 
152
  for region_name in test_regions:
153
  try:
 
205
 
206
  if __name__ == "__main__":
207
  # Test with a small subset first
208
+ instances=[i for i in range(20)] # First 10 instances for testing
209
  cache_stats = precompute_all_caches(
210
  models_to_test=[
211
  'AnnaWegmann/Style-Embedding'
utils/gram2vec_feat_utils.py CHANGED
@@ -49,7 +49,17 @@ def get_shorthand(feature_str: str) -> str:
49
  return None
50
  if category not in FEATURE_HANDLERS:
51
  return None
52
- code = load_code_map().get(human)
 
 
 
 
 
 
 
 
 
 
53
  if code is None:
54
  # print(f"Warning: No code found for human-readable feature '{human}'")
55
  return None # fallback to the human-readable name
@@ -78,6 +88,14 @@ def get_fullform(shorthand: str) -> str:
78
  if human is None:
79
  return None
80
 
 
 
 
 
 
 
 
 
81
  return f"{category}:{human}"
82
 
83
  def highlight_both_spans(text, llm_spans, gram_spans):
@@ -154,6 +172,8 @@ def show_combined_spans_all(selected_feature_llm, selected_feature_g2v,
154
  # print(llm_style_feats_analysis)
155
  print(f"{len(llm_style_feats_analysis['spans'].values())}")
156
  author_list = list(llm_style_feats_analysis['spans'].values())
 
 
157
  llm_spans_list = []
158
  for i, (_, txt) in enumerate(texts):
159
  print(f"{i}/{len(texts)}")
@@ -169,8 +189,9 @@ def show_combined_spans_all(selected_feature_llm, selected_feature_g2v,
169
  if selected_feature_g2v and selected_feature_g2v != "None":
170
  # get gram2vec spans
171
  gram_spans_list = []
172
- # clean the display string and get the feature name without the zscore
173
- selected_feature_g2v = selected_feature_g2v.split(" | [Z=")[0].strip()
 
174
  print(f"Selected Gram2Vec feature: {selected_feature_g2v}")
175
  short = get_shorthand(selected_feature_g2v)
176
  print(f"short hand: {short}")
@@ -209,12 +230,12 @@ def show_combined_spans_all(selected_feature_llm, selected_feature_g2v,
209
  bg_start = 4
210
  bg_indices = list(range(bg_start, len(texts)))
211
  kept_indices = [i for i in bg_indices if gram_spans_list[i]]
212
- print(f"\n---> {kept_indices}")
213
  filtered_texts_bg = [texts[i] for i in kept_indices]
214
  filtered_llm_bg = [llm_spans_list[i] for i in kept_indices]
215
  filtered_gram_bg = [gram_spans_list[i] for i in kept_indices]
216
 
217
- print(filtered_texts_bg)
218
 
219
  html_background_authors = create_html(
220
  filtered_texts_bg,
@@ -260,7 +281,7 @@ def get_label(label: str, predicted_author=None, ground_truth_author=None, bg_id
260
  def create_html(texts, llm_spans_list, gram_spans_list, selected_feature_llm, selected_feature_g2v, short=None, background = False, predicted_author=None, ground_truth_author=None):
261
  html = []
262
  for i, (label, txt) in enumerate(texts):
263
- print(i, label, txt[:30])
264
  label = get_label(label, predicted_author, ground_truth_author, i) if background else get_label(label, predicted_author, ground_truth_author)
265
  combined = highlight_both_spans(txt, llm_spans_list[i], gram_spans_list[i])
266
  notice = ""
 
49
  return None
50
  if category not in FEATURE_HANDLERS:
51
  return None
52
+ code_map = load_code_map()
53
+ code = code_map.get(human)
54
+ if code is None:
55
+ # Try normalizing terminology shown in UI
56
+ # Convert 'Preposition' phrasing back to 'Adposition' used in the code map
57
+ human_alt = (human
58
+ .replace("Preposition", "Adposition")
59
+ .replace("preposition", "adposition")
60
+ .replace("Prepositional", "Adpositional")
61
+ .replace("prepositional", "adpositional"))
62
+ code = code_map.get(human_alt)
63
  if code is None:
64
  # print(f"Warning: No code found for human-readable feature '{human}'")
65
  return None # fallback to the human-readable name
 
88
  if human is None:
89
  return None
90
 
91
+ # Normalize terminology for UI: prefer "Preposition" over "Adposition"
92
+ # Also handle potential "adpositional" variants if present
93
+ human = (human
94
+ .replace("Adposition", "Preposition")
95
+ .replace("adposition", "preposition")
96
+ .replace("Adpositional", "Prepositional")
97
+ .replace("adpositional", "prepositional"))
98
+
99
  return f"{category}:{human}"
100
 
101
  def highlight_both_spans(text, llm_spans, gram_spans):
 
172
  # print(llm_style_feats_analysis)
173
  print(f"{len(llm_style_feats_analysis['spans'].values())}")
174
  author_list = list(llm_style_feats_analysis['spans'].values())
175
+ # print(f"Author list length: {len(author_list)}")
176
+ # print(f"Author list: {author_list}")
177
  llm_spans_list = []
178
  for i, (_, txt) in enumerate(texts):
179
  print(f"{i}/{len(texts)}")
 
189
  if selected_feature_g2v and selected_feature_g2v != "None":
190
  # get gram2vec spans
191
  gram_spans_list = []
192
+ # In case any old label formatting with z-scores leaks through, strip it defensively
193
+ if "| [Z=" in selected_feature_g2v:
194
+ selected_feature_g2v = selected_feature_g2v.split(" | [Z=")[0].strip()
195
  print(f"Selected Gram2Vec feature: {selected_feature_g2v}")
196
  short = get_shorthand(selected_feature_g2v)
197
  print(f"short hand: {short}")
 
230
  bg_start = 4
231
  bg_indices = list(range(bg_start, len(texts)))
232
  kept_indices = [i for i in bg_indices if gram_spans_list[i]]
233
+ # print(f"\n---> {kept_indices}")
234
  filtered_texts_bg = [texts[i] for i in kept_indices]
235
  filtered_llm_bg = [llm_spans_list[i] for i in kept_indices]
236
  filtered_gram_bg = [gram_spans_list[i] for i in kept_indices]
237
 
238
+ # print(filtered_texts_bg)
239
 
240
  html_background_authors = create_html(
241
  filtered_texts_bg,
 
281
  def create_html(texts, llm_spans_list, gram_spans_list, selected_feature_llm, selected_feature_g2v, short=None, background = False, predicted_author=None, ground_truth_author=None):
282
  html = []
283
  for i, (label, txt) in enumerate(texts):
284
+ # print(i, label, txt[:30])
285
  label = get_label(label, predicted_author, ground_truth_author, i) if background else get_label(label, predicted_author, ground_truth_author)
286
  combined = highlight_both_spans(txt, llm_spans_list[i], gram_spans_list[i])
287
  notice = ""
utils/interp_space_utils.py CHANGED
@@ -17,16 +17,20 @@ from pydantic import BaseModel
17
  from pydantic import ValidationError
18
  import time
19
  from utils.llm_feat_utils import generate_feature_spans_cached
 
 
20
  from collections import Counter
21
  import numpy as np
22
  from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances
23
  from sklearn.decomposition import PCA
24
 
25
  CACHE_DIR = "datasets/embeddings_cache"
 
26
  ZOOM_CACHE = "datasets/zoom_cache/features_cache.json"
27
  REGION_CACHE = "datasets/region_cache/regions_cache.pkl"
28
  SUMMARY_CACHE = "datasets/summary_cache/summaries.json"
29
  os.makedirs(CACHE_DIR, exist_ok=True)
 
30
  os.makedirs(os.path.dirname(ZOOM_CACHE), exist_ok=True)
31
  os.makedirs(os.path.dirname(REGION_CACHE), exist_ok=True)
32
  # Bump this whenever there is a change etc...
@@ -56,8 +60,8 @@ def compute_g2v_features(clustered_authors_df: pd.DataFrame, task_authors_df: pd
56
  print (f"concatenating task authors and background corpus authors")
57
  print(f"Number of task authors: {len(task_authors_df)}")
58
  print(f"task authors author_ids: {task_authors_df.authorID.tolist()}")
59
- print(f"task authors -->")
60
- print(task_authors_df)
61
  print(f"Number of background corpus authors: {len(clustered_authors_df)}")
62
  clustered_authors_df = pd.concat([task_authors_df, clustered_authors_df])
63
  print(f"Number of authors after concatenation: {len(clustered_authors_df)}")
@@ -65,10 +69,12 @@ def compute_g2v_features(clustered_authors_df: pd.DataFrame, task_authors_df: pd
65
  # Gather the input texts (preserves list-of-strings if any)
66
  #texts = background_corpus_df[text_clm].fillna("").tolist()
67
  author_texts = ['\n\n'.join(x) for x in clustered_authors_df.fullText.tolist()]
68
- print('author_text at 0:{}'.format(author_texts[0]))
69
  print(f"Number of author_texts: {len(author_texts)}")
70
 
71
  # Create a reproducible JSON serialization of the texts
 
 
72
  serialized = json.dumps({
73
  "col": text_clm,
74
  "texts": author_texts
@@ -76,15 +82,20 @@ def compute_g2v_features(clustered_authors_df: pd.DataFrame, task_authors_df: pd
76
 
77
  # Compute MD5 hash
78
  digest = hashlib.md5(serialized.encode("utf-8")).hexdigest()
79
- cache_path = os.path.join(CACHE_DIR, f"{digest}.pkl")
80
 
81
  # If cache hit, load and return
82
  if os.path.exists(cache_path):
83
- print(f"Cache hit...")
 
 
84
  with open(cache_path, "rb") as f:
85
  clustered_authors_df = pickle.load(f)
86
 
87
  else: # Else compute and cache
 
 
 
88
  g2v_feats_df = vectorizer.from_documents(author_texts, batch_size=8)
89
 
90
  print(f"Number of g2v features: {len(g2v_feats_df)}")
@@ -118,6 +129,9 @@ def compute_g2v_features(clustered_authors_df: pd.DataFrame, task_authors_df: pd
118
 
119
  with open(cache_path, "wb") as f:
120
  pickle.dump(clustered_authors_df, f)
 
 
 
121
 
122
  if task_authors_df is not None:
123
  task_authors_df = clustered_authors_df[clustered_authors_df.authorID.isin(task_authors_df.authorID.tolist())]
@@ -268,14 +282,14 @@ def cached_generate_style_embedding(background_corpus_df: pd.DataFrame,
268
 
269
  # If cache hit, load and return
270
  if os.path.exists(cache_path):
271
- print(f"Cache hit for {model_name} on column '{text_clm}'")
272
- print(cache_path)
273
  with open(cache_path, "rb") as f:
274
  background_corpus_df = pickle.load(f)
275
 
276
  else:
277
  # Otherwise, compute, cache, and return
278
- print(f"Computing embeddings for {model_name} on column '{text_clm}', saving to {cache_path}")
279
  task_and_background_embeddings = generate_style_embedding(background_corpus_df, text_clm, model_name, dimensionality_reduction=False)
280
  # Create a clean column name from the model name
281
  col_name = f'{model_name.split("/")[-1]}_style_embedding'
@@ -283,6 +297,7 @@ def cached_generate_style_embedding(background_corpus_df: pd.DataFrame,
283
 
284
  with open(cache_path, "wb") as f:
285
  pickle.dump(background_corpus_df, f)
 
286
 
287
  if task_authors_df is not None:
288
  task_authors_df = background_corpus_df[background_corpus_df.authorID.isin(task_authors_df.authorID.tolist())]
@@ -290,163 +305,167 @@ def cached_generate_style_embedding(background_corpus_df: pd.DataFrame,
290
 
291
  return background_corpus_df, task_authors_df
292
 
293
- def get_style_feats_distribution(documentIDs, style_feats_dict):
294
- style_feats = []
295
- for documentId in documentIDs:
296
- if documentId not in document_to_style_feats:
297
- #print(documentId)
298
- continue
299
-
300
- style_feats+= document_to_style_feats[documentId]
301
-
302
- tfidf = [style_feats.count(key) * val for key, val in style_feats_dict.items()]
303
-
304
- return tfidf
305
-
306
- def get_cluster_top_feats(style_feats_distribution, style_feats_list, top_k=5):
307
- sorted_feats = np.argsort(style_feats_distribution)[::-1]
308
- top_feats = [style_feats_list[x] for x in sorted_feats[:top_k] if style_feats_distribution[x] > 0]
309
- return top_feats
310
-
311
- def compute_clusters_style_representation(
312
- background_corpus_df: pd.DataFrame,
313
- cluster_ids: List[Any],
314
- other_cluster_ids: List[Any],
315
- features_clm_name: str,
316
- cluster_label_clm_name: str = 'cluster_label',
317
- top_n: int = 10
318
- ) -> List[str]:
319
- """
320
- Given a DataFrame with document IDs, cluster IDs, and feature lists,
321
- return the top N features that are most important in the specified `cluster_ids`
322
- while having low importance in `other_cluster_ids`.
323
- Importance is determined by TF-IDF scores. The final score for a feature is
324
- (summed TF-IDF in `cluster_ids`) - (summed TF-IDF in `other_cluster_ids`).
325
-
326
- Parameters:
327
- - background_corpus_df: pd.DataFrame. Must contain the columns specified by
328
- `cluster_label_clm_name` and `features_clm_name`.
329
- The column `features_clm_name` should contain lists of strings (features).
330
- - cluster_ids: List of cluster IDs for which to find representative features (target clusters).
331
- - other_cluster_ids: List of cluster IDs whose features should be down-weighted.
332
- Features prominent in these clusters will have their scores reduced.
333
- Pass an empty list or None if no contrastive clusters are needed.
334
- - features_clm_name: The name of the column in `background_corpus_df` that
335
- contains the list of features for each document.
336
- - cluster_label_clm_name: The name of the column in `background_corpus_df`
337
- that contains the cluster labels. Defaults to 'cluster_label'.
338
- - top_n: Number of top features to return.
339
- Returns:
340
- - List[str]: A list of feature names. These are up to `top_n` features
341
- ranked by their adjusted TF-IDF scores (score in `cluster_ids`
342
- minus score in `other_cluster_ids`). Only features with a final
343
- adjusted score > 0 are included.
344
- """
345
-
346
- assert background_corpus_df[features_clm_name].apply(
347
- lambda x: isinstance(x, list) and all(isinstance(feat, str) for feat in x)
348
- ).all(), f"Column '{features_clm_name}' must contain lists of strings."
349
-
350
- # Compute TF-IDF on the entire corpus
351
- vectorizer = TfidfVectorizer(
352
- tokenizer=lambda x: x,
353
- preprocessor=lambda x: x,
354
- token_pattern=None # Disable default token pattern, treat items in list as tokens
355
- )
356
- tfidf_matrix = vectorizer.fit_transform(background_corpus_df[features_clm_name])
357
- feature_names = vectorizer.get_feature_names_out()
358
-
359
- # Get boolean mask for documents in selected clusters
360
- selected_mask = background_corpus_df[cluster_label_clm_name].isin(cluster_ids).to_numpy()
361
-
362
- if not selected_mask.any():
363
- return [] # No documents found for the given cluster_ids
364
-
365
- # Subset the TF-IDF matrix using the boolean mask
366
- selected_tfidf = tfidf_matrix[selected_mask]
367
-
368
- # Sum TF-IDF scores across documents for each feature in the target clusters
369
- target_feature_scores_sum = selected_tfidf.sum(axis=0).A1 # Convert to 1D array
370
-
371
- # Initialize adjusted scores with target scores
372
- adjusted_feature_scores = target_feature_scores_sum.copy()
373
-
374
- # If other_cluster_ids are provided and not empty, subtract their TF-IDF sums
375
- if other_cluster_ids: # Checks if the list is not None and not empty
376
- other_selected_mask = background_corpus_df[cluster_label_clm_name].isin(other_cluster_ids).to_numpy()
377
-
378
- if other_selected_mask.any():
379
- other_selected_tfidf = tfidf_matrix[other_selected_mask]
380
- contrast_feature_scores_sum = other_selected_tfidf.sum(axis=0).A1
 
 
 
381
 
382
- # Element-wise subtraction; assumes feature_names aligns for both sums
383
- adjusted_feature_scores -= contrast_feature_scores_sum
384
-
385
- # Map scores to feature names
386
- feature_score_dict = dict(zip(feature_names, adjusted_feature_scores))
387
- # Sort features by score
388
- sorted_features = sorted(feature_score_dict.items(), key=lambda item: item[1], reverse=True)
389
-
390
- # Return the names of the top_n features that have a score > 0
391
- top_features = [feature for feature, score in sorted_features if score > 0][:top_n]
392
-
393
- return top_features
394
-
395
- def compute_clusters_style_representation_2(
396
- background_corpus_df: pd.DataFrame,
397
- cluster_ids: List[Any],
398
- cluster_label_clm_name: str = 'cluster_label',
399
- max_num_feats: int = 5,
400
- max_num_documents_per_author=3,
401
- max_num_authors=5):
402
- """
403
- Call openAI to analyze the common writing style features of the given list of texts
404
- """
405
- client = OpenAI(base_url=os.getenv("OPENAI_BASE_URL", None), pi_key=os.getenv("OPENAI_API_KEY"))
406
-
407
- background_corpus_df['fullText'] = background_corpus_df['fullText'].map(lambda x: '\n\n'.join(x[:max_num_documents_per_author]) if isinstance(x, list) else x)
408
- background_corpus_df = background_corpus_df[background_corpus_df[cluster_label_clm_name].isin(cluster_ids)]
 
409
 
410
- author_texts = background_corpus_df['fullText'].tolist()[:max_num_authors]
411
- author_texts = "\n\n".join(["""Author {}:\n""".format(i+1) + text for i, text in enumerate(author_texts)])
412
- author_names = background_corpus_df[cluster_label_clm_name].tolist()[:max_num_authors]
413
- print(f"Number of authors: {len(background_corpus_df)}")
414
- print(author_names)
415
- print(author_texts)
416
- print(f"Number of authors: {len(author_names)}")
417
- print(f"Number of authors: {len(author_texts)}")
418
-
419
- prompt = f"""First identify a list of {max_num_feats} writing style features that are common between the given texts. Second for every author text and style feature, extract all spans that represent the feature. Output for every author all style features with their spans.
420
- Author Texts:
421
- \"\"\"{author_texts}\"\"\"
422
- """
423
-
424
- # Compute MD5 hash
425
- digest = hashlib.md5(prompt.encode("utf-8")).hexdigest()
426
- cache_path = os.path.join(CACHE_DIR, f"{digest}.pkl")
427
-
428
- # If cache hit, load and return
429
- if os.path.exists(cache_path):
430
- print(f"Loading authors writing style from cache ...")
431
- with open(cache_path, "rb") as f:
432
- parsed_response = pickle.load(f)
433
 
434
- else: # Else compute and cache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
435
 
436
- response = client.chat.completions.create(
437
- model="gpt-4o",
438
- messages=[
439
- {"role":"assistant","content":"You are a forensic linguistic who knows how to analyze similarites in writing styles."},
440
- {"role":"user","content":prompt}],
441
- response_format={"type": "json_schema", "json_schema": {"name": "style_analysis_schema", "schema": to_strict_json_schema(style_analysis_schema)}}
442
- )
443
 
444
- parsed_response = json.loads(response.choices[0].message.content)
445
 
446
- with open(cache_path, "wb") as f:
447
- pickle.dump(parsed_response, f)
448
 
449
- return parsed_response
450
 
451
  def generate_cache_key(author_names: List[str], max_num_feats: int) -> str:
452
  """Generate a unique cache key based on author names and max features"""
@@ -472,10 +491,11 @@ def identify_style_features(author_texts: list[str], author_names: list[str], ma
472
 
473
  if cache_key in cache:
474
  print(f"\nCache hit! Using cached features for authors: {author_names}")
 
475
  return cache[cache_key]["features"]
476
  else:
477
- print(f"Cache miss. Computing features for authors: {author_names}")
478
-
479
  client = OpenAI(base_url=os.getenv("OPENAI_BASE_URL", None), api_key=os.getenv("OPENAI_API_KEY"))
480
  prompt = f"""Identify {max_num_feats} writing style features that are common between the authors texts.
481
  Author Texts:
@@ -483,9 +503,9 @@ def identify_style_features(author_texts: list[str], author_names: list[str], ma
483
  {author_texts}
484
  """
485
 
486
- print('==================>>>>>>>>>>')
487
- print(prompt)
488
- print('==================>>>>>>>>>>')
489
  def _make_call():
490
  response = client.chat.completions.create(
491
  model="gpt-4o",
@@ -512,6 +532,8 @@ def identify_style_features(author_texts: list[str], author_names: list[str], ma
512
  # save_cache(cache)
513
  with open(ZOOM_CACHE, 'w') as f:
514
  json.dump(cache, f, indent=2)
 
 
515
 
516
  print(f"Cached features for authors: {author_names}")
517
 
@@ -540,7 +562,7 @@ def extract_all_spans(authors_df: pd.DataFrame, features: list[str], cluster_lab
540
 
541
  for _, row in authors_df.iterrows():
542
  author_name = str(row[cluster_label_clm_name])
543
- print(author_name)
544
  role = f"{author_name}"
545
  full_text = row['fullText']
546
  spans = generate_feature_spans_cached(client, full_text, features, role)
@@ -569,18 +591,18 @@ def compute_clusters_style_representation_3(
569
  author_texts = "\n\n".join(["""Author {}:\n""".format(i+1) + text for i, text in enumerate(author_texts)])
570
  author_names = background_corpus_df_feat_id[cluster_label_clm_name].tolist()[:max_num_authors]
571
  print(f"Number of authors: {len(background_corpus_df_feat_id)}")
572
- print(author_names)
573
  features = identify_style_features(author_texts, author_names, max_num_feats=max_num_feats)
574
 
575
  if return_only_feats:
576
  return features
577
 
578
- print("Features: ", features)
579
  # STEP 2: Prepare author pool for span extraction
580
  span_df = background_corpus_df.iloc[:max_authors_for_span_extraction]
581
  author_names = span_df[cluster_label_clm_name].tolist()[:max_authors_for_span_extraction]
582
  print(f"Number of authors for span detection : {len(span_df)}")
583
- print(author_names)
584
  spans_by_author = extract_all_spans(span_df, features, cluster_label_clm_name)
585
 
586
  # Filter-in only task authors that are part of the current selection
@@ -597,7 +619,7 @@ def compute_clusters_style_representation_3(
597
  for feature, spans in feature_map.items():
598
  if spans:
599
  feature_importance[feature] -= len(spans)
600
- print(feature_importance)
601
  selected_features_ranked = sorted(feature_importance, key=lambda f: -feature_importance[f])[:int(top_k)]
602
 
603
  #print('filtered set of features (min coverage', len(author_present_feature_sets), '): ', selected_features_ranked)
@@ -716,6 +738,7 @@ def compute_clusters_g2v_representation(
716
  other_author_ids: List[Any],
717
  features_clm_name: str,
718
  top_n: int = 10,
 
719
  ) -> List[tuple]: # Changed return type to List[tuple] to include scores
720
 
721
  # 1) Identify selected authors in the zoom region
@@ -749,67 +772,114 @@ def compute_clusters_g2v_representation(
749
  # 5) Rank features by mean z-score, keep positives only
750
  feature_scores = [(feat, float(score)) for feat, score in zip(all_features, selected_mean) if score > 0]
751
  feature_scores.sort(key=lambda x: x[1], reverse=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
752
 
753
- return feature_scores[:top_n]
754
-
755
- def generate_interpretable_space_representation(interp_space_path, styles_df_path, feat_clm, output_clm, num_feats=5):
756
 
757
- styles_df = pd.read_csv(styles_df_path)[[feat_clm, "documentID"]]
758
 
759
- # A dictionary of style features and their IDF
760
- style_feats_agg_df = styles_df.groupby(feat_clm).agg({'documentID': lambda x : len(list(x))}).reset_index()
761
- style_feats_agg_df['document_freq'] = style_feats_agg_df.documentID
762
- style_to_feats_dfreq = {x[0]: math.log(styles_df.documentID.nunique()/x[1]) for x in zip(style_feats_agg_df[feat_clm].tolist(), style_feats_agg_df.document_freq.tolist())}
763
 
764
- # A list of style features we work with
765
- style_feats_list = style_feats_agg_df[feat_clm].tolist()
766
- print('Number of style feats ', len(style_feats_list))
767
 
768
- # A list of documents and what list of style features each has
769
- doc_style_agg_df = styles_df.groupby('documentID').agg({feat_clm: lambda x : list(x)}).reset_index()
770
- document_to_feats_dict = {x[0]: x[1] for x in zip(doc_style_agg_df.documentID.tolist(), doc_style_agg_df[feat_clm].tolist())}
771
 
772
 
773
 
774
- # Load the clustering information
775
- df = pd.read_pickle(interp_space_path)
776
- df = df[df.cluster_label != -1]
777
- # A cluster to list of documents
778
- clusterd_df = df.groupby('cluster_label').agg({
779
- 'documentID': lambda x: [d_id for doc_ids in x for d_id in doc_ids]
780
- }).reset_index()
781
 
782
- # Filter-in only documents that has a style description
783
- clusterd_df['documentID'] = clusterd_df.documentID.apply(lambda documentIDs: [documentID for documentID in documentIDs if documentID in document_to_feats_dict])
784
- # Map from cluster label to list of features through the document information
785
- clusterd_df[feat_clm] = clusterd_df.documentID.apply(lambda doc_ids: [f for d_id in doc_ids for f in document_to_feats_dict[d_id]])
786
-
787
- def compute_tfidf(row):
788
- style_counts = Counter(row[feat_clm])
789
- total_num_styles = sum(style_counts.values())
790
- #print(style_counts, total_num_styles)
791
- style_distribution = {
792
- style: math.log(1+count) * style_to_feats_dfreq[style] if style in style_to_feats_dfreq else 0 for style, count in style_counts.items()
793
- } #TF-IDF
794
 
795
- return style_distribution
796
 
797
- def create_tfidf_rep(tfidf_dist, num_feats):
798
- style_feats = sorted(tfidf_dist.items(), key=lambda x: -x[1])
799
- top_k_feats = [x[0] for x in style_feats[:num_feats] if str(x[0]) != 'nan']
800
- return top_k_feats
801
 
802
- clusterd_df[output_clm +'_dist'] = clusterd_df.apply(lambda row: compute_tfidf(row), axis=1)
803
- clusterd_df[output_clm] = clusterd_df[output_clm +'_dist'].apply(lambda dist: create_tfidf_rep(dist, num_feats))
804
 
805
 
806
- return clusterd_df
807
 
808
  def compute_predicted_author(task_authors_df: pd.DataFrame, col_name: str) -> int:
809
  """
810
  Computes the predicted author based on the style features.
811
  """
812
- print("Computing predicted author using LUAR-MUD-style embeddings...")
813
 
814
  # Extract LUAR embeddings from task authors dataframe
815
  mystery_embedding = np.array(task_authors_df.iloc[0][col_name]).reshape(1, -1)
@@ -850,11 +920,11 @@ def compute_precomputed_regions(bg_proj, bg_ids, q_proj, c_proj, pred_idx, model
850
  else:
851
  cache = {}
852
  if key in cache:
853
- print(f"\nCache hit! Using cached regions.")
854
  return cache[key]
855
  else:
856
- print(f"Cache miss. Computing regions.")
857
-
858
  regions = {}
859
 
860
  # All points for distance calculation (mystery + candidates + background)
@@ -969,22 +1039,23 @@ def compute_precomputed_regions(bg_proj, bg_ids, q_proj, c_proj, pred_idx, model
969
  response = json.dumps(serializable_regions, default=str)
970
  cache[key] = response
971
  with open(REGION_CACHE, 'wb') as f:
 
972
  pickle.dump(cache, f)
973
 
974
  return response
975
 
976
- if __name__ == "__main__":
977
- background_corpus = pd.read_pickle('../datasets/luar_interp_space_cluster_19/train_authors.pkl')
978
- print(background_corpus.columns)
979
- print(background_corpus[['authorID', 'fullText', 'cluster_label']].head())
980
- # # Example: Find features for clusters [2,3,4] that are NOT prominent in cluster [1]
981
- # feats = compute_clusters_style_representation(
982
- # background_corpus_df=background_corpus,
983
- # cluster_ids=['00005a5c-5c06-3a36-37f9-53c6422a31d8',],
984
- # other_cluster_ids=[], # Pass the contrastive cluster IDs here
985
- # cluster_label_clm_name='authorID',
986
- # features_clm_name='final_attribute_name'
987
- # )
988
- # print(feats)
989
- generate_style_embedding(background_corpus, 'fullText', 'AnnaWegmann/Style-Embedding')
990
- print(background_corpus.columns)
 
17
  from pydantic import ValidationError
18
  import time
19
  from utils.llm_feat_utils import generate_feature_spans_cached
20
+ from utils.gram2vec_feat_utils import get_shorthand, get_fullform
21
+ from gram2vec.feature_locator import find_feature_spans
22
  from collections import Counter
23
  import numpy as np
24
  from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances
25
  from sklearn.decomposition import PCA
26
 
27
  CACHE_DIR = "datasets/embeddings_cache"
28
+ G2V_CACHE = "datasets/gram2vec_cache"
29
  ZOOM_CACHE = "datasets/zoom_cache/features_cache.json"
30
  REGION_CACHE = "datasets/region_cache/regions_cache.pkl"
31
  SUMMARY_CACHE = "datasets/summary_cache/summaries.json"
32
  os.makedirs(CACHE_DIR, exist_ok=True)
33
+ os.makedirs(G2V_CACHE, exist_ok=True)
34
  os.makedirs(os.path.dirname(ZOOM_CACHE), exist_ok=True)
35
  os.makedirs(os.path.dirname(REGION_CACHE), exist_ok=True)
36
  # Bump this whenever there is a change etc...
 
60
  print (f"concatenating task authors and background corpus authors")
61
  print(f"Number of task authors: {len(task_authors_df)}")
62
  print(f"task authors author_ids: {task_authors_df.authorID.tolist()}")
63
+ # print(f"task authors -->")
64
+ # print(task_authors_df)
65
  print(f"Number of background corpus authors: {len(clustered_authors_df)}")
66
  clustered_authors_df = pd.concat([task_authors_df, clustered_authors_df])
67
  print(f"Number of authors after concatenation: {len(clustered_authors_df)}")
 
69
  # Gather the input texts (preserves list-of-strings if any)
70
  #texts = background_corpus_df[text_clm].fillna("").tolist()
71
  author_texts = ['\n\n'.join(x) for x in clustered_authors_df.fullText.tolist()]
72
+ # print('author_text at 0:{}'.format(author_texts[0]))
73
  print(f"Number of author_texts: {len(author_texts)}")
74
 
75
  # Create a reproducible JSON serialization of the texts
76
+ # why are g2v features going into a new file inside embeddings_cache?
77
+ # changed to G2V_CACHE
78
  serialized = json.dumps({
79
  "col": text_clm,
80
  "texts": author_texts
 
82
 
83
  # Compute MD5 hash
84
  digest = hashlib.md5(serialized.encode("utf-8")).hexdigest()
85
+ cache_path = os.path.join(G2V_CACHE, f"{digest}.pkl")
86
 
87
  # If cache hit, load and return
88
  if os.path.exists(cache_path):
89
+ # print(f"Cache hit...")
90
+ # Making this green to make it stand out from rest of the logs
91
+ print(f"\n\n\n\033[1m\033[92m>>> Cache hit for {cache_path} <<<\033[0m\n")
92
  with open(cache_path, "rb") as f:
93
  clustered_authors_df = pickle.load(f)
94
 
95
  else: # Else compute and cache
96
+ # Making this red to make it stand out from rest of the logs
97
+ print(f"\n\n\n\033[1m\033[91m>>> Cache miss for {cache_path} => Computing fresh!! <<<\033[0m\n")
98
+
99
  g2v_feats_df = vectorizer.from_documents(author_texts, batch_size=8)
100
 
101
  print(f"Number of g2v features: {len(g2v_feats_df)}")
 
129
 
130
  with open(cache_path, "wb") as f:
131
  pickle.dump(clustered_authors_df, f)
132
+ # Making this green to make it stand out from rest of the logs
133
+ print(f"\n\n\n\033[1m\033[92m>>> Saved to {cache_path} <<<\033[0m\n")
134
+ # the file generated here contains g2v + style embeddings.
135
 
136
  if task_authors_df is not None:
137
  task_authors_df = clustered_authors_df[clustered_authors_df.authorID.isin(task_authors_df.authorID.tolist())]
 
282
 
283
  # If cache hit, load and return
284
  if os.path.exists(cache_path):
285
+ # Making this green to make it stand out from rest of the logs
286
+ print(f"\n\n\n\033[1m\033[92m>>> Cache hit for {cache_path} for {model_name} on column '{text_clm} <<<\033[0m\n")
287
  with open(cache_path, "rb") as f:
288
  background_corpus_df = pickle.load(f)
289
 
290
  else:
291
  # Otherwise, compute, cache, and return
292
+ print(f"\n\n\n\033[1m\033[91m>>> Cache miss for {cache_path} for {model_name} on column '{text_clm} <<<\033[0m\n")
293
  task_and_background_embeddings = generate_style_embedding(background_corpus_df, text_clm, model_name, dimensionality_reduction=False)
294
  # Create a clean column name from the model name
295
  col_name = f'{model_name.split("/")[-1]}_style_embedding'
 
297
 
298
  with open(cache_path, "wb") as f:
299
  pickle.dump(background_corpus_df, f)
300
+ print(f"\n\n\n\033[1m\033[92m>>> Cache saved for {cache_path} for {model_name} on column '{text_clm} <<<\033[0m\n")
301
 
302
  if task_authors_df is not None:
303
  task_authors_df = background_corpus_df[background_corpus_df.authorID.isin(task_authors_df.authorID.tolist())]
 
305
 
306
  return background_corpus_df, task_authors_df
307
 
308
+ # Noticed the following function isnt actually referenced anywhere.
309
+ # def get_style_feats_distribution(documentIDs, style_feats_dict):
310
+ # style_feats = []
311
+ # for documentId in documentIDs:
312
+ # if documentId not in document_to_style_feats:
313
+ # #print(documentId)
314
+ # continue
315
+
316
+ # style_feats+= document_to_style_feats[documentId]
317
+
318
+ # tfidf = [style_feats.count(key) * val for key, val in style_feats_dict.items()]
319
+
320
+ # return tfidf
321
+ #
322
+ # Noticed the following function isnt actually referenced anywhere.
323
+ # def get_cluster_top_feats(style_feats_distribution, style_feats_list, top_k=5):
324
+ # sorted_feats = np.argsort(style_feats_distribution)[::-1]
325
+ # top_feats = [style_feats_list[x] for x in sorted_feats[:top_k] if style_feats_distribution[x] > 0]
326
+ # return top_feats
327
+
328
+ # Noticed the following function isnt actually referenced anywhere.
329
+ # def compute_clusters_style_representation(
330
+ # background_corpus_df: pd.DataFrame,
331
+ # cluster_ids: List[Any],
332
+ # other_cluster_ids: List[Any],
333
+ # features_clm_name: str,
334
+ # cluster_label_clm_name: str = 'cluster_label',
335
+ # top_n: int = 10
336
+ # ) -> List[str]:
337
+ # """
338
+ # Given a DataFrame with document IDs, cluster IDs, and feature lists,
339
+ # return the top N features that are most important in the specified `cluster_ids`
340
+ # while having low importance in `other_cluster_ids`.
341
+ # Importance is determined by TF-IDF scores. The final score for a feature is
342
+ # (summed TF-IDF in `cluster_ids`) - (summed TF-IDF in `other_cluster_ids`).
343
+
344
+ # Parameters:
345
+ # - background_corpus_df: pd.DataFrame. Must contain the columns specified by
346
+ # `cluster_label_clm_name` and `features_clm_name`.
347
+ # The column `features_clm_name` should contain lists of strings (features).
348
+ # - cluster_ids: List of cluster IDs for which to find representative features (target clusters).
349
+ # - other_cluster_ids: List of cluster IDs whose features should be down-weighted.
350
+ # Features prominent in these clusters will have their scores reduced.
351
+ # Pass an empty list or None if no contrastive clusters are needed.
352
+ # - features_clm_name: The name of the column in `background_corpus_df` that
353
+ # contains the list of features for each document.
354
+ # - cluster_label_clm_name: The name of the column in `background_corpus_df`
355
+ # that contains the cluster labels. Defaults to 'cluster_label'.
356
+ # - top_n: Number of top features to return.
357
+ # Returns:
358
+ # - List[str]: A list of feature names. These are up to `top_n` features
359
+ # ranked by their adjusted TF-IDF scores (score in `cluster_ids`
360
+ # minus score in `other_cluster_ids`). Only features with a final
361
+ # adjusted score > 0 are included.
362
+ # """
363
+
364
+ # assert background_corpus_df[features_clm_name].apply(
365
+ # lambda x: isinstance(x, list) and all(isinstance(feat, str) for feat in x)
366
+ # ).all(), f"Column '{features_clm_name}' must contain lists of strings."
367
+
368
+ # # Compute TF-IDF on the entire corpus
369
+ # vectorizer = TfidfVectorizer(
370
+ # tokenizer=lambda x: x,
371
+ # preprocessor=lambda x: x,
372
+ # token_pattern=None # Disable default token pattern, treat items in list as tokens
373
+ # )
374
+ # tfidf_matrix = vectorizer.fit_transform(background_corpus_df[features_clm_name])
375
+ # feature_names = vectorizer.get_feature_names_out()
376
+
377
+ # # Get boolean mask for documents in selected clusters
378
+ # selected_mask = background_corpus_df[cluster_label_clm_name].isin(cluster_ids).to_numpy()
379
+
380
+ # if not selected_mask.any():
381
+ # return [] # No documents found for the given cluster_ids
382
+
383
+ # # Subset the TF-IDF matrix using the boolean mask
384
+ # selected_tfidf = tfidf_matrix[selected_mask]
385
+
386
+ # # Sum TF-IDF scores across documents for each feature in the target clusters
387
+ # target_feature_scores_sum = selected_tfidf.sum(axis=0).A1 # Convert to 1D array
388
+
389
+ # # Initialize adjusted scores with target scores
390
+ # adjusted_feature_scores = target_feature_scores_sum.copy()
391
+
392
+ # # If other_cluster_ids are provided and not empty, subtract their TF-IDF sums
393
+ # if other_cluster_ids: # Checks if the list is not None and not empty
394
+ # other_selected_mask = background_corpus_df[cluster_label_clm_name].isin(other_cluster_ids).to_numpy()
395
+
396
+ # if other_selected_mask.any():
397
+ # other_selected_tfidf = tfidf_matrix[other_selected_mask]
398
+ # contrast_feature_scores_sum = other_selected_tfidf.sum(axis=0).A1
399
 
400
+ # # Element-wise subtraction; assumes feature_names aligns for both sums
401
+ # adjusted_feature_scores -= contrast_feature_scores_sum
402
+
403
+ # # Map scores to feature names
404
+ # feature_score_dict = dict(zip(feature_names, adjusted_feature_scores))
405
+ # # Sort features by score
406
+ # sorted_features = sorted(feature_score_dict.items(), key=lambda item: item[1], reverse=True)
407
+
408
+ # # Return the names of the top_n features that have a score > 0
409
+ # top_features = [feature for feature, score in sorted_features if score > 0][:top_n]
410
+
411
+ # return top_features
412
+
413
+ # Noticed the following function isnt actually referenced anywhere.
414
+ # def compute_clusters_style_representation_2(
415
+ # background_corpus_df: pd.DataFrame,
416
+ # cluster_ids: List[Any],
417
+ # cluster_label_clm_name: str = 'cluster_label',
418
+ # max_num_feats: int = 5,
419
+ # max_num_documents_per_author=3,
420
+ # max_num_authors=5):
421
+ # """
422
+ # Call openAI to analyze the common writing style features of the given list of texts
423
+ # """
424
+ # client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
425
+
426
+ # background_corpus_df['fullText'] = background_corpus_df['fullText'].map(lambda x: '\n\n'.join(x[:max_num_documents_per_author]) if isinstance(x, list) else x)
427
+ # background_corpus_df = background_corpus_df[background_corpus_df[cluster_label_clm_name].isin(cluster_ids)]
428
 
429
+ # author_texts = background_corpus_df['fullText'].tolist()[:max_num_authors]
430
+ # author_texts = "\n\n".join(["""Author {}:\n""".format(i+1) + text for i, text in enumerate(author_texts)])
431
+ # author_names = background_corpus_df[cluster_label_clm_name].tolist()[:max_num_authors]
432
+ # print(f"Number of authors: {len(background_corpus_df)}")
433
+ # print(author_names)
434
+ # print(author_texts)
435
+ # print(f"Number of authors: {len(author_names)}")
436
+ # print(f"Number of authors: {len(author_texts)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
437
 
438
+ # prompt = f"""First identify a list of {max_num_feats} writing style features that are common between the given texts. Second for every author text and style feature, extract all spans that represent the feature. Output for every author all style features with their spans.
439
+ # Author Texts:
440
+ # \"\"\"{author_texts}\"\"\"
441
+ # """
442
+
443
+ # # Compute MD5 hash
444
+ # digest = hashlib.md5(prompt.encode("utf-8")).hexdigest()
445
+ # cache_path = os.path.join(CACHE_DIR, f"{digest}.pkl")
446
+
447
+ # # If cache hit, load and return
448
+ # if os.path.exists(cache_path):
449
+ # print(f"Loading authors writing style from cache ...")
450
+ # with open(cache_path, "rb") as f:
451
+ # parsed_response = pickle.load(f)
452
+
453
+ # else: # Else compute and cache
454
 
455
+ # response = client.chat.completions.create(
456
+ # model="gpt-4o-mini",
457
+ # messages=[
458
+ # {"role":"assistant","content":"You are a forensic linguistic who knows how to analyze similarites in writing styles."},
459
+ # {"role":"user","content":prompt}],
460
+ # response_format={"type": "json_schema", "json_schema": {"name": "style_analysis_schema", "schema": to_strict_json_schema(style_analysis_schema)}}
461
+ # )
462
 
463
+ # parsed_response = json.loads(response.choices[0].message.content)
464
 
465
+ # with open(cache_path, "wb") as f:
466
+ # pickle.dump(parsed_response, f)
467
 
468
+ # return parsed_response
469
 
470
  def generate_cache_key(author_names: List[str], max_num_feats: int) -> str:
471
  """Generate a unique cache key based on author names and max features"""
 
491
 
492
  if cache_key in cache:
493
  print(f"\nCache hit! Using cached features for authors: {author_names}")
494
+ print(f"\n\n\n\033[1m\033[92m>>> Cache hit for {cache_key} in {ZOOM_CACHE} <<<\033[0m\n")
495
  return cache[cache_key]["features"]
496
  else:
497
+ print(f"\n\n\n\033[1m\033[91m>>> Cache miss for {cache_key} in {ZOOM_CACHE} \nComputing features for authors: {author_names}<<<\033[0m\n")
498
+
499
  client = OpenAI(base_url=os.getenv("OPENAI_BASE_URL", None), api_key=os.getenv("OPENAI_API_KEY"))
500
  prompt = f"""Identify {max_num_feats} writing style features that are common between the authors texts.
501
  Author Texts:
 
503
  {author_texts}
504
  """
505
 
506
+ # print('==================>>>>>>>>>>')
507
+ # print(prompt)
508
+ # print('==================>>>>>>>>>>')
509
  def _make_call():
510
  response = client.chat.completions.create(
511
  model="gpt-4o",
 
532
  # save_cache(cache)
533
  with open(ZOOM_CACHE, 'w') as f:
534
  json.dump(cache, f, indent=2)
535
+ print(f"\n\n\n\033[1m\033[92m>>> Cache saved for {cache_key} in {ZOOM_CACHE}<<<\033[0m\n")
536
+
537
 
538
  print(f"Cached features for authors: {author_names}")
539
 
 
562
 
563
  for _, row in authors_df.iterrows():
564
  author_name = str(row[cluster_label_clm_name])
565
+ # print(author_name)
566
  role = f"{author_name}"
567
  full_text = row['fullText']
568
  spans = generate_feature_spans_cached(client, full_text, features, role)
 
591
  author_texts = "\n\n".join(["""Author {}:\n""".format(i+1) + text for i, text in enumerate(author_texts)])
592
  author_names = background_corpus_df_feat_id[cluster_label_clm_name].tolist()[:max_num_authors]
593
  print(f"Number of authors: {len(background_corpus_df_feat_id)}")
594
+ # print(author_names)
595
  features = identify_style_features(author_texts, author_names, max_num_feats=max_num_feats)
596
 
597
  if return_only_feats:
598
  return features
599
 
600
+ #print("Features: ", features)
601
  # STEP 2: Prepare author pool for span extraction
602
  span_df = background_corpus_df.iloc[:max_authors_for_span_extraction]
603
  author_names = span_df[cluster_label_clm_name].tolist()[:max_authors_for_span_extraction]
604
  print(f"Number of authors for span detection : {len(span_df)}")
605
+ # print(author_names)
606
  spans_by_author = extract_all_spans(span_df, features, cluster_label_clm_name)
607
 
608
  # Filter-in only task authors that are part of the current selection
 
619
  for feature, spans in feature_map.items():
620
  if spans:
621
  feature_importance[feature] -= len(spans)
622
+ # print(feature_importance)
623
  selected_features_ranked = sorted(feature_importance, key=lambda f: -feature_importance[f])[:int(top_k)]
624
 
625
  #print('filtered set of features (min coverage', len(author_present_feature_sets), '): ', selected_features_ranked)
 
738
  other_author_ids: List[Any],
739
  features_clm_name: str,
740
  top_n: int = 10,
741
+ max_candidates_for_span_sorting: int = 50,
742
  ) -> List[tuple]: # Changed return type to List[tuple] to include scores
743
 
744
  # 1) Identify selected authors in the zoom region
 
772
  # 5) Rank features by mean z-score, keep positives only
773
  feature_scores = [(feat, float(score)) for feat, score in zip(all_features, selected_mean) if score > 0]
774
  feature_scores.sort(key=lambda x: x[1], reverse=True)
775
+
776
+ # 6) Extract top candidates for span-based sorting
777
+ candidate_features = feature_scores[:max_candidates_for_span_sorting]
778
+
779
+ # 7) Extract spans for task authors to sort by frequency
780
+ task_author_names = {'Mystery author', 'Candidate Author 1', 'Candidate Author 2', 'Candidate Author 3'}
781
+ task_authors_in_selection = [aid for aid in author_ids if aid in task_author_names]
782
+
783
+ if not task_authors_in_selection:
784
+ # If no task authors in selection, just return the z-score sorted features
785
+ print("[INFO] No task authors in selection, returning z-score sorted features")
786
+ return feature_scores[:top_n]
787
+
788
+ # Get task author data
789
+ task_authors_df = background_corpus_df[background_corpus_df['authorID'].isin(task_authors_in_selection)]
790
+
791
+ # Count spans for each feature across task authors
792
+ feature_span_counts = {}
793
+ for feat_shorthand, z_score in candidate_features:
794
+ span_count = 0
795
+
796
+ # Convert shorthand to human-readable for display (if needed)
797
+ # Note: features in gram2vec dict are in shorthand format like "pos_unigrams:ADJ"
798
+
799
+ for _, author_row in task_authors_df.iterrows():
800
+ author_text = author_row['fullText']
801
+ if isinstance(author_text, list):
802
+ author_text = '\n\n'.join(author_text)
803
+
804
+ try:
805
+ # find_feature_spans expects shorthand format like "pos_unigrams:ADJ"
806
+ spans = find_feature_spans(author_text, feat_shorthand)
807
+ span_count += len(spans)
808
+ except Exception as e:
809
+ # If span extraction fails, continue with 0 spans for this author
810
+ pass
811
+
812
+ feature_span_counts[feat_shorthand] = span_count
813
+
814
+ # 8) Sort features by span frequency, then by z-score as tiebreaker
815
+ sorted_by_spans = sorted(
816
+ candidate_features,
817
+ key=lambda x: (-feature_span_counts.get(x[0], 0), -x[1])
818
+ )
819
+
820
+ # print(f"[INFO] Sorted gram2vec features by span frequency: {[(f, feature_span_counts.get(f, 0), z) for f, z in sorted_by_spans[:top_n]]}")
821
+
822
+ return sorted_by_spans[:top_n]
823
 
824
+ # Noticed the following function isnt actually referenced anywhere.
825
+ # def generate_interpretable_space_representation(interp_space_path, styles_df_path, feat_clm, output_clm, num_feats=5):
 
826
 
827
+ # styles_df = pd.read_csv(styles_df_path)[[feat_clm, "documentID"]]
828
 
829
+ # # A dictionary of style features and their IDF
830
+ # style_feats_agg_df = styles_df.groupby(feat_clm).agg({'documentID': lambda x : len(list(x))}).reset_index()
831
+ # style_feats_agg_df['document_freq'] = style_feats_agg_df.documentID
832
+ # style_to_feats_dfreq = {x[0]: math.log(styles_df.documentID.nunique()/x[1]) for x in zip(style_feats_agg_df[feat_clm].tolist(), style_feats_agg_df.document_freq.tolist())}
833
 
834
+ # # A list of style features we work with
835
+ # style_feats_list = style_feats_agg_df[feat_clm].tolist()
836
+ # print('Number of style feats ', len(style_feats_list))
837
 
838
+ # # A list of documents and what list of style features each has
839
+ # doc_style_agg_df = styles_df.groupby('documentID').agg({feat_clm: lambda x : list(x)}).reset_index()
840
+ # document_to_feats_dict = {x[0]: x[1] for x in zip(doc_style_agg_df.documentID.tolist(), doc_style_agg_df[feat_clm].tolist())}
841
 
842
 
843
 
844
+ # # Load the clustering information
845
+ # df = pd.read_pickle(interp_space_path)
846
+ # df = df[df.cluster_label != -1]
847
+ # # A cluster to list of documents
848
+ # clusterd_df = df.groupby('cluster_label').agg({
849
+ # 'documentID': lambda x: [d_id for doc_ids in x for d_id in doc_ids]
850
+ # }).reset_index()
851
 
852
+ # # Filter-in only documents that has a style description
853
+ # clusterd_df['documentID'] = clusterd_df.documentID.apply(lambda documentIDs: [documentID for documentID in documentIDs if documentID in document_to_feats_dict])
854
+ # # Map from cluster label to list of features through the document information
855
+ # clusterd_df[feat_clm] = clusterd_df.documentID.apply(lambda doc_ids: [f for d_id in doc_ids for f in document_to_feats_dict[d_id]])
856
+
857
+ # def compute_tfidf(row):
858
+ # style_counts = Counter(row[feat_clm])
859
+ # total_num_styles = sum(style_counts.values())
860
+ # #print(style_counts, total_num_styles)
861
+ # style_distribution = {
862
+ # style: math.log(1+count) * style_to_feats_dfreq[style] if style in style_to_feats_dfreq else 0 for style, count in style_counts.items()
863
+ # } #TF-IDF
864
 
865
+ # return style_distribution
866
 
867
+ # def create_tfidf_rep(tfidf_dist, num_feats):
868
+ # style_feats = sorted(tfidf_dist.items(), key=lambda x: -x[1])
869
+ # top_k_feats = [x[0] for x in style_feats[:num_feats] if str(x[0]) != 'nan']
870
+ # return top_k_feats
871
 
872
+ # clusterd_df[output_clm +'_dist'] = clusterd_df.apply(lambda row: compute_tfidf(row), axis=1)
873
+ # clusterd_df[output_clm] = clusterd_df[output_clm +'_dist'].apply(lambda dist: create_tfidf_rep(dist, num_feats))
874
 
875
 
876
+ # return clusterd_df
877
 
878
  def compute_predicted_author(task_authors_df: pd.DataFrame, col_name: str) -> int:
879
  """
880
  Computes the predicted author based on the style features.
881
  """
882
+ print("Computing predicted author using embeddings...")
883
 
884
  # Extract LUAR embeddings from task authors dataframe
885
  mystery_embedding = np.array(task_authors_df.iloc[0][col_name]).reshape(1, -1)
 
920
  else:
921
  cache = {}
922
  if key in cache:
923
+ print(f"\n\n\n\033[1m\033[92m>>> Cache hit for {key} in {REGION_CACHE}: Using cached regions<<<\033[0m\n")
924
  return cache[key]
925
  else:
926
+ print(f"\n\n\n\033[1m\033[91m>>> Cache miss for {key} in {REGION_CACHE}: Computing Regions<<<\033[0m\n")
927
+
928
  regions = {}
929
 
930
  # All points for distance calculation (mystery + candidates + background)
 
1039
  response = json.dumps(serializable_regions, default=str)
1040
  cache[key] = response
1041
  with open(REGION_CACHE, 'wb') as f:
1042
+ print(f"\n\n\n\033[1m\033[92m>>> Cache saved for {key} in {REGION_CACHE} <<<\033[0m\n")
1043
  pickle.dump(cache, f)
1044
 
1045
  return response
1046
 
1047
+ # if __name__ == "__main__":
1048
+ # background_corpus = pd.read_pickle('../datasets/luar_interp_space_cluster_19/train_authors.pkl')
1049
+ # print(background_corpus.columns)
1050
+ # print(background_corpus[['authorID', 'fullText', 'cluster_label']].head())
1051
+ # # # Example: Find features for clusters [2,3,4] that are NOT prominent in cluster [1]
1052
+ # # feats = compute_clusters_style_representation(
1053
+ # # background_corpus_df=background_corpus,
1054
+ # # cluster_ids=['00005a5c-5c06-3a36-37f9-53c6422a31d8',],
1055
+ # # other_cluster_ids=[], # Pass the contrastive cluster IDs here
1056
+ # # cluster_label_clm_name='authorID',
1057
+ # # features_clm_name='final_attribute_name'
1058
+ # # )
1059
+ # # print(feats)
1060
+ # generate_style_embedding(background_corpus, 'fullText', 'AnnaWegmann/Style-Embedding')
1061
+ # print(background_corpus.columns)
utils/llm_feat_utils.py CHANGED
@@ -32,19 +32,20 @@ def generate_feature_spans(client, text: str, features: list[str]) -> str:
32
  """
33
  Call to OpenAI to extract spans. Returns a JSON string.
34
  """
 
 
 
35
  prompt = f"""You are a linguistic specialist. Given a writing sample and a list of descriptive features, identify the exact text spans that demonstrate each feature.
36
 
37
  Important:
38
  - The headers like "Document 1:" etc are NOT part of the original text β€” ignore them.
39
  - For each feature, even if there is no match, return an empty list.
40
  - Only return exact phrases from the text.
 
41
 
42
- Respond in JSON format like:
43
- {{
44
- "feature1": ["span1", "span2"],
45
- "feature2": [],
46
- …
47
- }}
48
 
49
  Text:
50
  \"\"\"{text}\"\"\"
@@ -52,9 +53,9 @@ def generate_feature_spans(client, text: str, features: list[str]) -> str:
52
  Style Features:
53
  {features}
54
  """
55
- print('==================>>>>>>>>>>')
56
- print(prompt)
57
- print('==================>>>>>>>>>>')
58
  response = client.chat.completions.create(
59
  model="gpt-4o",
60
  messages=[{"role":"user","content":prompt}]
@@ -71,8 +72,14 @@ def generate_feature_spans_with_retries(client, text: str, features: list[str])
71
  for attempt in range(MAX_ATTEMPTS):
72
  try:
73
  response_str = generate_feature_spans(client, text, features)
74
- print(response_str)
75
  result = json.loads(response_str)
 
 
 
 
 
 
76
  return result
77
  except (JSONDecodeError, ValueError) as e:
78
  print(f"Attempt {attempt+1} failed: {e}")
@@ -116,7 +123,13 @@ def generate_feature_spans_cached(client, text: str, features: list[str], role:
116
  if h in cache:
117
  # print(f"Found feature: {feat}")
118
  found_feats_count += 1
119
- result[feat] = cache[h]["spans"]
 
 
 
 
 
 
120
  else:
121
  # print(f"Missing feature: {feat}")
122
  missing_feats_count += 1
 
32
  """
33
  Call to OpenAI to extract spans. Returns a JSON string.
34
  """
35
+ # For some of the longer features, openai client was truncating the feature names, resulting in downstream errors.
36
+ # Adding structured JSON template to ensure all features are included properly.
37
+ features_json_template = {feature: [] for feature in features}
38
  prompt = f"""You are a linguistic specialist. Given a writing sample and a list of descriptive features, identify the exact text spans that demonstrate each feature.
39
 
40
  Important:
41
  - The headers like "Document 1:" etc are NOT part of the original text β€” ignore them.
42
  - For each feature, even if there is no match, return an empty list.
43
  - Only return exact phrases from the text.
44
+ - Use the EXACT feature names as JSON keys - do not paraphrase or shorten them.
45
 
46
+
47
+ Respond in this EXACT JSON format (use these exact keys, populate the lists with the extracted text spans):
48
+ {json.dumps(features_json_template, indent=2)}
 
 
 
49
 
50
  Text:
51
  \"\"\"{text}\"\"\"
 
53
  Style Features:
54
  {features}
55
  """
56
+ # print('==================>>>>>>>>>>')
57
+ # print(prompt)
58
+ # print('==================>>>>>>>>>>')
59
  response = client.chat.completions.create(
60
  model="gpt-4o",
61
  messages=[{"role":"user","content":prompt}]
 
72
  for attempt in range(MAX_ATTEMPTS):
73
  try:
74
  response_str = generate_feature_spans(client, text, features)
75
+ # print(response_str)
76
  result = json.loads(response_str)
77
+ # Additional check to ensure all requested features are present in the response correctly
78
+ if result.keys() != set(features):
79
+ print("Response keys do not match requested features. Retrying!")
80
+ response_str = generate_feature_spans(client, text, features)
81
+ # print(response_str)
82
+ result = json.loads(response_str)
83
  return result
84
  except (JSONDecodeError, ValueError) as e:
85
  print(f"Attempt {attempt+1} failed: {e}")
 
123
  if h in cache:
124
  # print(f"Found feature: {feat}")
125
  found_feats_count += 1
126
+ if cache[h]["spans"] is None:
127
+ print(f"Missing feature: {feat}")
128
+ missing_feats_count += 1
129
+ missing_feats.append(feat)
130
+ else:
131
+ result[feat] = cache[h]["spans"]
132
+
133
  else:
134
  # print(f"Missing feature: {feat}")
135
  missing_feats_count += 1
utils/ui.py CHANGED
@@ -81,14 +81,14 @@ def read_txt(f):
81
  def toggle_task(mode):
82
  print(mode)
83
  return (
84
- gr.update(visible=(mode == "Predefined HRS Task")),
85
  gr.update(visible=(mode == "Upload Your Own Task"))
86
  )
87
 
88
  # Update displayed texts based on mode
89
  def update_task_display(mode, iid, instances, background_df, mystery_file, cand1_file, cand2_file, cand3_file, true_author, model_radio, custom_model_input):
90
  model_name = model_radio if model_radio != "Other" else custom_model_input
91
- if mode == "Predefined HRS Task":
92
  iid = int(iid.replace('Task ', ''))
93
  data = instances[iid]
94
  ground_truth_author = 100#data['gt_idx']
 
81
  def toggle_task(mode):
82
  print(mode)
83
  return (
84
+ gr.update(visible=(mode == "Predefined Reddit Task")),
85
  gr.update(visible=(mode == "Upload Your Own Task"))
86
  )
87
 
88
  # Update displayed texts based on mode
89
  def update_task_display(mode, iid, instances, background_df, mystery_file, cand1_file, cand2_file, cand3_file, true_author, model_radio, custom_model_input):
90
  model_name = model_radio if model_radio != "Other" else custom_model_input
91
+ if mode == "Predefined Reddit Task":
92
  iid = int(iid.replace('Task ', ''))
93
  data = instances[iid]
94
  ground_truth_author = 100#data['gt_idx']
utils/visualizations.py CHANGED
@@ -309,7 +309,7 @@ def handle_zoom(event_json, bg_proj, bg_lbls, clustered_authors_df, task_authors
309
 
310
  task_texts = [_to_text(x) for x in task_only_df['fullText'].tolist()]
311
 
312
- print(f"task_texts: {task_texts}")
313
  filtered_g2v_feats = []
314
  for feat in g2v_feats:
315
  try:
@@ -333,7 +333,7 @@ def handle_zoom(event_json, bg_proj, bg_lbls, clustered_authors_df, task_authors
333
  HR_g2v_list = []
334
  for feat in filtered_g2v_feats:
335
  HR_g2v = get_fullform(feat[0])
336
- print(f"\n\n feat: {feat} ---> Human Readable: {HR_g2v}")
337
  if HR_g2v is None:
338
  print(f"Skipping Gram2Vec feature without human readable form: {feat}")
339
  else:
@@ -342,11 +342,11 @@ def handle_zoom(event_json, bg_proj, bg_lbls, clustered_authors_df, task_authors
342
  HR_g2v_list = [("None", None)] + HR_g2v_list
343
 
344
  print(f"[INFO] Found {len(llm_feats)} LLM features and {len(g2v_feats)} Gram2Vec features in the zoomed region.")
345
- print(f"[INFO] unfiltered g2v features: {g2v_feats}")
346
 
347
  print(f"[INFO] LLM features: {llm_feats}")
348
  HR_g2v_list, _ = format_g2v_features_for_display(HR_g2v_list)
349
- print(f"[INFO] Gram2Vec features: {HR_g2v_list}")
350
 
351
  return (
352
  gr.update(choices=llm_feats, value=llm_feats[0]),
@@ -386,7 +386,7 @@ def handle_zoom_with_retries(event_json, bg_proj, bg_lbls, clustered_authors_df,
386
  def visualize_clusters_plotly(iid, cfg, instances, model_radio, custom_model_input, task_authors_df, background_authors_embeddings_df, pred_idx=None, gt_idx=None):
387
  model_name = model_radio if model_radio != "Other" else custom_model_input
388
  embedding_col_name = f'{model_name.split("/")[-1]}_style_embedding'
389
- print(background_authors_embeddings_df.columns)
390
  print("Generating cluster visualization")
391
  iid = int(iid)
392
  #interp = load_interp_space(cfg)
 
309
 
310
  task_texts = [_to_text(x) for x in task_only_df['fullText'].tolist()]
311
 
312
+ print(f"len task_texts: {len(task_texts)}")
313
  filtered_g2v_feats = []
314
  for feat in g2v_feats:
315
  try:
 
333
  HR_g2v_list = []
334
  for feat in filtered_g2v_feats:
335
  HR_g2v = get_fullform(feat[0])
336
+ # print(f"\n\n feat: {feat} ---> Human Readable: {HR_g2v}")
337
  if HR_g2v is None:
338
  print(f"Skipping Gram2Vec feature without human readable form: {feat}")
339
  else:
 
342
  HR_g2v_list = [("None", None)] + HR_g2v_list
343
 
344
  print(f"[INFO] Found {len(llm_feats)} LLM features and {len(g2v_feats)} Gram2Vec features in the zoomed region.")
345
+ # print(f"[INFO] unfiltered g2v features: {g2v_feats}")
346
 
347
  print(f"[INFO] LLM features: {llm_feats}")
348
  HR_g2v_list, _ = format_g2v_features_for_display(HR_g2v_list)
349
+ # print(f"[INFO] Gram2Vec features: {HR_g2v_list}")
350
 
351
  return (
352
  gr.update(choices=llm_feats, value=llm_feats[0]),
 
386
  def visualize_clusters_plotly(iid, cfg, instances, model_radio, custom_model_input, task_authors_df, background_authors_embeddings_df, pred_idx=None, gt_idx=None):
387
  model_name = model_radio if model_radio != "Other" else custom_model_input
388
  embedding_col_name = f'{model_name.split("/")[-1]}_style_embedding'
389
+ # print(background_authors_embeddings_df.columns)
390
  print("Generating cluster visualization")
391
  iid = int(iid)
392
  #interp = load_interp_space(cfg)