Milad Alshomary
commited on
Commit
·
a5e49c0
1
Parent(s):
dcbbcbd
updates
Browse files- README.md +20 -0
- app.py +1 -1
- baseline_static_explanations.py +196 -0
- prepare_data.py +3 -0
- utils/clustering_utils.py +4 -0
- utils/interp_space_utils.py +116 -33
README.md
CHANGED
|
@@ -13,3 +13,23 @@ short_description: Interpreting the latent space of Authorship Attribution
|
|
| 13 |
---
|
| 14 |
|
| 15 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
---
|
| 14 |
|
| 15 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
## Useful commands
|
| 19 |
+
|
| 20 |
+
### Prepare data training/test
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
### Clustering the background corpus
|
| 25 |
+
|
| 26 |
+
python cluster_corpus.py ../../iarpa-hiatus/explanation_tool_files/reddit_cluster_training.pkl ../../iarpa-hiatus/explanation_tool_files/reddit_cluster_test.pkl "AnnaWegmann/Style-Embedding" ./datasets/reddit_clustered_authors.pkl --min_samples 2 --metric cosine --pca_dimensions 100 --eps 0.04
|
| 27 |
+
|
| 28 |
+
### Generate explainability sample
|
| 29 |
+
|
| 30 |
+
python prepare_data.py ../explanation_tool_files/reddit_cluster_test.pkl ./datasets/reddit_explanation_sample.json
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
### Generate static explanations for a sample
|
| 34 |
+
|
| 35 |
+
python baseline_static_explanations.py generate_explanations ./datasets/reddit_explanation_sample.json ./datasets/reddit_explanation_sample_with_explanations.json --interp_space_path ./datasets/reddit_interp_space.json --model_name 'AnnaWegmann/Style-Embedding'
|
app.py
CHANGED
|
@@ -42,7 +42,7 @@ from utils.interp_space_utils import *
|
|
| 42 |
from utils.ui import *
|
| 43 |
|
| 44 |
load_dotenv()
|
| 45 |
-
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
| 46 |
|
| 47 |
|
| 48 |
# ── load once at startup ────────────────────────────────────────
|
|
|
|
| 42 |
from utils.ui import *
|
| 43 |
|
| 44 |
load_dotenv()
|
| 45 |
+
client = OpenAI(base_url=os.getenv("OPENAI_API_BASE"), api_key=os.getenv("OPENAI_API_KEY"))
|
| 46 |
|
| 47 |
|
| 48 |
# ── load once at startup ────────────────────────────────────────
|
baseline_static_explanations.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import numpy as np
|
| 4 |
+
import os, json
|
| 5 |
+
|
| 6 |
+
from utils.interp_space_utils import cached_generate_style_embedding
|
| 7 |
+
from utils.clustering_utils import clustering_author
|
| 8 |
+
from utils.interp_space_utils import compute_clusters_style_representation_3, summarize_style_features_to_paragraph, find_closest_cluster_style
|
| 9 |
+
|
| 10 |
+
from sklearn.metrics.pairwise import cosine_distances, cosine_similarity
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def build_static_interp_space(cluster_df):
|
| 14 |
+
"""
|
| 15 |
+
Takes a dataframe with cluster_label indicates every author's cluster and return a
|
| 16 |
+
json file with key the cluster_label and value containing the style-embedding representation and the style description
|
| 17 |
+
|
| 18 |
+
Example cluster_df
|
| 19 |
+
fullText authorID Style-Embedding_style_embedding cluster_label
|
| 20 |
+
4 [I've play them all (D3, Torchlight 1&2, P... HaxRyter [0.7126333904811682, -0.5076461933032986, -0.1... 0
|
| 21 |
+
10 [Back in Texas. Buddy had a kid in an up and ... OaklandHellBent [0.11238726238181786, 0.9263576185812101, -0.2... 1
|
| 22 |
+
|
| 23 |
+
"""
|
| 24 |
+
# Find the embedding column (assuming it's the only one ending with '_style_embedding')
|
| 25 |
+
embedding_clm = next((col for col in cluster_df.columns if col.endswith('_style_embedding')), None)
|
| 26 |
+
if not embedding_clm:
|
| 27 |
+
raise ValueError("No style embedding column found in the DataFrame.")
|
| 28 |
+
|
| 29 |
+
print(f"Using embedding column: {embedding_clm}")
|
| 30 |
+
|
| 31 |
+
# Group by cluster label and calculate the average embedding for each cluster
|
| 32 |
+
# We also aggregate authorIDs to use them for style representation
|
| 33 |
+
cluster_groups = cluster_df.groupby('cluster_label').agg({
|
| 34 |
+
embedding_clm: lambda embs: np.mean(np.vstack(embs), axis=0).tolist(),
|
| 35 |
+
'authorID': list
|
| 36 |
+
}).reset_index()
|
| 37 |
+
|
| 38 |
+
interpretable_space = {}
|
| 39 |
+
|
| 40 |
+
for _, row in cluster_groups.iterrows():
|
| 41 |
+
cluster_label = row['cluster_label']
|
| 42 |
+
avg_embedding = row[embedding_clm]
|
| 43 |
+
author_ids_in_cluster = row['authorID']
|
| 44 |
+
|
| 45 |
+
print(f"\nProcessing cluster {cluster_label} with {len(author_ids_in_cluster)} authors...")
|
| 46 |
+
|
| 47 |
+
# Generate style description using an LLM
|
| 48 |
+
# We reuse the utility function from the interactive tool for consistency
|
| 49 |
+
style_analysis = compute_clusters_style_representation_3(
|
| 50 |
+
background_corpus_df=cluster_df,
|
| 51 |
+
cluster_ids=author_ids_in_cluster,
|
| 52 |
+
cluster_label_clm_name='authorID',
|
| 53 |
+
max_num_feats=5, # Requesting 5 top features
|
| 54 |
+
max_num_authors=20, # Use up to 20 authors from the cluster for analysis
|
| 55 |
+
return_only_feats=True
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
# When return_only_feats=True, style_analysis is a list of features
|
| 59 |
+
style_features_list = style_analysis
|
| 60 |
+
print(f" Generated style features: {style_features_list}")
|
| 61 |
+
|
| 62 |
+
# Summarize the list of features into a coherent paragraph
|
| 63 |
+
style_paragraph = summarize_style_features_to_paragraph(style_features_list)
|
| 64 |
+
print(f" Summarized paragraph: {style_paragraph}")
|
| 65 |
+
|
| 66 |
+
# JSON cannot serialize numpy integers, so convert cluster_label
|
| 67 |
+
interpretable_space[int(cluster_label)] = (avg_embedding, style_paragraph)
|
| 68 |
+
|
| 69 |
+
return interpretable_space
|
| 70 |
+
|
| 71 |
+
def generate_explanations(args):
|
| 72 |
+
input_file = args.input_file
|
| 73 |
+
interp_space_path = args.interp_space_path
|
| 74 |
+
output_file = args.output_file
|
| 75 |
+
model_name = args.model_name if args.model_name else 'AnnaWegmann/Style-Embedding'
|
| 76 |
+
|
| 77 |
+
instances_for_ex = json.load(open(input_file))
|
| 78 |
+
interp_space = json.load(open(interp_space_path))
|
| 79 |
+
|
| 80 |
+
output = []
|
| 81 |
+
for instance in instances_for_ex:
|
| 82 |
+
json_obj = {}
|
| 83 |
+
json_obj['Q_authorID'] = instance['Q_authorID']
|
| 84 |
+
json_obj['Q_fullText'] = instance['Q_fullText']
|
| 85 |
+
style_descirption, q_embeddings = find_closest_cluster_style(instance['Q_fullText'], interp_space, model_name=model_name)
|
| 86 |
+
json_obj['Q_top_style_feats'] = style_descirption
|
| 87 |
+
|
| 88 |
+
json_obj['a0_authorID'] = instance['a0_authorID']
|
| 89 |
+
json_obj['a0_fullText'] = instance['a0_fullText']
|
| 90 |
+
style_descirption, a0_embeddings = find_closest_cluster_style(instance['a0_fullText'], interp_space, model_name=model_name)
|
| 91 |
+
json_obj['a0_top_style_feats'] = style_descirption
|
| 92 |
+
|
| 93 |
+
json_obj['a1_authorID'] = instance['a1_authorID']
|
| 94 |
+
json_obj['a1_fullText'] = instance['a1_fullText']
|
| 95 |
+
style_descirption, a1_embeddings = find_closest_cluster_style(instance['a1_fullText'], interp_space, model_name=model_name)
|
| 96 |
+
json_obj['a1_top_style_feats'] = style_descirption
|
| 97 |
+
|
| 98 |
+
json_obj['a2_authorID'] = instance['a2_authorID']
|
| 99 |
+
json_obj['a2_fullText'] = instance['a2_fullText']
|
| 100 |
+
style_descirption, a2_embeddings = find_closest_cluster_style(instance['a2_fullText'], interp_space, model_name=model_name)
|
| 101 |
+
json_obj['a2_top_style_feats'] = style_descirption
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
# Compute pairwise similarity between q_embeddings and all a_embeddings
|
| 105 |
+
# Ensure embeddings are 2D arrays for cosine_similarity
|
| 106 |
+
q_emb_2d = np.array(q_embeddings).reshape(1, -1)
|
| 107 |
+
a0_emb_2d = np.array(a0_embeddings).reshape(1, -1)
|
| 108 |
+
a1_emb_2d = np.array(a1_embeddings).reshape(1, -1)
|
| 109 |
+
a2_emb_2d = np.array(a2_embeddings).reshape(1, -1)
|
| 110 |
+
|
| 111 |
+
similarity_q_a0 = cosine_similarity(q_emb_2d, a0_emb_2d)[0][0]
|
| 112 |
+
similarity_q_a1 = cosine_similarity(q_emb_2d, a1_emb_2d)[0][0]
|
| 113 |
+
similarity_q_a2 = cosine_similarity(q_emb_2d, a2_emb_2d)[0][0]
|
| 114 |
+
|
| 115 |
+
ranked_candidates = [
|
| 116 |
+
{'authorID': instance['a0_authorID'], 'similarity': float(similarity_q_a0)},
|
| 117 |
+
{'authorID': instance['a1_authorID'], 'similarity': float(similarity_q_a1)},
|
| 118 |
+
{'authorID': instance['a2_authorID'], 'similarity': float(similarity_q_a2)},
|
| 119 |
+
]
|
| 120 |
+
|
| 121 |
+
json_obj['latent_rank'] = np.argsort([x['similarity'] for x in ranked_candidates]).tolist()
|
| 122 |
+
json_obj['model_pred'] = 'Candidate {}'.format(json_obj['latent_rank'][0] + 1)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
output.append(json_obj)
|
| 127 |
+
|
| 128 |
+
json.dump(output, open(output_file, 'w'), indent=4)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def main():
|
| 134 |
+
"""
|
| 135 |
+
Main function to generate and save the static interpretable space.
|
| 136 |
+
"""
|
| 137 |
+
|
| 138 |
+
parser = argparse.ArgumentParser(
|
| 139 |
+
description="Build a static interpretable space from clustered author data."
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
parser.add_argument(
|
| 143 |
+
"task",
|
| 144 |
+
type=str,
|
| 145 |
+
help="task: one of the following: build_static_interp_space, generate_explanations",
|
| 146 |
+
choices=["build_static_interp_space", "generate_explanations"]
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
parser.add_argument(
|
| 150 |
+
"input_file",
|
| 151 |
+
type=str,
|
| 152 |
+
help="Path to the input clustered DataFrame (.pkl file)."
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
parser.add_argument(
|
| 156 |
+
"output_file",
|
| 157 |
+
type=str,
|
| 158 |
+
help="file to save the output"
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
parser.add_argument(
|
| 162 |
+
"--interp_space_path",
|
| 163 |
+
type=str,
|
| 164 |
+
help="Path to the input interpretable space(.pkl file)."
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
parser.add_argument(
|
| 168 |
+
"--model_name",
|
| 169 |
+
type=str,
|
| 170 |
+
help="style analysis model name"
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
args = parser.parse_args()
|
| 174 |
+
|
| 175 |
+
if args.task == "build_static_interp_space":
|
| 176 |
+
return build_and_save_static_interp_space(args)
|
| 177 |
+
elif args.task == "generate_explanations":
|
| 178 |
+
return generate_explanations(args)
|
| 179 |
+
else:
|
| 180 |
+
raise ValueError(f"Unknown task: {args.task}")
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def build_and_save_static_interp_space(args):
|
| 184 |
+
print(f"Loading clustered data from {args.input_file}...")
|
| 185 |
+
clustered_df = pd.read_pickle(args.input_file)
|
| 186 |
+
|
| 187 |
+
interpretable_space = build_static_interp_space(clustered_df)
|
| 188 |
+
|
| 189 |
+
print(f"\nSaving interpretable space to {args.output_file}...")
|
| 190 |
+
with open(args.output_file, 'w') as f:
|
| 191 |
+
json.dump(interpretable_space, f, indent=4)
|
| 192 |
+
|
| 193 |
+
print("Done.")
|
| 194 |
+
|
| 195 |
+
if __name__ == "__main__":
|
| 196 |
+
main()
|
prepare_data.py
CHANGED
|
@@ -44,6 +44,9 @@ def sample_ds(input_file, output_file, num_insts=10000, min_num_text_per_inst=0,
|
|
| 44 |
df = pd.DataFrame(out_list)
|
| 45 |
df.to_pickle(output_file)
|
| 46 |
|
|
|
|
|
|
|
|
|
|
| 47 |
def get_reddit_data(input_path, random_seed=123, num_instances=100, num_documents_per_author=8, min_instance_len=10):
|
| 48 |
|
| 49 |
df = pd.read_pickle(open(input_path, 'rb'))
|
|
|
|
| 44 |
df = pd.DataFrame(out_list)
|
| 45 |
df.to_pickle(output_file)
|
| 46 |
|
| 47 |
+
df = df.explode('fullText').reset_index()
|
| 48 |
+
df.to_json(output_file.replace('.pkl', '.json'))
|
| 49 |
+
|
| 50 |
def get_reddit_data(input_path, random_seed=123, num_instances=100, num_documents_per_author=8, min_instance_len=10):
|
| 51 |
|
| 52 |
df = pd.read_pickle(open(input_path, 'rb'))
|
utils/clustering_utils.py
CHANGED
|
@@ -128,6 +128,7 @@ def clustering_author(background_corpus_df: pd.DataFrame,
|
|
| 128 |
return background_corpus_df
|
| 129 |
|
| 130 |
X = np.array(X_list) # Creates a 2D array from the list of 1D arrays
|
|
|
|
| 131 |
|
| 132 |
if X.shape[0] == 1:
|
| 133 |
print("Only one valid embedding found. Assigning cluster label 0 to it.")
|
|
@@ -279,6 +280,9 @@ def clustering_author(background_corpus_df: pd.DataFrame,
|
|
| 279 |
print("No suitable DBSCAN clustering found meeting criteria. All processed embeddings marked as noise (-1).")
|
| 280 |
|
| 281 |
background_corpus_df['cluster_label'] = final_labels_for_df
|
|
|
|
|
|
|
|
|
|
| 282 |
return background_corpus_df
|
| 283 |
|
| 284 |
|
|
|
|
| 128 |
return background_corpus_df
|
| 129 |
|
| 130 |
X = np.array(X_list) # Creates a 2D array from the list of 1D arrays
|
| 131 |
+
original_embeddings_list = [embeddings_list[i] for i in original_indices]
|
| 132 |
|
| 133 |
if X.shape[0] == 1:
|
| 134 |
print("Only one valid embedding found. Assigning cluster label 0 to it.")
|
|
|
|
| 280 |
print("No suitable DBSCAN clustering found meeting criteria. All processed embeddings marked as noise (-1).")
|
| 281 |
|
| 282 |
background_corpus_df['cluster_label'] = final_labels_for_df
|
| 283 |
+
# restore the original embedding
|
| 284 |
+
print(original_embeddings_list[0].shape)
|
| 285 |
+
background_corpus_df[embedding_clm] = original_embeddings_list
|
| 286 |
return background_corpus_df
|
| 287 |
|
| 288 |
|
utils/interp_space_utils.py
CHANGED
|
@@ -25,6 +25,7 @@ from sklearn.decomposition import PCA
|
|
| 25 |
CACHE_DIR = "datasets/embeddings_cache"
|
| 26 |
ZOOM_CACHE = "datasets/zoom_cache/features_cache.json"
|
| 27 |
REGION_CACHE = "datasets/region_cache/regions_cache.pkl"
|
|
|
|
| 28 |
os.makedirs(CACHE_DIR, exist_ok=True)
|
| 29 |
os.makedirs(os.path.dirname(ZOOM_CACHE), exist_ok=True)
|
| 30 |
os.makedirs(os.path.dirname(REGION_CACHE), exist_ok=True)
|
|
@@ -41,6 +42,9 @@ class FeatureIdentificationSchema(BaseModel):
|
|
| 41 |
class SpanExtractionSchema(BaseModel):
|
| 42 |
spans: dict[str, dict[str, list[str]]] # {author_name: {feature: [spans]}}
|
| 43 |
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
|
| 46 |
def compute_g2v_features(clustered_authors_df: pd.DataFrame, task_authors_df: pd.DataFrame=None, text_clm='fullText') -> pd.DataFrame:
|
|
@@ -398,7 +402,7 @@ def compute_clusters_style_representation_2(
|
|
| 398 |
"""
|
| 399 |
Call openAI to analyze the common writing style features of the given list of texts
|
| 400 |
"""
|
| 401 |
-
client = OpenAI(
|
| 402 |
|
| 403 |
background_corpus_df['fullText'] = background_corpus_df['fullText'].map(lambda x: '\n\n'.join(x[:max_num_documents_per_author]) if isinstance(x, list) else x)
|
| 404 |
background_corpus_df = background_corpus_df[background_corpus_df[cluster_label_clm_name].isin(cluster_ids)]
|
|
@@ -430,7 +434,7 @@ def compute_clusters_style_representation_2(
|
|
| 430 |
else: # Else compute and cache
|
| 431 |
|
| 432 |
response = client.chat.completions.create(
|
| 433 |
-
model="gpt-4o
|
| 434 |
messages=[
|
| 435 |
{"role":"assistant","content":"You are a forensic linguistic who knows how to analyze similarites in writing styles."},
|
| 436 |
{"role":"user","content":prompt}],
|
|
@@ -472,7 +476,7 @@ def identify_style_features(author_texts: list[str], author_names: list[str], ma
|
|
| 472 |
else:
|
| 473 |
print(f"Cache miss. Computing features for authors: {author_names}")
|
| 474 |
|
| 475 |
-
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
| 476 |
prompt = f"""Identify {max_num_feats} writing style features that are common between the authors texts.
|
| 477 |
Author Texts:
|
| 478 |
|
|
@@ -530,7 +534,7 @@ def extract_all_spans(authors_df: pd.DataFrame, features: list[str], cluster_lab
|
|
| 530 |
For each author, use `generate_feature_spans_cached` to get feature->span mappings.
|
| 531 |
Returns a dict: {author_name: {feature: [spans]}}
|
| 532 |
"""
|
| 533 |
-
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
| 534 |
|
| 535 |
spans_by_author = {}
|
| 536 |
|
|
@@ -552,7 +556,8 @@ def compute_clusters_style_representation_3(
|
|
| 552 |
max_num_documents_per_author=10,
|
| 553 |
max_num_authors=10,
|
| 554 |
max_authors_for_span_extraction=4,
|
| 555 |
-
top_k: int = 10
|
|
|
|
| 556 |
):
|
| 557 |
|
| 558 |
print(f"Computing style representation for visible clusters: {len(cluster_ids)}")
|
|
@@ -567,6 +572,9 @@ def compute_clusters_style_representation_3(
|
|
| 567 |
print(author_names)
|
| 568 |
features = identify_style_features(author_texts, author_names, max_num_feats=max_num_feats)
|
| 569 |
|
|
|
|
|
|
|
|
|
|
| 570 |
print("Features: ", features)
|
| 571 |
# STEP 2: Prepare author pool for span extraction
|
| 572 |
span_df = background_corpus_df.iloc[:max_authors_for_span_extraction]
|
|
@@ -577,34 +585,6 @@ def compute_clusters_style_representation_3(
|
|
| 577 |
|
| 578 |
# Filter-in only task authors that are part of the current selection
|
| 579 |
task_author_names = {'Mystery author', 'Candidate Author 1', 'Candidate Author 2', 'Candidate Author 3'}
|
| 580 |
-
#filtered_task_authors = {author: feat_map for author, feat_map in spans_by_author.items() if author in task_author_names.intersection(set(cluster_ids))}
|
| 581 |
-
|
| 582 |
-
# Build per-author sets of features that have at least one span
|
| 583 |
-
# author_present_feature_sets = [
|
| 584 |
-
# {feature for feature, spans in feature_map.items() if spans and len(spans) > 0}
|
| 585 |
-
# for _, feature_map in filtered_task_authors.items()
|
| 586 |
-
# ]
|
| 587 |
-
|
| 588 |
-
# print(filtered_task_authors.keys(), author_present_feature_sets)
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
# if len(author_present_feature_sets) > 0: # we have more than one task author
|
| 592 |
-
# coverage_counter = Counter()
|
| 593 |
-
# for present_set in author_present_feature_sets:
|
| 594 |
-
# coverage_counter.update(present_set)
|
| 595 |
-
|
| 596 |
-
# # Keep features present in at least `min_authors_required` authors
|
| 597 |
-
# eligible_features = [feat for feat, cnt in coverage_counter.items() if cnt >= len(author_present_feature_sets)]
|
| 598 |
-
|
| 599 |
-
# # Preserve original LLM feature ordering as a secondary key where possible
|
| 600 |
-
# feature_original_index = {feat: idx for idx, feat in enumerate(features)} if features else {}
|
| 601 |
-
|
| 602 |
-
# selected_features_ranked = sorted(
|
| 603 |
-
# eligible_features,
|
| 604 |
-
# key=lambda f: (-coverage_counter[f], feature_original_index.get(f, 10**9))
|
| 605 |
-
# )[:int(top_k)]
|
| 606 |
-
# else:
|
| 607 |
-
# selected_features_ranked = features
|
| 608 |
|
| 609 |
|
| 610 |
feature_importance = {f : 0 for f in features}
|
|
@@ -627,6 +607,109 @@ def compute_clusters_style_representation_3(
|
|
| 627 |
"spans": spans_by_author
|
| 628 |
}
|
| 629 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 630 |
def compute_clusters_g2v_representation(
|
| 631 |
background_corpus_df: pd.DataFrame,
|
| 632 |
author_ids: List[Any],
|
|
|
|
| 25 |
CACHE_DIR = "datasets/embeddings_cache"
|
| 26 |
ZOOM_CACHE = "datasets/zoom_cache/features_cache.json"
|
| 27 |
REGION_CACHE = "datasets/region_cache/regions_cache.pkl"
|
| 28 |
+
SUMMARY_CACHE = "datasets/summary_cache/summaries.json"
|
| 29 |
os.makedirs(CACHE_DIR, exist_ok=True)
|
| 30 |
os.makedirs(os.path.dirname(ZOOM_CACHE), exist_ok=True)
|
| 31 |
os.makedirs(os.path.dirname(REGION_CACHE), exist_ok=True)
|
|
|
|
| 42 |
class SpanExtractionSchema(BaseModel):
|
| 43 |
spans: dict[str, dict[str, list[str]]] # {author_name: {feature: [spans]}}
|
| 44 |
|
| 45 |
+
class StyleSummarySchema(BaseModel):
|
| 46 |
+
summary_paragraph: str
|
| 47 |
+
|
| 48 |
|
| 49 |
|
| 50 |
def compute_g2v_features(clustered_authors_df: pd.DataFrame, task_authors_df: pd.DataFrame=None, text_clm='fullText') -> pd.DataFrame:
|
|
|
|
| 402 |
"""
|
| 403 |
Call openAI to analyze the common writing style features of the given list of texts
|
| 404 |
"""
|
| 405 |
+
client = OpenAI(base_url=os.getenv("OPENAI_BASE_URL", None), pi_key=os.getenv("OPENAI_API_KEY"))
|
| 406 |
|
| 407 |
background_corpus_df['fullText'] = background_corpus_df['fullText'].map(lambda x: '\n\n'.join(x[:max_num_documents_per_author]) if isinstance(x, list) else x)
|
| 408 |
background_corpus_df = background_corpus_df[background_corpus_df[cluster_label_clm_name].isin(cluster_ids)]
|
|
|
|
| 434 |
else: # Else compute and cache
|
| 435 |
|
| 436 |
response = client.chat.completions.create(
|
| 437 |
+
model="gpt-4o",
|
| 438 |
messages=[
|
| 439 |
{"role":"assistant","content":"You are a forensic linguistic who knows how to analyze similarites in writing styles."},
|
| 440 |
{"role":"user","content":prompt}],
|
|
|
|
| 476 |
else:
|
| 477 |
print(f"Cache miss. Computing features for authors: {author_names}")
|
| 478 |
|
| 479 |
+
client = OpenAI(base_url=os.getenv("OPENAI_BASE_URL", None), api_key=os.getenv("OPENAI_API_KEY"))
|
| 480 |
prompt = f"""Identify {max_num_feats} writing style features that are common between the authors texts.
|
| 481 |
Author Texts:
|
| 482 |
|
|
|
|
| 534 |
For each author, use `generate_feature_spans_cached` to get feature->span mappings.
|
| 535 |
Returns a dict: {author_name: {feature: [spans]}}
|
| 536 |
"""
|
| 537 |
+
client = OpenAI(base_url=os.getenv("OPENAI_BASE_URL", None), api_key=os.getenv("OPENAI_API_KEY"))
|
| 538 |
|
| 539 |
spans_by_author = {}
|
| 540 |
|
|
|
|
| 556 |
max_num_documents_per_author=10,
|
| 557 |
max_num_authors=10,
|
| 558 |
max_authors_for_span_extraction=4,
|
| 559 |
+
top_k: int = 10,
|
| 560 |
+
return_only_feats= False,
|
| 561 |
):
|
| 562 |
|
| 563 |
print(f"Computing style representation for visible clusters: {len(cluster_ids)}")
|
|
|
|
| 572 |
print(author_names)
|
| 573 |
features = identify_style_features(author_texts, author_names, max_num_feats=max_num_feats)
|
| 574 |
|
| 575 |
+
if return_only_feats:
|
| 576 |
+
return features
|
| 577 |
+
|
| 578 |
print("Features: ", features)
|
| 579 |
# STEP 2: Prepare author pool for span extraction
|
| 580 |
span_df = background_corpus_df.iloc[:max_authors_for_span_extraction]
|
|
|
|
| 585 |
|
| 586 |
# Filter-in only task authors that are part of the current selection
|
| 587 |
task_author_names = {'Mystery author', 'Candidate Author 1', 'Candidate Author 2', 'Candidate Author 3'}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 588 |
|
| 589 |
|
| 590 |
feature_importance = {f : 0 for f in features}
|
|
|
|
| 607 |
"spans": spans_by_author
|
| 608 |
}
|
| 609 |
|
| 610 |
+
def summarize_style_features_to_paragraph(features: list[str]) -> str:
|
| 611 |
+
"""
|
| 612 |
+
Takes a list of writing style features and uses an LLM to generate a
|
| 613 |
+
coherent, descriptive paragraph summarizing the style.
|
| 614 |
+
|
| 615 |
+
Args:
|
| 616 |
+
features (list[str]): A list of style features.
|
| 617 |
+
|
| 618 |
+
Returns:
|
| 619 |
+
str: A single paragraph summarizing the writing style.
|
| 620 |
+
"""
|
| 621 |
+
if not features:
|
| 622 |
+
return "No style features were identified for this selection."
|
| 623 |
+
|
| 624 |
+
# Generate a cache key based on the sorted features to ensure consistency
|
| 625 |
+
feature_key = hashlib.md5(json.dumps(sorted(features)).encode()).hexdigest()
|
| 626 |
+
|
| 627 |
+
os.makedirs(os.path.dirname(SUMMARY_CACHE), exist_ok=True)
|
| 628 |
+
if os.path.exists(SUMMARY_CACHE):
|
| 629 |
+
with open(SUMMARY_CACHE, 'r') as f:
|
| 630 |
+
try:
|
| 631 |
+
cache = json.load(f)
|
| 632 |
+
except json.JSONDecodeError:
|
| 633 |
+
cache = {}
|
| 634 |
+
else:
|
| 635 |
+
cache = {}
|
| 636 |
+
|
| 637 |
+
if feature_key in cache:
|
| 638 |
+
print(f"Cache hit for style summary. Key: {feature_key}")
|
| 639 |
+
return cache[feature_key]
|
| 640 |
+
|
| 641 |
+
print(f"Cache miss for style summary. Generating new summary...")
|
| 642 |
+
client = OpenAI(base_url=os.getenv("OPENAI_BASE_URL", None), api_key=os.getenv("OPENAI_API_KEY"))
|
| 643 |
+
|
| 644 |
+
feature_list_str = "\n".join([f"- {feat}" for feat in features])
|
| 645 |
+
prompt = f"""You are a linguistic analyst. Your task is to synthesize the following list of writing style features into a single, coherent, and descriptive paragraph. The paragraph should flow naturally and explain the overall writing style of an author based on these features. Be concise and only mention the features without referring to example spans.
|
| 646 |
+
|
| 647 |
+
Style Features:
|
| 648 |
+
{feature_list_str}
|
| 649 |
+
|
| 650 |
+
Please provide the summary as a single paragraph.
|
| 651 |
+
"""
|
| 652 |
+
|
| 653 |
+
def _make_call():
|
| 654 |
+
response = client.chat.completions.create(
|
| 655 |
+
model="gpt-4o",
|
| 656 |
+
messages=[{"role": "user", "content": prompt}],
|
| 657 |
+
response_format={"type": "json_schema", "json_schema": {"name": "StyleSummarySchema", "schema": to_strict_json_schema(StyleSummarySchema)}}
|
| 658 |
+
)
|
| 659 |
+
return json.loads(response.choices[0].message.content)
|
| 660 |
+
|
| 661 |
+
summary_paragraph = retry_call(_make_call, StyleSummarySchema).summary_paragraph
|
| 662 |
+
|
| 663 |
+
# Save to cache
|
| 664 |
+
cache[feature_key] = summary_paragraph
|
| 665 |
+
with open(SUMMARY_CACHE, 'w') as f:
|
| 666 |
+
json.dump(cache, f, indent=2)
|
| 667 |
+
|
| 668 |
+
return summary_paragraph
|
| 669 |
+
|
| 670 |
+
def find_closest_cluster_style(texts: list[str], interp_space, model_name: str) -> str:
|
| 671 |
+
"""
|
| 672 |
+
Computes the average embedding for a list of texts and finds the most similar
|
| 673 |
+
cluster from the interpretable space, returning its style description.
|
| 674 |
+
|
| 675 |
+
Args:
|
| 676 |
+
texts (list[str]): A list of texts for which to find a style description.
|
| 677 |
+
interp_space_path (str): Path to the interpretable_space.json file.
|
| 678 |
+
model_name (str): The name of the sentence transformer model to use for embeddings.
|
| 679 |
+
|
| 680 |
+
Returns:
|
| 681 |
+
str: The style description paragraph of the most similar cluster.
|
| 682 |
+
"""
|
| 683 |
+
if not texts:
|
| 684 |
+
return "No texts provided for analysis."
|
| 685 |
+
|
| 686 |
+
|
| 687 |
+
# 2. Compute the average embedding for the input texts
|
| 688 |
+
# We create a temporary DataFrame to use the existing embedding generation utility
|
| 689 |
+
temp_df = pd.DataFrame([{'fullText': texts}])
|
| 690 |
+
input_embedding_list = generate_style_embedding(temp_df, 'fullText', model_name, dimensionality_reduction=False)
|
| 691 |
+
|
| 692 |
+
if not input_embedding_list:
|
| 693 |
+
return "Could not generate an embedding for the provided texts."
|
| 694 |
+
|
| 695 |
+
input_embedding = np.array(input_embedding_list[0]).reshape(1, -1)
|
| 696 |
+
|
| 697 |
+
# 3. Find the most similar cluster
|
| 698 |
+
cluster_embeddings = {int(k): np.array(v[0]) for k, v in interp_space.items()}
|
| 699 |
+
|
| 700 |
+
best_cluster_label = -1
|
| 701 |
+
max_similarity = -1
|
| 702 |
+
|
| 703 |
+
for label, cluster_emb in cluster_embeddings.items():
|
| 704 |
+
similarity = cosine_similarity(input_embedding, cluster_emb.reshape(1, -1))[0][0]
|
| 705 |
+
if similarity > max_similarity:
|
| 706 |
+
max_similarity = similarity
|
| 707 |
+
best_cluster_label = label
|
| 708 |
+
|
| 709 |
+
# 4. Return the style description of the closest cluster
|
| 710 |
+
return interp_space.get(str(best_cluster_label), [None, "Could not find a matching style description."])[1], input_embedding[0]
|
| 711 |
+
|
| 712 |
+
|
| 713 |
def compute_clusters_g2v_representation(
|
| 714 |
background_corpus_df: pd.DataFrame,
|
| 715 |
author_ids: List[Any],
|