Milad Alshomary
commited on
Commit
Β·
c54eddb
1
Parent(s):
74947b9
updates
Browse files- app.py +2 -2
- utils/gram2vec_feat_utils.py +1 -1
- utils/interp_space_utils.py +22 -28
- utils/visualizations.py +1 -0
app.py
CHANGED
|
@@ -58,8 +58,8 @@ def app(share=False, use_cluster_feats=False):
|
|
| 58 |
instances, instance_ids = get_instances(cfg['instances_to_explain_path'])
|
| 59 |
|
| 60 |
interp = load_interp_space(cfg)
|
| 61 |
-
clustered_authors_df = interp['clustered_authors_df']
|
| 62 |
-
clustered_authors_df['fullText'] = clustered_authors_df['fullText'].map(lambda l: l[:
|
| 63 |
|
| 64 |
with gr.Blocks(title="Author Attribution Explainability Tool") as demo:
|
| 65 |
# ββ Big Centered Title ββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 58 |
instances, instance_ids = get_instances(cfg['instances_to_explain_path'])
|
| 59 |
|
| 60 |
interp = load_interp_space(cfg)
|
| 61 |
+
clustered_authors_df = interp['clustered_authors_df']
|
| 62 |
+
clustered_authors_df['fullText'] = clustered_authors_df['fullText'].map(lambda l: l[:5]) # Take at most 3 texts per author
|
| 63 |
|
| 64 |
with gr.Blocks(title="Author Attribution Explainability Tool") as demo:
|
| 65 |
# ββ Big Centered Title ββββββββββββββββββββββββββββββββββββββββββ
|
utils/gram2vec_feat_utils.py
CHANGED
|
@@ -126,7 +126,7 @@ def highlight_both_spans(text, llm_spans, gram_spans):
|
|
| 126 |
|
| 127 |
|
| 128 |
def show_combined_spans_all(selected_feature_llm, selected_feature_g2v,
|
| 129 |
-
llm_style_feats_analysis, background_authors_embeddings_df, task_authors_embeddings_df, visible_authors, predicted_author=None, ground_truth_author=None, max_num_authors=
|
| 130 |
"""
|
| 131 |
For mystery + 3 candidates:
|
| 132 |
1. get llm spans via your existing cache+API
|
|
|
|
| 126 |
|
| 127 |
|
| 128 |
def show_combined_spans_all(selected_feature_llm, selected_feature_g2v,
|
| 129 |
+
llm_style_feats_analysis, background_authors_embeddings_df, task_authors_embeddings_df, visible_authors, predicted_author=None, ground_truth_author=None, max_num_authors=4):
|
| 130 |
"""
|
| 131 |
For mystery + 3 candidates:
|
| 132 |
1. get llm spans via your existing cache+API
|
utils/interp_space_utils.py
CHANGED
|
@@ -449,11 +449,11 @@ def identify_style_features(author_texts: list[str], author_names: list[str], ma
|
|
| 449 |
print(f"Cache miss. Computing features for authors: {author_names}")
|
| 450 |
|
| 451 |
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
| 452 |
-
prompt = f"""Identify {max_num_feats} writing style features that are commonly
|
| 453 |
Author Texts:
|
| 454 |
-
\"\"\"{chr(10).join(author_texts)}\"\"\"
|
| 455 |
-
"""
|
| 456 |
|
|
|
|
|
|
|
| 457 |
def _make_call():
|
| 458 |
response = client.chat.completions.create(
|
| 459 |
model="gpt-4o-mini",
|
|
@@ -473,7 +473,6 @@ def identify_style_features(author_texts: list[str], author_names: list[str], ma
|
|
| 473 |
|
| 474 |
features = retry_call(_make_call, FeatureIdentificationSchema).features
|
| 475 |
|
| 476 |
-
print(f"Adding to zoom cache")
|
| 477 |
if cache_key and author_names:
|
| 478 |
cache[cache_key] = {
|
| 479 |
"features": features
|
|
@@ -519,10 +518,10 @@ def compute_clusters_style_representation_3(
|
|
| 519 |
background_corpus_df: pd.DataFrame,
|
| 520 |
cluster_ids: List[Any],
|
| 521 |
cluster_label_clm_name: str = 'authorID',
|
| 522 |
-
max_num_feats: int =
|
| 523 |
max_num_documents_per_author=3,
|
| 524 |
max_num_authors=5,
|
| 525 |
-
max_authors_for_span_extraction=
|
| 526 |
):
|
| 527 |
|
| 528 |
print(f"Computing style representation for visible clusters: {len(cluster_ids)}")
|
|
@@ -546,20 +545,17 @@ def compute_clusters_style_representation_3(
|
|
| 546 |
|
| 547 |
# Filter out features that are not present in any of the authors
|
| 548 |
filtered_spans_by_author = {x[0] : x[1] for x in spans_by_author.items() if x[0] in {'Mystery author', 'Candidate Author 1', 'Candidate Author 2', 'Candidate Author 3'}.intersection(set(cluster_ids))}
|
| 549 |
-
print(
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
if found_in_any_author:
|
| 558 |
-
filtered_features.append(feature)
|
| 559 |
-
features = filtered_features
|
| 560 |
|
| 561 |
return {
|
| 562 |
-
"features":
|
| 563 |
"spans": spans_by_author
|
| 564 |
}
|
| 565 |
|
|
@@ -646,19 +642,17 @@ def compute_clusters_g2v_representation(
|
|
| 646 |
key=lambda x: -x[1] # Sort by contrastive score
|
| 647 |
)
|
| 648 |
|
| 649 |
-
# Filter
|
| 650 |
selected_authors = {'Mystery author', 'Candidate Author 1', 'Candidate Author 2', 'Candidate Author 3'}.intersection(set(author_ids))
|
| 651 |
-
|
| 652 |
-
|
|
|
|
| 653 |
filtered_features = []
|
| 654 |
for feature, score, z_score in top_g2v_feats:
|
| 655 |
-
|
| 656 |
-
for author_g2v_feats in
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
break
|
| 660 |
-
if found_in_any_author:
|
| 661 |
-
filtered_features.append((feature, score, z_score))
|
| 662 |
|
| 663 |
print('Filtered G2V features: ', [(f[0], f[2]) for f in filtered_features]) # Print feature names and z-scores
|
| 664 |
|
|
|
|
| 449 |
print(f"Cache miss. Computing features for authors: {author_names}")
|
| 450 |
|
| 451 |
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
| 452 |
+
prompt = f"""Identify {max_num_feats} writing style features that are commonly between the authors texts.
|
| 453 |
Author Texts:
|
|
|
|
|
|
|
| 454 |
|
| 455 |
+
{author_texts}
|
| 456 |
+
"""
|
| 457 |
def _make_call():
|
| 458 |
response = client.chat.completions.create(
|
| 459 |
model="gpt-4o-mini",
|
|
|
|
| 473 |
|
| 474 |
features = retry_call(_make_call, FeatureIdentificationSchema).features
|
| 475 |
|
|
|
|
| 476 |
if cache_key and author_names:
|
| 477 |
cache[cache_key] = {
|
| 478 |
"features": features
|
|
|
|
| 518 |
background_corpus_df: pd.DataFrame,
|
| 519 |
cluster_ids: List[Any],
|
| 520 |
cluster_label_clm_name: str = 'authorID',
|
| 521 |
+
max_num_feats: int = 20,
|
| 522 |
max_num_documents_per_author=3,
|
| 523 |
max_num_authors=5,
|
| 524 |
+
max_authors_for_span_extraction=4
|
| 525 |
):
|
| 526 |
|
| 527 |
print(f"Computing style representation for visible clusters: {len(cluster_ids)}")
|
|
|
|
| 545 |
|
| 546 |
# Filter out features that are not present in any of the authors
|
| 547 |
filtered_spans_by_author = {x[0] : x[1] for x in spans_by_author.items() if x[0] in {'Mystery author', 'Candidate Author 1', 'Candidate Author 2', 'Candidate Author 3'}.intersection(set(cluster_ids))}
|
| 548 |
+
print(filtered_spans_by_author.keys())
|
| 549 |
+
filtered_spans_by_author = [set([f[0] for f in x[1].items() if len(f[1]) > 0]) for x in filtered_spans_by_author.items()]
|
| 550 |
+
|
| 551 |
+
filtered_set_of_features = filtered_spans_by_author[0] # all features that appear in all the sets in the filtered_Spans_by_authors list
|
| 552 |
+
for x in filtered_spans_by_author[1:]:
|
| 553 |
+
filtered_set_of_features = filtered_set_of_features.intersection(x)
|
| 554 |
+
|
| 555 |
+
print('filtered set of features: ', filtered_set_of_features)
|
|
|
|
|
|
|
|
|
|
| 556 |
|
| 557 |
return {
|
| 558 |
+
"features": list(filtered_set_of_features),
|
| 559 |
"spans": spans_by_author
|
| 560 |
}
|
| 561 |
|
|
|
|
| 642 |
key=lambda x: -x[1] # Sort by contrastive score
|
| 643 |
)
|
| 644 |
|
| 645 |
+
# Filter in only features that are present in selected_authors
|
| 646 |
selected_authors = {'Mystery author', 'Candidate Author 1', 'Candidate Author 2', 'Candidate Author 3'}.intersection(set(author_ids))
|
| 647 |
+
# Filter in only features that are present in selected_authors
|
| 648 |
+
selected_authors_g2v_data = background_corpus_df[background_corpus_df['authorID'].isin(selected_authors)][features_clm_name].tolist()
|
| 649 |
+
|
| 650 |
filtered_features = []
|
| 651 |
for feature, score, z_score in top_g2v_feats:
|
| 652 |
+
# Check if the feature has a non-zero value in all of the selected authors
|
| 653 |
+
if all(author_g2v_feats.get(feature, 0) > 0 for author_g2v_feats in selected_authors_g2v_data):
|
| 654 |
+
filtered_features.append((feature, score, z_score)) # Only return feature and z-score
|
| 655 |
+
|
|
|
|
|
|
|
|
|
|
| 656 |
|
| 657 |
print('Filtered G2V features: ', [(f[0], f[2]) for f in filtered_features]) # Print feature names and z-scores
|
| 658 |
|
utils/visualizations.py
CHANGED
|
@@ -276,6 +276,7 @@ def handle_zoom(event_json, bg_proj, bg_lbls, clustered_authors_df, task_authors
|
|
| 276 |
print(f"Task authors: {len(task_authors_df)}, Clustered authors: {len(clustered_authors_df)}")
|
| 277 |
merged_authors_df = pd.concat([task_authors_df, clustered_authors_df])
|
| 278 |
print(f"Merged authors DataFrame:\n{len(merged_authors_df)}")
|
|
|
|
| 279 |
style_analysis_response = compute_clusters_style_representation_3(
|
| 280 |
background_corpus_df=merged_authors_df,
|
| 281 |
cluster_ids=visible_authors,
|
|
|
|
| 276 |
print(f"Task authors: {len(task_authors_df)}, Clustered authors: {len(clustered_authors_df)}")
|
| 277 |
merged_authors_df = pd.concat([task_authors_df, clustered_authors_df])
|
| 278 |
print(f"Merged authors DataFrame:\n{len(merged_authors_df)}")
|
| 279 |
+
#style_analysis_response = {'features': [], 'spans': []}
|
| 280 |
style_analysis_response = compute_clusters_style_representation_3(
|
| 281 |
background_corpus_df=merged_authors_df,
|
| 282 |
cluster_ids=visible_authors,
|