Milad Alshomary commited on
Commit
a5e49c0
·
1 Parent(s): dcbbcbd
README.md CHANGED
@@ -13,3 +13,23 @@ short_description: Interpreting the latent space of Authorship Attribution
13
  ---
14
 
15
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  ---
14
 
15
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
16
+
17
+
18
+ ## Useful commands
19
+
20
+ ### Prepare data training/test
21
+
22
+
23
+
24
+ ### Clustering the background corpus
25
+
26
+ python cluster_corpus.py ../../iarpa-hiatus/explanation_tool_files/reddit_cluster_training.pkl ../../iarpa-hiatus/explanation_tool_files/reddit_cluster_test.pkl "AnnaWegmann/Style-Embedding" ./datasets/reddit_clustered_authors.pkl --min_samples 2 --metric cosine --pca_dimensions 100 --eps 0.04
27
+
28
+ ### Generate explainability sample
29
+
30
+ python prepare_data.py ../explanation_tool_files/reddit_cluster_test.pkl ./datasets/reddit_explanation_sample.json
31
+
32
+
33
+ ### Generate static explanations for a sample
34
+
35
+ python baseline_static_explanations.py generate_explanations ./datasets/reddit_explanation_sample.json ./datasets/reddit_explanation_sample_with_explanations.json --interp_space_path ./datasets/reddit_interp_space.json --model_name 'AnnaWegmann/Style-Embedding'
app.py CHANGED
@@ -42,7 +42,7 @@ from utils.interp_space_utils import *
42
  from utils.ui import *
43
 
44
  load_dotenv()
45
- client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
46
 
47
 
48
  # ── load once at startup ────────────────────────────────────────
 
42
  from utils.ui import *
43
 
44
  load_dotenv()
45
+ client = OpenAI(base_url=os.getenv("OPENAI_API_BASE"), api_key=os.getenv("OPENAI_API_KEY"))
46
 
47
 
48
  # ── load once at startup ────────────────────────────────────────
baseline_static_explanations.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import pandas as pd
3
+ import numpy as np
4
+ import os, json
5
+
6
+ from utils.interp_space_utils import cached_generate_style_embedding
7
+ from utils.clustering_utils import clustering_author
8
+ from utils.interp_space_utils import compute_clusters_style_representation_3, summarize_style_features_to_paragraph, find_closest_cluster_style
9
+
10
+ from sklearn.metrics.pairwise import cosine_distances, cosine_similarity
11
+
12
+
13
+ def build_static_interp_space(cluster_df):
14
+ """
15
+ Takes a dataframe with cluster_label indicates every author's cluster and return a
16
+ json file with key the cluster_label and value containing the style-embedding representation and the style description
17
+
18
+ Example cluster_df
19
+ fullText authorID Style-Embedding_style_embedding cluster_label
20
+ 4 [I've play them all (D3, Torchlight 1&2, P... HaxRyter [0.7126333904811682, -0.5076461933032986, -0.1... 0
21
+ 10 [Back in Texas. Buddy had a kid in an up and ... OaklandHellBent [0.11238726238181786, 0.9263576185812101, -0.2... 1
22
+
23
+ """
24
+ # Find the embedding column (assuming it's the only one ending with '_style_embedding')
25
+ embedding_clm = next((col for col in cluster_df.columns if col.endswith('_style_embedding')), None)
26
+ if not embedding_clm:
27
+ raise ValueError("No style embedding column found in the DataFrame.")
28
+
29
+ print(f"Using embedding column: {embedding_clm}")
30
+
31
+ # Group by cluster label and calculate the average embedding for each cluster
32
+ # We also aggregate authorIDs to use them for style representation
33
+ cluster_groups = cluster_df.groupby('cluster_label').agg({
34
+ embedding_clm: lambda embs: np.mean(np.vstack(embs), axis=0).tolist(),
35
+ 'authorID': list
36
+ }).reset_index()
37
+
38
+ interpretable_space = {}
39
+
40
+ for _, row in cluster_groups.iterrows():
41
+ cluster_label = row['cluster_label']
42
+ avg_embedding = row[embedding_clm]
43
+ author_ids_in_cluster = row['authorID']
44
+
45
+ print(f"\nProcessing cluster {cluster_label} with {len(author_ids_in_cluster)} authors...")
46
+
47
+ # Generate style description using an LLM
48
+ # We reuse the utility function from the interactive tool for consistency
49
+ style_analysis = compute_clusters_style_representation_3(
50
+ background_corpus_df=cluster_df,
51
+ cluster_ids=author_ids_in_cluster,
52
+ cluster_label_clm_name='authorID',
53
+ max_num_feats=5, # Requesting 5 top features
54
+ max_num_authors=20, # Use up to 20 authors from the cluster for analysis
55
+ return_only_feats=True
56
+ )
57
+
58
+ # When return_only_feats=True, style_analysis is a list of features
59
+ style_features_list = style_analysis
60
+ print(f" Generated style features: {style_features_list}")
61
+
62
+ # Summarize the list of features into a coherent paragraph
63
+ style_paragraph = summarize_style_features_to_paragraph(style_features_list)
64
+ print(f" Summarized paragraph: {style_paragraph}")
65
+
66
+ # JSON cannot serialize numpy integers, so convert cluster_label
67
+ interpretable_space[int(cluster_label)] = (avg_embedding, style_paragraph)
68
+
69
+ return interpretable_space
70
+
71
+ def generate_explanations(args):
72
+ input_file = args.input_file
73
+ interp_space_path = args.interp_space_path
74
+ output_file = args.output_file
75
+ model_name = args.model_name if args.model_name else 'AnnaWegmann/Style-Embedding'
76
+
77
+ instances_for_ex = json.load(open(input_file))
78
+ interp_space = json.load(open(interp_space_path))
79
+
80
+ output = []
81
+ for instance in instances_for_ex:
82
+ json_obj = {}
83
+ json_obj['Q_authorID'] = instance['Q_authorID']
84
+ json_obj['Q_fullText'] = instance['Q_fullText']
85
+ style_descirption, q_embeddings = find_closest_cluster_style(instance['Q_fullText'], interp_space, model_name=model_name)
86
+ json_obj['Q_top_style_feats'] = style_descirption
87
+
88
+ json_obj['a0_authorID'] = instance['a0_authorID']
89
+ json_obj['a0_fullText'] = instance['a0_fullText']
90
+ style_descirption, a0_embeddings = find_closest_cluster_style(instance['a0_fullText'], interp_space, model_name=model_name)
91
+ json_obj['a0_top_style_feats'] = style_descirption
92
+
93
+ json_obj['a1_authorID'] = instance['a1_authorID']
94
+ json_obj['a1_fullText'] = instance['a1_fullText']
95
+ style_descirption, a1_embeddings = find_closest_cluster_style(instance['a1_fullText'], interp_space, model_name=model_name)
96
+ json_obj['a1_top_style_feats'] = style_descirption
97
+
98
+ json_obj['a2_authorID'] = instance['a2_authorID']
99
+ json_obj['a2_fullText'] = instance['a2_fullText']
100
+ style_descirption, a2_embeddings = find_closest_cluster_style(instance['a2_fullText'], interp_space, model_name=model_name)
101
+ json_obj['a2_top_style_feats'] = style_descirption
102
+
103
+
104
+ # Compute pairwise similarity between q_embeddings and all a_embeddings
105
+ # Ensure embeddings are 2D arrays for cosine_similarity
106
+ q_emb_2d = np.array(q_embeddings).reshape(1, -1)
107
+ a0_emb_2d = np.array(a0_embeddings).reshape(1, -1)
108
+ a1_emb_2d = np.array(a1_embeddings).reshape(1, -1)
109
+ a2_emb_2d = np.array(a2_embeddings).reshape(1, -1)
110
+
111
+ similarity_q_a0 = cosine_similarity(q_emb_2d, a0_emb_2d)[0][0]
112
+ similarity_q_a1 = cosine_similarity(q_emb_2d, a1_emb_2d)[0][0]
113
+ similarity_q_a2 = cosine_similarity(q_emb_2d, a2_emb_2d)[0][0]
114
+
115
+ ranked_candidates = [
116
+ {'authorID': instance['a0_authorID'], 'similarity': float(similarity_q_a0)},
117
+ {'authorID': instance['a1_authorID'], 'similarity': float(similarity_q_a1)},
118
+ {'authorID': instance['a2_authorID'], 'similarity': float(similarity_q_a2)},
119
+ ]
120
+
121
+ json_obj['latent_rank'] = np.argsort([x['similarity'] for x in ranked_candidates]).tolist()
122
+ json_obj['model_pred'] = 'Candidate {}'.format(json_obj['latent_rank'][0] + 1)
123
+
124
+
125
+
126
+ output.append(json_obj)
127
+
128
+ json.dump(output, open(output_file, 'w'), indent=4)
129
+
130
+
131
+
132
+
133
+ def main():
134
+ """
135
+ Main function to generate and save the static interpretable space.
136
+ """
137
+
138
+ parser = argparse.ArgumentParser(
139
+ description="Build a static interpretable space from clustered author data."
140
+ )
141
+
142
+ parser.add_argument(
143
+ "task",
144
+ type=str,
145
+ help="task: one of the following: build_static_interp_space, generate_explanations",
146
+ choices=["build_static_interp_space", "generate_explanations"]
147
+ )
148
+
149
+ parser.add_argument(
150
+ "input_file",
151
+ type=str,
152
+ help="Path to the input clustered DataFrame (.pkl file)."
153
+ )
154
+
155
+ parser.add_argument(
156
+ "output_file",
157
+ type=str,
158
+ help="file to save the output"
159
+ )
160
+
161
+ parser.add_argument(
162
+ "--interp_space_path",
163
+ type=str,
164
+ help="Path to the input interpretable space(.pkl file)."
165
+ )
166
+
167
+ parser.add_argument(
168
+ "--model_name",
169
+ type=str,
170
+ help="style analysis model name"
171
+ )
172
+
173
+ args = parser.parse_args()
174
+
175
+ if args.task == "build_static_interp_space":
176
+ return build_and_save_static_interp_space(args)
177
+ elif args.task == "generate_explanations":
178
+ return generate_explanations(args)
179
+ else:
180
+ raise ValueError(f"Unknown task: {args.task}")
181
+
182
+
183
+ def build_and_save_static_interp_space(args):
184
+ print(f"Loading clustered data from {args.input_file}...")
185
+ clustered_df = pd.read_pickle(args.input_file)
186
+
187
+ interpretable_space = build_static_interp_space(clustered_df)
188
+
189
+ print(f"\nSaving interpretable space to {args.output_file}...")
190
+ with open(args.output_file, 'w') as f:
191
+ json.dump(interpretable_space, f, indent=4)
192
+
193
+ print("Done.")
194
+
195
+ if __name__ == "__main__":
196
+ main()
prepare_data.py CHANGED
@@ -44,6 +44,9 @@ def sample_ds(input_file, output_file, num_insts=10000, min_num_text_per_inst=0,
44
  df = pd.DataFrame(out_list)
45
  df.to_pickle(output_file)
46
 
 
 
 
47
  def get_reddit_data(input_path, random_seed=123, num_instances=100, num_documents_per_author=8, min_instance_len=10):
48
 
49
  df = pd.read_pickle(open(input_path, 'rb'))
 
44
  df = pd.DataFrame(out_list)
45
  df.to_pickle(output_file)
46
 
47
+ df = df.explode('fullText').reset_index()
48
+ df.to_json(output_file.replace('.pkl', '.json'))
49
+
50
  def get_reddit_data(input_path, random_seed=123, num_instances=100, num_documents_per_author=8, min_instance_len=10):
51
 
52
  df = pd.read_pickle(open(input_path, 'rb'))
utils/clustering_utils.py CHANGED
@@ -128,6 +128,7 @@ def clustering_author(background_corpus_df: pd.DataFrame,
128
  return background_corpus_df
129
 
130
  X = np.array(X_list) # Creates a 2D array from the list of 1D arrays
 
131
 
132
  if X.shape[0] == 1:
133
  print("Only one valid embedding found. Assigning cluster label 0 to it.")
@@ -279,6 +280,9 @@ def clustering_author(background_corpus_df: pd.DataFrame,
279
  print("No suitable DBSCAN clustering found meeting criteria. All processed embeddings marked as noise (-1).")
280
 
281
  background_corpus_df['cluster_label'] = final_labels_for_df
 
 
 
282
  return background_corpus_df
283
 
284
 
 
128
  return background_corpus_df
129
 
130
  X = np.array(X_list) # Creates a 2D array from the list of 1D arrays
131
+ original_embeddings_list = [embeddings_list[i] for i in original_indices]
132
 
133
  if X.shape[0] == 1:
134
  print("Only one valid embedding found. Assigning cluster label 0 to it.")
 
280
  print("No suitable DBSCAN clustering found meeting criteria. All processed embeddings marked as noise (-1).")
281
 
282
  background_corpus_df['cluster_label'] = final_labels_for_df
283
+ # restore the original embedding
284
+ print(original_embeddings_list[0].shape)
285
+ background_corpus_df[embedding_clm] = original_embeddings_list
286
  return background_corpus_df
287
 
288
 
utils/interp_space_utils.py CHANGED
@@ -25,6 +25,7 @@ from sklearn.decomposition import PCA
25
  CACHE_DIR = "datasets/embeddings_cache"
26
  ZOOM_CACHE = "datasets/zoom_cache/features_cache.json"
27
  REGION_CACHE = "datasets/region_cache/regions_cache.pkl"
 
28
  os.makedirs(CACHE_DIR, exist_ok=True)
29
  os.makedirs(os.path.dirname(ZOOM_CACHE), exist_ok=True)
30
  os.makedirs(os.path.dirname(REGION_CACHE), exist_ok=True)
@@ -41,6 +42,9 @@ class FeatureIdentificationSchema(BaseModel):
41
  class SpanExtractionSchema(BaseModel):
42
  spans: dict[str, dict[str, list[str]]] # {author_name: {feature: [spans]}}
43
 
 
 
 
44
 
45
 
46
  def compute_g2v_features(clustered_authors_df: pd.DataFrame, task_authors_df: pd.DataFrame=None, text_clm='fullText') -> pd.DataFrame:
@@ -398,7 +402,7 @@ def compute_clusters_style_representation_2(
398
  """
399
  Call openAI to analyze the common writing style features of the given list of texts
400
  """
401
- client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
402
 
403
  background_corpus_df['fullText'] = background_corpus_df['fullText'].map(lambda x: '\n\n'.join(x[:max_num_documents_per_author]) if isinstance(x, list) else x)
404
  background_corpus_df = background_corpus_df[background_corpus_df[cluster_label_clm_name].isin(cluster_ids)]
@@ -430,7 +434,7 @@ def compute_clusters_style_representation_2(
430
  else: # Else compute and cache
431
 
432
  response = client.chat.completions.create(
433
- model="gpt-4o-mini",
434
  messages=[
435
  {"role":"assistant","content":"You are a forensic linguistic who knows how to analyze similarites in writing styles."},
436
  {"role":"user","content":prompt}],
@@ -472,7 +476,7 @@ def identify_style_features(author_texts: list[str], author_names: list[str], ma
472
  else:
473
  print(f"Cache miss. Computing features for authors: {author_names}")
474
 
475
- client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
476
  prompt = f"""Identify {max_num_feats} writing style features that are common between the authors texts.
477
  Author Texts:
478
 
@@ -530,7 +534,7 @@ def extract_all_spans(authors_df: pd.DataFrame, features: list[str], cluster_lab
530
  For each author, use `generate_feature_spans_cached` to get feature->span mappings.
531
  Returns a dict: {author_name: {feature: [spans]}}
532
  """
533
- client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
534
 
535
  spans_by_author = {}
536
 
@@ -552,7 +556,8 @@ def compute_clusters_style_representation_3(
552
  max_num_documents_per_author=10,
553
  max_num_authors=10,
554
  max_authors_for_span_extraction=4,
555
- top_k: int = 10
 
556
  ):
557
 
558
  print(f"Computing style representation for visible clusters: {len(cluster_ids)}")
@@ -567,6 +572,9 @@ def compute_clusters_style_representation_3(
567
  print(author_names)
568
  features = identify_style_features(author_texts, author_names, max_num_feats=max_num_feats)
569
 
 
 
 
570
  print("Features: ", features)
571
  # STEP 2: Prepare author pool for span extraction
572
  span_df = background_corpus_df.iloc[:max_authors_for_span_extraction]
@@ -577,34 +585,6 @@ def compute_clusters_style_representation_3(
577
 
578
  # Filter-in only task authors that are part of the current selection
579
  task_author_names = {'Mystery author', 'Candidate Author 1', 'Candidate Author 2', 'Candidate Author 3'}
580
- #filtered_task_authors = {author: feat_map for author, feat_map in spans_by_author.items() if author in task_author_names.intersection(set(cluster_ids))}
581
-
582
- # Build per-author sets of features that have at least one span
583
- # author_present_feature_sets = [
584
- # {feature for feature, spans in feature_map.items() if spans and len(spans) > 0}
585
- # for _, feature_map in filtered_task_authors.items()
586
- # ]
587
-
588
- # print(filtered_task_authors.keys(), author_present_feature_sets)
589
-
590
-
591
- # if len(author_present_feature_sets) > 0: # we have more than one task author
592
- # coverage_counter = Counter()
593
- # for present_set in author_present_feature_sets:
594
- # coverage_counter.update(present_set)
595
-
596
- # # Keep features present in at least `min_authors_required` authors
597
- # eligible_features = [feat for feat, cnt in coverage_counter.items() if cnt >= len(author_present_feature_sets)]
598
-
599
- # # Preserve original LLM feature ordering as a secondary key where possible
600
- # feature_original_index = {feat: idx for idx, feat in enumerate(features)} if features else {}
601
-
602
- # selected_features_ranked = sorted(
603
- # eligible_features,
604
- # key=lambda f: (-coverage_counter[f], feature_original_index.get(f, 10**9))
605
- # )[:int(top_k)]
606
- # else:
607
- # selected_features_ranked = features
608
 
609
 
610
  feature_importance = {f : 0 for f in features}
@@ -627,6 +607,109 @@ def compute_clusters_style_representation_3(
627
  "spans": spans_by_author
628
  }
629
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
630
  def compute_clusters_g2v_representation(
631
  background_corpus_df: pd.DataFrame,
632
  author_ids: List[Any],
 
25
  CACHE_DIR = "datasets/embeddings_cache"
26
  ZOOM_CACHE = "datasets/zoom_cache/features_cache.json"
27
  REGION_CACHE = "datasets/region_cache/regions_cache.pkl"
28
+ SUMMARY_CACHE = "datasets/summary_cache/summaries.json"
29
  os.makedirs(CACHE_DIR, exist_ok=True)
30
  os.makedirs(os.path.dirname(ZOOM_CACHE), exist_ok=True)
31
  os.makedirs(os.path.dirname(REGION_CACHE), exist_ok=True)
 
42
  class SpanExtractionSchema(BaseModel):
43
  spans: dict[str, dict[str, list[str]]] # {author_name: {feature: [spans]}}
44
 
45
+ class StyleSummarySchema(BaseModel):
46
+ summary_paragraph: str
47
+
48
 
49
 
50
  def compute_g2v_features(clustered_authors_df: pd.DataFrame, task_authors_df: pd.DataFrame=None, text_clm='fullText') -> pd.DataFrame:
 
402
  """
403
  Call openAI to analyze the common writing style features of the given list of texts
404
  """
405
+ client = OpenAI(base_url=os.getenv("OPENAI_BASE_URL", None), pi_key=os.getenv("OPENAI_API_KEY"))
406
 
407
  background_corpus_df['fullText'] = background_corpus_df['fullText'].map(lambda x: '\n\n'.join(x[:max_num_documents_per_author]) if isinstance(x, list) else x)
408
  background_corpus_df = background_corpus_df[background_corpus_df[cluster_label_clm_name].isin(cluster_ids)]
 
434
  else: # Else compute and cache
435
 
436
  response = client.chat.completions.create(
437
+ model="gpt-4o",
438
  messages=[
439
  {"role":"assistant","content":"You are a forensic linguistic who knows how to analyze similarites in writing styles."},
440
  {"role":"user","content":prompt}],
 
476
  else:
477
  print(f"Cache miss. Computing features for authors: {author_names}")
478
 
479
+ client = OpenAI(base_url=os.getenv("OPENAI_BASE_URL", None), api_key=os.getenv("OPENAI_API_KEY"))
480
  prompt = f"""Identify {max_num_feats} writing style features that are common between the authors texts.
481
  Author Texts:
482
 
 
534
  For each author, use `generate_feature_spans_cached` to get feature->span mappings.
535
  Returns a dict: {author_name: {feature: [spans]}}
536
  """
537
+ client = OpenAI(base_url=os.getenv("OPENAI_BASE_URL", None), api_key=os.getenv("OPENAI_API_KEY"))
538
 
539
  spans_by_author = {}
540
 
 
556
  max_num_documents_per_author=10,
557
  max_num_authors=10,
558
  max_authors_for_span_extraction=4,
559
+ top_k: int = 10,
560
+ return_only_feats= False,
561
  ):
562
 
563
  print(f"Computing style representation for visible clusters: {len(cluster_ids)}")
 
572
  print(author_names)
573
  features = identify_style_features(author_texts, author_names, max_num_feats=max_num_feats)
574
 
575
+ if return_only_feats:
576
+ return features
577
+
578
  print("Features: ", features)
579
  # STEP 2: Prepare author pool for span extraction
580
  span_df = background_corpus_df.iloc[:max_authors_for_span_extraction]
 
585
 
586
  # Filter-in only task authors that are part of the current selection
587
  task_author_names = {'Mystery author', 'Candidate Author 1', 'Candidate Author 2', 'Candidate Author 3'}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
588
 
589
 
590
  feature_importance = {f : 0 for f in features}
 
607
  "spans": spans_by_author
608
  }
609
 
610
+ def summarize_style_features_to_paragraph(features: list[str]) -> str:
611
+ """
612
+ Takes a list of writing style features and uses an LLM to generate a
613
+ coherent, descriptive paragraph summarizing the style.
614
+
615
+ Args:
616
+ features (list[str]): A list of style features.
617
+
618
+ Returns:
619
+ str: A single paragraph summarizing the writing style.
620
+ """
621
+ if not features:
622
+ return "No style features were identified for this selection."
623
+
624
+ # Generate a cache key based on the sorted features to ensure consistency
625
+ feature_key = hashlib.md5(json.dumps(sorted(features)).encode()).hexdigest()
626
+
627
+ os.makedirs(os.path.dirname(SUMMARY_CACHE), exist_ok=True)
628
+ if os.path.exists(SUMMARY_CACHE):
629
+ with open(SUMMARY_CACHE, 'r') as f:
630
+ try:
631
+ cache = json.load(f)
632
+ except json.JSONDecodeError:
633
+ cache = {}
634
+ else:
635
+ cache = {}
636
+
637
+ if feature_key in cache:
638
+ print(f"Cache hit for style summary. Key: {feature_key}")
639
+ return cache[feature_key]
640
+
641
+ print(f"Cache miss for style summary. Generating new summary...")
642
+ client = OpenAI(base_url=os.getenv("OPENAI_BASE_URL", None), api_key=os.getenv("OPENAI_API_KEY"))
643
+
644
+ feature_list_str = "\n".join([f"- {feat}" for feat in features])
645
+ prompt = f"""You are a linguistic analyst. Your task is to synthesize the following list of writing style features into a single, coherent, and descriptive paragraph. The paragraph should flow naturally and explain the overall writing style of an author based on these features. Be concise and only mention the features without referring to example spans.
646
+
647
+ Style Features:
648
+ {feature_list_str}
649
+
650
+ Please provide the summary as a single paragraph.
651
+ """
652
+
653
+ def _make_call():
654
+ response = client.chat.completions.create(
655
+ model="gpt-4o",
656
+ messages=[{"role": "user", "content": prompt}],
657
+ response_format={"type": "json_schema", "json_schema": {"name": "StyleSummarySchema", "schema": to_strict_json_schema(StyleSummarySchema)}}
658
+ )
659
+ return json.loads(response.choices[0].message.content)
660
+
661
+ summary_paragraph = retry_call(_make_call, StyleSummarySchema).summary_paragraph
662
+
663
+ # Save to cache
664
+ cache[feature_key] = summary_paragraph
665
+ with open(SUMMARY_CACHE, 'w') as f:
666
+ json.dump(cache, f, indent=2)
667
+
668
+ return summary_paragraph
669
+
670
+ def find_closest_cluster_style(texts: list[str], interp_space, model_name: str) -> str:
671
+ """
672
+ Computes the average embedding for a list of texts and finds the most similar
673
+ cluster from the interpretable space, returning its style description.
674
+
675
+ Args:
676
+ texts (list[str]): A list of texts for which to find a style description.
677
+ interp_space_path (str): Path to the interpretable_space.json file.
678
+ model_name (str): The name of the sentence transformer model to use for embeddings.
679
+
680
+ Returns:
681
+ str: The style description paragraph of the most similar cluster.
682
+ """
683
+ if not texts:
684
+ return "No texts provided for analysis."
685
+
686
+
687
+ # 2. Compute the average embedding for the input texts
688
+ # We create a temporary DataFrame to use the existing embedding generation utility
689
+ temp_df = pd.DataFrame([{'fullText': texts}])
690
+ input_embedding_list = generate_style_embedding(temp_df, 'fullText', model_name, dimensionality_reduction=False)
691
+
692
+ if not input_embedding_list:
693
+ return "Could not generate an embedding for the provided texts."
694
+
695
+ input_embedding = np.array(input_embedding_list[0]).reshape(1, -1)
696
+
697
+ # 3. Find the most similar cluster
698
+ cluster_embeddings = {int(k): np.array(v[0]) for k, v in interp_space.items()}
699
+
700
+ best_cluster_label = -1
701
+ max_similarity = -1
702
+
703
+ for label, cluster_emb in cluster_embeddings.items():
704
+ similarity = cosine_similarity(input_embedding, cluster_emb.reshape(1, -1))[0][0]
705
+ if similarity > max_similarity:
706
+ max_similarity = similarity
707
+ best_cluster_label = label
708
+
709
+ # 4. Return the style description of the closest cluster
710
+ return interp_space.get(str(best_cluster_label), [None, "Could not find a matching style description."])[1], input_embedding[0]
711
+
712
+
713
  def compute_clusters_g2v_representation(
714
  background_corpus_df: pd.DataFrame,
715
  author_ids: List[Any],