explainability-tool-for-aa / baseline_static_explanations.py
Milad Alshomary
updates
a721fcf
raw
history blame
7.94 kB
import argparse
import pandas as pd
import numpy as np
import os, json
from utils.interp_space_utils import cached_generate_style_embedding
from utils.clustering_utils import clustering_author
from utils.interp_space_utils import compute_clusters_style_representation_3, summarize_style_features_to_paragraph, find_closest_cluster_style
from sklearn.metrics.pairwise import cosine_distances, cosine_similarity
def build_static_interp_space(cluster_df):
"""
Takes a dataframe with cluster_label indicates every author's cluster and return a
json file with key the cluster_label and value containing the style-embedding representation and the style description
Example cluster_df
fullText authorID Style-Embedding_style_embedding cluster_label
4 [I've play them all (D3, Torchlight 1&2, P... HaxRyter [0.7126333904811682, -0.5076461933032986, -0.1... 0
10 [Back in Texas. Buddy had a kid in an up and ... OaklandHellBent [0.11238726238181786, 0.9263576185812101, -0.2... 1
"""
# Find the embedding column (assuming it's the only one ending with '_style_embedding')
embedding_clm = next((col for col in cluster_df.columns if col.endswith('_style_embedding')), None)
if not embedding_clm:
raise ValueError("No style embedding column found in the DataFrame.")
print(f"Using embedding column: {embedding_clm}")
# Group by cluster label and calculate the average embedding for each cluster
# We also aggregate authorIDs to use them for style representation
cluster_groups = cluster_df.groupby('cluster_label').agg({
embedding_clm: lambda embs: np.mean(np.vstack(embs), axis=0).tolist(),
'authorID': list
}).reset_index()
interpretable_space = {}
for _, row in cluster_groups.iterrows():
cluster_label = row['cluster_label']
avg_embedding = row[embedding_clm]
author_ids_in_cluster = row['authorID']
print(f"\nProcessing cluster {cluster_label} with {len(author_ids_in_cluster)} authors...")
# Generate style description using an LLM
# We reuse the utility function from the interactive tool for consistency
style_analysis = compute_clusters_style_representation_3(
background_corpus_df=cluster_df,
cluster_ids=author_ids_in_cluster,
cluster_label_clm_name='authorID',
max_num_feats=5, # Requesting 5 top features
max_num_authors=20, # Use up to 20 authors from the cluster for analysis
return_only_feats=True
)
# When return_only_feats=True, style_analysis is a list of features
style_features_list = style_analysis
print(f" Generated style features: {style_features_list}")
# Summarize the list of features into a coherent paragraph
style_paragraph = summarize_style_features_to_paragraph(style_features_list)
print(f" Summarized paragraph: {style_paragraph}")
# JSON cannot serialize numpy integers, so convert cluster_label
interpretable_space[int(cluster_label)] = (avg_embedding, style_paragraph)
return interpretable_space
def generate_explanations(args):
input_file = args.input_file
interp_space_path = args.interp_space_path
output_file = args.output_file
model_name = args.model_name if args.model_name else 'AnnaWegmann/Style-Embedding'
instances_for_ex = json.load(open(input_file))
interp_space = json.load(open(interp_space_path))
output = []
for instance in instances_for_ex:
json_obj = {}
json_obj['Q_authorID'] = instance['Q_authorID']
json_obj['Q_fullText'] = '\n\n'.join(instance['Q_fullText'])
style_descirption, q_embeddings = find_closest_cluster_style(instance['Q_fullText'], interp_space, model_name=model_name)
json_obj['Q_top_style_feats'] = style_descirption
json_obj['a0_authorID'] = instance['a0_authorID']
json_obj['a0_fullText'] = '\n\n'.join(instance['a0_fullText'])
style_descirption, a0_embeddings = find_closest_cluster_style(instance['a0_fullText'], interp_space, model_name=model_name)
json_obj['a0_top_style_feats'] = style_descirption
json_obj['a1_authorID'] = instance['a1_authorID']
json_obj['a1_fullText'] = '\n\n'.join(instance['a1_fullText'])
style_descirption, a1_embeddings = find_closest_cluster_style(instance['a1_fullText'], interp_space, model_name=model_name)
json_obj['a1_top_style_feats'] = style_descirption
json_obj['a2_authorID'] = instance['a2_authorID']
json_obj['a2_fullText'] = '\n\n'.join(instance['a2_fullText'])
style_descirption, a2_embeddings = find_closest_cluster_style(instance['a2_fullText'], interp_space, model_name=model_name)
json_obj['a2_top_style_feats'] = style_descirption
json_obj['gt_idx'] = instance['gt_idx']
# Compute pairwise similarity between q_embeddings and all a_embeddings
# Ensure embeddings are 2D arrays for cosine_similarity
q_emb_2d = np.array(q_embeddings).reshape(1, -1)
a0_emb_2d = np.array(a0_embeddings).reshape(1, -1)
a1_emb_2d = np.array(a1_embeddings).reshape(1, -1)
a2_emb_2d = np.array(a2_embeddings).reshape(1, -1)
similarity_q_a0 = cosine_similarity(q_emb_2d, a0_emb_2d)[0][0]
similarity_q_a1 = cosine_similarity(q_emb_2d, a1_emb_2d)[0][0]
similarity_q_a2 = cosine_similarity(q_emb_2d, a2_emb_2d)[0][0]
ranked_candidates = [
{'authorID': instance['a0_authorID'], 'similarity': float(similarity_q_a0)},
{'authorID': instance['a1_authorID'], 'similarity': float(similarity_q_a1)},
{'authorID': instance['a2_authorID'], 'similarity': float(similarity_q_a2)},
]
json_obj['latent_rank'] = np.argsort([x['similarity'] for x in ranked_candidates]).tolist()
json_obj['model_pred'] = 'Candidate {}'.format(json_obj['latent_rank'][0] + 1)
output.append(json_obj)
json.dump(output, open(output_file, 'w'), indent=4)
def main():
"""
Main function to generate and save the static interpretable space.
"""
parser = argparse.ArgumentParser(
description="Build a static interpretable space from clustered author data."
)
parser.add_argument(
"task",
type=str,
help="task: one of the following: build_static_interp_space, generate_explanations",
choices=["build_static_interp_space", "generate_explanations"]
)
parser.add_argument(
"input_file",
type=str,
help="Path to the input clustered DataFrame (.pkl file)."
)
parser.add_argument(
"output_file",
type=str,
help="file to save the output"
)
parser.add_argument(
"--interp_space_path",
type=str,
help="Path to the input interpretable space(.pkl file)."
)
parser.add_argument(
"--model_name",
type=str,
help="style analysis model name"
)
args = parser.parse_args()
if args.task == "build_static_interp_space":
return build_and_save_static_interp_space(args)
elif args.task == "generate_explanations":
return generate_explanations(args)
else:
raise ValueError(f"Unknown task: {args.task}")
def build_and_save_static_interp_space(args):
print(f"Loading clustered data from {args.input_file}...")
clustered_df = pd.read_pickle(args.input_file)
interpretable_space = build_static_interp_space(clustered_df)
print(f"\nSaving interpretable space to {args.output_file}...")
with open(args.output_file, 'w') as f:
json.dump(interpretable_space, f, indent=4)
print("Done.")
if __name__ == "__main__":
main()