File size: 7,941 Bytes
a5e49c0 d8b3dc8 a5e49c0 d8b3dc8 a5e49c0 d8b3dc8 a5e49c0 d8b3dc8 a5e49c0 a721fcf a5e49c0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 |
import argparse
import pandas as pd
import numpy as np
import os, json
from utils.interp_space_utils import cached_generate_style_embedding
from utils.clustering_utils import clustering_author
from utils.interp_space_utils import compute_clusters_style_representation_3, summarize_style_features_to_paragraph, find_closest_cluster_style
from sklearn.metrics.pairwise import cosine_distances, cosine_similarity
def build_static_interp_space(cluster_df):
"""
Takes a dataframe with cluster_label indicates every author's cluster and return a
json file with key the cluster_label and value containing the style-embedding representation and the style description
Example cluster_df
fullText authorID Style-Embedding_style_embedding cluster_label
4 [I've play them all (D3, Torchlight 1&2, P... HaxRyter [0.7126333904811682, -0.5076461933032986, -0.1... 0
10 [Back in Texas. Buddy had a kid in an up and ... OaklandHellBent [0.11238726238181786, 0.9263576185812101, -0.2... 1
"""
# Find the embedding column (assuming it's the only one ending with '_style_embedding')
embedding_clm = next((col for col in cluster_df.columns if col.endswith('_style_embedding')), None)
if not embedding_clm:
raise ValueError("No style embedding column found in the DataFrame.")
print(f"Using embedding column: {embedding_clm}")
# Group by cluster label and calculate the average embedding for each cluster
# We also aggregate authorIDs to use them for style representation
cluster_groups = cluster_df.groupby('cluster_label').agg({
embedding_clm: lambda embs: np.mean(np.vstack(embs), axis=0).tolist(),
'authorID': list
}).reset_index()
interpretable_space = {}
for _, row in cluster_groups.iterrows():
cluster_label = row['cluster_label']
avg_embedding = row[embedding_clm]
author_ids_in_cluster = row['authorID']
print(f"\nProcessing cluster {cluster_label} with {len(author_ids_in_cluster)} authors...")
# Generate style description using an LLM
# We reuse the utility function from the interactive tool for consistency
style_analysis = compute_clusters_style_representation_3(
background_corpus_df=cluster_df,
cluster_ids=author_ids_in_cluster,
cluster_label_clm_name='authorID',
max_num_feats=5, # Requesting 5 top features
max_num_authors=20, # Use up to 20 authors from the cluster for analysis
return_only_feats=True
)
# When return_only_feats=True, style_analysis is a list of features
style_features_list = style_analysis
print(f" Generated style features: {style_features_list}")
# Summarize the list of features into a coherent paragraph
style_paragraph = summarize_style_features_to_paragraph(style_features_list)
print(f" Summarized paragraph: {style_paragraph}")
# JSON cannot serialize numpy integers, so convert cluster_label
interpretable_space[int(cluster_label)] = (avg_embedding, style_paragraph)
return interpretable_space
def generate_explanations(args):
input_file = args.input_file
interp_space_path = args.interp_space_path
output_file = args.output_file
model_name = args.model_name if args.model_name else 'AnnaWegmann/Style-Embedding'
instances_for_ex = json.load(open(input_file))
interp_space = json.load(open(interp_space_path))
output = []
for instance in instances_for_ex:
json_obj = {}
json_obj['Q_authorID'] = instance['Q_authorID']
json_obj['Q_fullText'] = '\n\n'.join(instance['Q_fullText'])
style_descirption, q_embeddings = find_closest_cluster_style(instance['Q_fullText'], interp_space, model_name=model_name)
json_obj['Q_top_style_feats'] = style_descirption
json_obj['a0_authorID'] = instance['a0_authorID']
json_obj['a0_fullText'] = '\n\n'.join(instance['a0_fullText'])
style_descirption, a0_embeddings = find_closest_cluster_style(instance['a0_fullText'], interp_space, model_name=model_name)
json_obj['a0_top_style_feats'] = style_descirption
json_obj['a1_authorID'] = instance['a1_authorID']
json_obj['a1_fullText'] = '\n\n'.join(instance['a1_fullText'])
style_descirption, a1_embeddings = find_closest_cluster_style(instance['a1_fullText'], interp_space, model_name=model_name)
json_obj['a1_top_style_feats'] = style_descirption
json_obj['a2_authorID'] = instance['a2_authorID']
json_obj['a2_fullText'] = '\n\n'.join(instance['a2_fullText'])
style_descirption, a2_embeddings = find_closest_cluster_style(instance['a2_fullText'], interp_space, model_name=model_name)
json_obj['a2_top_style_feats'] = style_descirption
json_obj['gt_idx'] = instance['gt_idx']
# Compute pairwise similarity between q_embeddings and all a_embeddings
# Ensure embeddings are 2D arrays for cosine_similarity
q_emb_2d = np.array(q_embeddings).reshape(1, -1)
a0_emb_2d = np.array(a0_embeddings).reshape(1, -1)
a1_emb_2d = np.array(a1_embeddings).reshape(1, -1)
a2_emb_2d = np.array(a2_embeddings).reshape(1, -1)
similarity_q_a0 = cosine_similarity(q_emb_2d, a0_emb_2d)[0][0]
similarity_q_a1 = cosine_similarity(q_emb_2d, a1_emb_2d)[0][0]
similarity_q_a2 = cosine_similarity(q_emb_2d, a2_emb_2d)[0][0]
ranked_candidates = [
{'authorID': instance['a0_authorID'], 'similarity': float(similarity_q_a0)},
{'authorID': instance['a1_authorID'], 'similarity': float(similarity_q_a1)},
{'authorID': instance['a2_authorID'], 'similarity': float(similarity_q_a2)},
]
json_obj['latent_rank'] = np.argsort([x['similarity'] for x in ranked_candidates]).tolist()
json_obj['model_pred'] = 'Candidate {}'.format(json_obj['latent_rank'][0] + 1)
output.append(json_obj)
json.dump(output, open(output_file, 'w'), indent=4)
def main():
"""
Main function to generate and save the static interpretable space.
"""
parser = argparse.ArgumentParser(
description="Build a static interpretable space from clustered author data."
)
parser.add_argument(
"task",
type=str,
help="task: one of the following: build_static_interp_space, generate_explanations",
choices=["build_static_interp_space", "generate_explanations"]
)
parser.add_argument(
"input_file",
type=str,
help="Path to the input clustered DataFrame (.pkl file)."
)
parser.add_argument(
"output_file",
type=str,
help="file to save the output"
)
parser.add_argument(
"--interp_space_path",
type=str,
help="Path to the input interpretable space(.pkl file)."
)
parser.add_argument(
"--model_name",
type=str,
help="style analysis model name"
)
args = parser.parse_args()
if args.task == "build_static_interp_space":
return build_and_save_static_interp_space(args)
elif args.task == "generate_explanations":
return generate_explanations(args)
else:
raise ValueError(f"Unknown task: {args.task}")
def build_and_save_static_interp_space(args):
print(f"Loading clustered data from {args.input_file}...")
clustered_df = pd.read_pickle(args.input_file)
interpretable_space = build_static_interp_space(clustered_df)
print(f"\nSaving interpretable space to {args.output_file}...")
with open(args.output_file, 'w') as f:
json.dump(interpretable_space, f, indent=4)
print("Done.")
if __name__ == "__main__":
main()
|