updates
Browse files- app.py +2 -2
- config/config.yaml +2 -2
- utils/gram2vec_feat_utils.py +1 -1
- utils/interp_space_utils.py +23 -29
- utils/visualizations.py +12 -9
app.py
CHANGED
|
@@ -58,8 +58,8 @@ def app(share=False, use_cluster_feats=False):
|
|
| 58 |
instances, instance_ids = get_instances(cfg['instances_to_explain_path'])
|
| 59 |
|
| 60 |
interp = load_interp_space(cfg)
|
| 61 |
-
clustered_authors_df = interp['clustered_authors_df'][:
|
| 62 |
-
clustered_authors_df['fullText'] = clustered_authors_df['fullText']
|
| 63 |
|
| 64 |
with gr.Blocks(title="Author Attribution Explainability Tool") as demo:
|
| 65 |
# ββ Big Centered Title ββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 58 |
instances, instance_ids = get_instances(cfg['instances_to_explain_path'])
|
| 59 |
|
| 60 |
interp = load_interp_space(cfg)
|
| 61 |
+
clustered_authors_df = interp['clustered_authors_df'][:1000]
|
| 62 |
+
clustered_authors_df['fullText'] = clustered_authors_df['fullText']
|
| 63 |
|
| 64 |
with gr.Blocks(title="Author Attribution Explainability Tool") as demo:
|
| 65 |
# ββ Big Centered Title ββββββββββββββββββββββββββββββββββββββββββ
|
config/config.yaml
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
# config.yaml
|
| 2 |
instances_to_explain_path: "./datasets/hrs_explanations.json"
|
| 3 |
instances_to_explain_url: "https://huggingface.co/datasets/miladalsh/explanation_tool_files/resolve/main/hrs_explanations_luar_clusters_18_balanced.json?/download=true"
|
| 4 |
-
interp_space_path: "./datasets/
|
| 5 |
-
interp_space_url: "https://huggingface.co/datasets/miladalsh/explanation_tool_files/resolve/main/
|
| 6 |
gram2vec_feats_path: "./datasets/gram2vec_feats.csv"
|
| 7 |
gram2vec_feats_url: "https://huggingface.co/datasets/miladalsh/explanation_tool_files/resolve/main/gram2vec_feats.csv?download=true"
|
| 8 |
|
|
|
|
| 1 |
# config.yaml
|
| 2 |
instances_to_explain_path: "./datasets/hrs_explanations.json"
|
| 3 |
instances_to_explain_url: "https://huggingface.co/datasets/miladalsh/explanation_tool_files/resolve/main/hrs_explanations_luar_clusters_18_balanced.json?/download=true"
|
| 4 |
+
interp_space_path: "./datasets/sentence_luar_interp_space_2_35/"
|
| 5 |
+
interp_space_url: "https://huggingface.co/datasets/miladalsh/explanation_tool_files/resolve/main/sentence_luar_interp_space_2_35.zip?download=true"
|
| 6 |
gram2vec_feats_path: "./datasets/gram2vec_feats.csv"
|
| 7 |
gram2vec_feats_url: "https://huggingface.co/datasets/miladalsh/explanation_tool_files/resolve/main/gram2vec_feats.csv?download=true"
|
| 8 |
|
utils/gram2vec_feat_utils.py
CHANGED
|
@@ -126,7 +126,7 @@ def highlight_both_spans(text, llm_spans, gram_spans):
|
|
| 126 |
|
| 127 |
|
| 128 |
def show_combined_spans_all(selected_feature_llm, selected_feature_g2v,
|
| 129 |
-
llm_style_feats_analysis, background_authors_embeddings_df, task_authors_embeddings_df, visible_authors, predicted_author=None, ground_truth_author=None, max_num_authors=
|
| 130 |
"""
|
| 131 |
For mystery + 3 candidates:
|
| 132 |
1. get llm spans via your existing cache+API
|
|
|
|
| 126 |
|
| 127 |
|
| 128 |
def show_combined_spans_all(selected_feature_llm, selected_feature_g2v,
|
| 129 |
+
llm_style_feats_analysis, background_authors_embeddings_df, task_authors_embeddings_df, visible_authors, predicted_author=None, ground_truth_author=None, max_num_authors=4):
|
| 130 |
"""
|
| 131 |
For mystery + 3 candidates:
|
| 132 |
1. get llm spans via your existing cache+API
|
utils/interp_space_utils.py
CHANGED
|
@@ -79,7 +79,7 @@ def compute_g2v_features(clustered_authors_df: pd.DataFrame, task_authors_df: pd
|
|
| 79 |
clustered_authors_df = pickle.load(f)
|
| 80 |
|
| 81 |
else: # Else compute and cache
|
| 82 |
-
g2v_feats_df = vectorizer.from_documents(author_texts, batch_size=
|
| 83 |
|
| 84 |
print(f"Number of g2v features: {len(g2v_feats_df)}")
|
| 85 |
print(f"Number of clustered_authors_df.authorID.tolist(): {len(clustered_authors_df.authorID.tolist())}")
|
|
@@ -471,11 +471,11 @@ def identify_style_features(author_texts: list[str], author_names: list[str], ma
|
|
| 471 |
print(f"Cache miss. Computing features for authors: {author_names}")
|
| 472 |
|
| 473 |
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
| 474 |
-
prompt = f"""Identify {max_num_feats} writing style features that are commonly
|
| 475 |
Author Texts:
|
| 476 |
-
\"\"\"{chr(10).join(author_texts)}\"\"\"
|
| 477 |
-
"""
|
| 478 |
|
|
|
|
|
|
|
| 479 |
def _make_call():
|
| 480 |
response = client.chat.completions.create(
|
| 481 |
model="gpt-4o-mini",
|
|
@@ -495,7 +495,6 @@ def identify_style_features(author_texts: list[str], author_names: list[str], ma
|
|
| 495 |
|
| 496 |
features = retry_call(_make_call, FeatureIdentificationSchema).features
|
| 497 |
|
| 498 |
-
print(f"Adding to zoom cache")
|
| 499 |
if cache_key and author_names:
|
| 500 |
cache[cache_key] = {
|
| 501 |
"features": features
|
|
@@ -541,10 +540,10 @@ def compute_clusters_style_representation_3(
|
|
| 541 |
background_corpus_df: pd.DataFrame,
|
| 542 |
cluster_ids: List[Any],
|
| 543 |
cluster_label_clm_name: str = 'authorID',
|
| 544 |
-
max_num_feats: int =
|
| 545 |
max_num_documents_per_author=3,
|
| 546 |
max_num_authors=5,
|
| 547 |
-
max_authors_for_span_extraction=
|
| 548 |
):
|
| 549 |
|
| 550 |
print(f"Computing style representation for visible clusters: {len(cluster_ids)}")
|
|
@@ -568,20 +567,17 @@ def compute_clusters_style_representation_3(
|
|
| 568 |
|
| 569 |
# Filter out features that are not present in any of the authors
|
| 570 |
filtered_spans_by_author = {x[0] : x[1] for x in spans_by_author.items() if x[0] in {'Mystery author', 'Candidate Author 1', 'Candidate Author 2', 'Candidate Author 3'}.intersection(set(cluster_ids))}
|
| 571 |
-
print(
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
if found_in_any_author:
|
| 580 |
-
filtered_features.append(feature)
|
| 581 |
-
features = filtered_features
|
| 582 |
|
| 583 |
return {
|
| 584 |
-
"features":
|
| 585 |
"spans": spans_by_author
|
| 586 |
}
|
| 587 |
|
|
@@ -668,19 +664,17 @@ def compute_clusters_g2v_representation(
|
|
| 668 |
key=lambda x: -x[1] # Sort by contrastive score
|
| 669 |
)
|
| 670 |
|
| 671 |
-
# Filter
|
| 672 |
selected_authors = {'Mystery author', 'Candidate Author 1', 'Candidate Author 2', 'Candidate Author 3'}.intersection(set(author_ids))
|
| 673 |
-
|
| 674 |
-
|
|
|
|
| 675 |
filtered_features = []
|
| 676 |
for feature, score, z_score in top_g2v_feats:
|
| 677 |
-
|
| 678 |
-
for author_g2v_feats in
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
break
|
| 682 |
-
if found_in_any_author:
|
| 683 |
-
filtered_features.append((feature, score, z_score))
|
| 684 |
|
| 685 |
print('Filtered G2V features: ', [(f[0], f[2]) for f in filtered_features]) # Print feature names and z-scores
|
| 686 |
|
|
|
|
| 79 |
clustered_authors_df = pickle.load(f)
|
| 80 |
|
| 81 |
else: # Else compute and cache
|
| 82 |
+
g2v_feats_df = vectorizer.from_documents(author_texts, batch_size=8)
|
| 83 |
|
| 84 |
print(f"Number of g2v features: {len(g2v_feats_df)}")
|
| 85 |
print(f"Number of clustered_authors_df.authorID.tolist(): {len(clustered_authors_df.authorID.tolist())}")
|
|
|
|
| 471 |
print(f"Cache miss. Computing features for authors: {author_names}")
|
| 472 |
|
| 473 |
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
| 474 |
+
prompt = f"""Identify {max_num_feats} writing style features that are commonly between the authors texts.
|
| 475 |
Author Texts:
|
|
|
|
|
|
|
| 476 |
|
| 477 |
+
{author_texts}
|
| 478 |
+
"""
|
| 479 |
def _make_call():
|
| 480 |
response = client.chat.completions.create(
|
| 481 |
model="gpt-4o-mini",
|
|
|
|
| 495 |
|
| 496 |
features = retry_call(_make_call, FeatureIdentificationSchema).features
|
| 497 |
|
|
|
|
| 498 |
if cache_key and author_names:
|
| 499 |
cache[cache_key] = {
|
| 500 |
"features": features
|
|
|
|
| 540 |
background_corpus_df: pd.DataFrame,
|
| 541 |
cluster_ids: List[Any],
|
| 542 |
cluster_label_clm_name: str = 'authorID',
|
| 543 |
+
max_num_feats: int = 20,
|
| 544 |
max_num_documents_per_author=3,
|
| 545 |
max_num_authors=5,
|
| 546 |
+
max_authors_for_span_extraction=4
|
| 547 |
):
|
| 548 |
|
| 549 |
print(f"Computing style representation for visible clusters: {len(cluster_ids)}")
|
|
|
|
| 567 |
|
| 568 |
# Filter out features that are not present in any of the authors
|
| 569 |
filtered_spans_by_author = {x[0] : x[1] for x in spans_by_author.items() if x[0] in {'Mystery author', 'Candidate Author 1', 'Candidate Author 2', 'Candidate Author 3'}.intersection(set(cluster_ids))}
|
| 570 |
+
print(filtered_spans_by_author.keys())
|
| 571 |
+
filtered_spans_by_author = [set([f[0] for f in x[1].items() if len(f[1]) > 0]) for x in filtered_spans_by_author.items()]
|
| 572 |
+
|
| 573 |
+
filtered_set_of_features = filtered_spans_by_author[0] # all features that appear in all the sets in the filtered_Spans_by_authors list
|
| 574 |
+
for x in filtered_spans_by_author[1:]:
|
| 575 |
+
filtered_set_of_features = filtered_set_of_features.intersection(x)
|
| 576 |
+
|
| 577 |
+
print('filtered set of features: ', filtered_set_of_features)
|
|
|
|
|
|
|
|
|
|
| 578 |
|
| 579 |
return {
|
| 580 |
+
"features": list(filtered_set_of_features),
|
| 581 |
"spans": spans_by_author
|
| 582 |
}
|
| 583 |
|
|
|
|
| 664 |
key=lambda x: -x[1] # Sort by contrastive score
|
| 665 |
)
|
| 666 |
|
| 667 |
+
# Filter in only features that are present in selected_authors
|
| 668 |
selected_authors = {'Mystery author', 'Candidate Author 1', 'Candidate Author 2', 'Candidate Author 3'}.intersection(set(author_ids))
|
| 669 |
+
# Filter in only features that are present in selected_authors
|
| 670 |
+
selected_authors_g2v_data = background_corpus_df[background_corpus_df['authorID'].isin(selected_authors)][features_clm_name].tolist()
|
| 671 |
+
|
| 672 |
filtered_features = []
|
| 673 |
for feature, score, z_score in top_g2v_feats:
|
| 674 |
+
# Check if the feature has a non-zero value in all of the selected authors
|
| 675 |
+
if all(author_g2v_feats.get(feature, 0) > 0 for author_g2v_feats in selected_authors_g2v_data):
|
| 676 |
+
filtered_features.append((feature, score, z_score)) # Only return feature and z-score
|
| 677 |
+
|
|
|
|
|
|
|
|
|
|
| 678 |
|
| 679 |
print('Filtered G2V features: ', [(f[0], f[2]) for f in filtered_features]) # Print feature names and z-scores
|
| 680 |
|
utils/visualizations.py
CHANGED
|
@@ -14,7 +14,7 @@ from utils.interp_space_utils import compute_clusters_style_representation_3, co
|
|
| 14 |
from utils.llm_feat_utils import split_features
|
| 15 |
from utils.gram2vec_feat_utils import get_shorthand, get_fullform
|
| 16 |
from gram2vec.feature_locator import find_feature_spans
|
| 17 |
-
|
| 18 |
import plotly.io as pio
|
| 19 |
|
| 20 |
def clean_text(text: str) -> str:
|
|
@@ -132,8 +132,10 @@ def compute_tsne_with_cache(embeddings: np.ndarray, cache_path: str = 'datasets/
|
|
| 132 |
return cache[hash_key]
|
| 133 |
else:
|
| 134 |
print("Computing t-SNE")
|
| 135 |
-
tsne_result = TSNE(n_components=2, learning_rate='auto',
|
| 136 |
-
|
|
|
|
|
|
|
| 137 |
cache[hash_key] = tsne_result
|
| 138 |
with open(cache_path, 'wb') as f:
|
| 139 |
pkl.dump(cache, f)
|
|
@@ -147,7 +149,7 @@ def load_interp_space(cfg):
|
|
| 147 |
|
| 148 |
# Load authors embeddings and their cluster labels
|
| 149 |
clustered_authors_df = pd.read_pickle(clustered_authors_path)
|
| 150 |
-
clustered_authors_df = clustered_authors_df[clustered_authors_df.cluster_label != -1]
|
| 151 |
author_embedding = clustered_authors_df.author_embedding.tolist()
|
| 152 |
author_labels = clustered_authors_df.cluster_label.tolist()
|
| 153 |
author_ids = clustered_authors_df.authorID.tolist()
|
|
@@ -276,11 +278,12 @@ def handle_zoom(event_json, bg_proj, bg_lbls, clustered_authors_df, task_authors
|
|
| 276 |
print(f"Task authors: {len(task_authors_df)}, Clustered authors: {len(clustered_authors_df)}")
|
| 277 |
merged_authors_df = pd.concat([task_authors_df, clustered_authors_df])
|
| 278 |
print(f"Merged authors DataFrame:\n{len(merged_authors_df)}")
|
| 279 |
-
style_analysis_response =
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
|
|
|
| 284 |
|
| 285 |
llm_feats = ['None'] + style_analysis_response['features']
|
| 286 |
|
|
|
|
| 14 |
from utils.llm_feat_utils import split_features
|
| 15 |
from utils.gram2vec_feat_utils import get_shorthand, get_fullform
|
| 16 |
from gram2vec.feature_locator import find_feature_spans
|
| 17 |
+
import umap
|
| 18 |
import plotly.io as pio
|
| 19 |
|
| 20 |
def clean_text(text: str) -> str:
|
|
|
|
| 132 |
return cache[hash_key]
|
| 133 |
else:
|
| 134 |
print("Computing t-SNE")
|
| 135 |
+
# tsne_result = TSNE(n_components=2, learning_rate='auto',
|
| 136 |
+
# init='random', perplexity=3).fit_transform(embeddings)
|
| 137 |
+
tsne_result = umap.UMAP(n_components=2, n_neighbors=15, min_dist=0.0, metric='cosine').fit_transform(embeddings)
|
| 138 |
+
|
| 139 |
cache[hash_key] = tsne_result
|
| 140 |
with open(cache_path, 'wb') as f:
|
| 141 |
pkl.dump(cache, f)
|
|
|
|
| 149 |
|
| 150 |
# Load authors embeddings and their cluster labels
|
| 151 |
clustered_authors_df = pd.read_pickle(clustered_authors_path)
|
| 152 |
+
#clustered_authors_df = clustered_authors_df[clustered_authors_df.cluster_label != -1]
|
| 153 |
author_embedding = clustered_authors_df.author_embedding.tolist()
|
| 154 |
author_labels = clustered_authors_df.cluster_label.tolist()
|
| 155 |
author_ids = clustered_authors_df.authorID.tolist()
|
|
|
|
| 278 |
print(f"Task authors: {len(task_authors_df)}, Clustered authors: {len(clustered_authors_df)}")
|
| 279 |
merged_authors_df = pd.concat([task_authors_df, clustered_authors_df])
|
| 280 |
print(f"Merged authors DataFrame:\n{len(merged_authors_df)}")
|
| 281 |
+
style_analysis_response = {'features': [], 'spans': []}
|
| 282 |
+
# style_analysis_response = compute_clusters_style_representation_3(
|
| 283 |
+
# background_corpus_df=merged_authors_df,
|
| 284 |
+
# cluster_ids=visible_authors,
|
| 285 |
+
# cluster_label_clm_name='authorID',
|
| 286 |
+
# )
|
| 287 |
|
| 288 |
llm_feats = ['None'] + style_analysis_response['features']
|
| 289 |
|