Anisha Bhatnagar
commited on
Commit
·
74947b9
1
Parent(s):
8367823
show span in background authors issue fixed
Browse files- utils/gram2vec_feat_utils.py +18 -7
- utils/interp_space_utils.py +4 -3
- utils/llm_feat_utils.py +3 -0
- utils/visualizations.py +1 -1
utils/gram2vec_feat_utils.py
CHANGED
|
@@ -126,7 +126,7 @@ def highlight_both_spans(text, llm_spans, gram_spans):
|
|
| 126 |
|
| 127 |
|
| 128 |
def show_combined_spans_all(selected_feature_llm, selected_feature_g2v,
|
| 129 |
-
llm_style_feats_analysis, background_authors_embeddings_df, task_authors_embeddings_df, visible_authors, predicted_author=None, ground_truth_author=None, max_num_authors=
|
| 130 |
"""
|
| 131 |
For mystery + 3 candidates:
|
| 132 |
1. get llm spans via your existing cache+API
|
|
@@ -152,9 +152,11 @@ def show_combined_spans_all(selected_feature_llm, selected_feature_g2v,
|
|
| 152 |
|
| 153 |
if selected_feature_llm and selected_feature_llm != "None":
|
| 154 |
# print(llm_style_feats_analysis)
|
|
|
|
| 155 |
author_list = list(llm_style_feats_analysis['spans'].values())
|
| 156 |
llm_spans_list = []
|
| 157 |
for i, (_, txt) in enumerate(texts):
|
|
|
|
| 158 |
author_spans_list = []
|
| 159 |
for txt_span in author_list[i][selected_feature_llm]:
|
| 160 |
author_spans_list.append(Span(txt.find(txt_span), txt.find(txt_span) + len(txt_span)))
|
|
@@ -167,6 +169,8 @@ def show_combined_spans_all(selected_feature_llm, selected_feature_g2v,
|
|
| 167 |
if selected_feature_g2v and selected_feature_g2v != "None":
|
| 168 |
# get gram2vec spans
|
| 169 |
gram_spans_list = []
|
|
|
|
|
|
|
| 170 |
print(f"Selected Gram2Vec feature: {selected_feature_g2v}")
|
| 171 |
short = get_shorthand(selected_feature_g2v)
|
| 172 |
print(f"short hand: {short}")
|
|
@@ -199,14 +203,19 @@ def show_combined_spans_all(selected_feature_llm, selected_feature_g2v,
|
|
| 199 |
)
|
| 200 |
combined_html = "<div>" + "\n<hr>\n".join(html_task_authors) + "</div>"
|
| 201 |
|
|
|
|
|
|
|
| 202 |
# Filter background authors to those with at least one Gram2Vec span
|
| 203 |
bg_start = 4
|
| 204 |
bg_indices = list(range(bg_start, len(texts)))
|
| 205 |
kept_indices = [i for i in bg_indices if gram_spans_list[i]]
|
|
|
|
| 206 |
filtered_texts_bg = [texts[i] for i in kept_indices]
|
| 207 |
filtered_llm_bg = [llm_spans_list[i] for i in kept_indices]
|
| 208 |
filtered_gram_bg = [gram_spans_list[i] for i in kept_indices]
|
| 209 |
|
|
|
|
|
|
|
| 210 |
html_background_authors = create_html(
|
| 211 |
filtered_texts_bg,
|
| 212 |
filtered_llm_bg,
|
|
@@ -219,6 +228,7 @@ def show_combined_spans_all(selected_feature_llm, selected_feature_g2v,
|
|
| 219 |
ground_truth_author=ground_truth_author
|
| 220 |
)
|
| 221 |
background_html = "<div>" + "\n<hr>\n".join(html_background_authors) + "</div>"
|
|
|
|
| 222 |
return combined_html, background_html
|
| 223 |
|
| 224 |
def get_label(label: str, predicted_author=None, ground_truth_author=None, bg_id: int=0) -> str:
|
|
@@ -230,26 +240,27 @@ def get_label(label: str, predicted_author=None, ground_truth_author=None, bg_id
|
|
| 230 |
return "Mystery Author"
|
| 231 |
elif label.startswith("a0_author") or label.startswith("a1_author") or label.startswith("a2_author") or label.startswith("Candidate"):
|
| 232 |
if label.startswith("Candidate"):
|
| 233 |
-
id = int(label.split(" ")[2]) # Get the number after 'Candidate Author'
|
| 234 |
else:
|
| 235 |
id = label.split("_")[0][-1] # Get the last character of the first part (a0, a1, a2)
|
| 236 |
if predicted_author is not None and ground_truth_author is not None:
|
| 237 |
if int(id) == predicted_author and int(id) == ground_truth_author:
|
| 238 |
-
return f"Candidate {int(id)} (Predicted & Ground Truth)"
|
| 239 |
elif int(id) == predicted_author:
|
| 240 |
-
return f"Candidate {int(id)} (Predicted)"
|
| 241 |
elif int(id) == ground_truth_author:
|
| 242 |
-
return f"Candidate {int(id)} (Ground Truth)"
|
| 243 |
else:
|
| 244 |
-
return f"Candidate {int(id)}"
|
| 245 |
else:
|
| 246 |
-
return f"Candidate {int(id)}"
|
| 247 |
else:
|
| 248 |
return f"Background Author {bg_id+1}"
|
| 249 |
|
| 250 |
def create_html(texts, llm_spans_list, gram_spans_list, selected_feature_llm, selected_feature_g2v, short=None, background = False, predicted_author=None, ground_truth_author=None):
|
| 251 |
html = []
|
| 252 |
for i, (label, txt) in enumerate(texts):
|
|
|
|
| 253 |
label = get_label(label, predicted_author, ground_truth_author, i) if background else get_label(label, predicted_author, ground_truth_author)
|
| 254 |
combined = highlight_both_spans(txt, llm_spans_list[i], gram_spans_list[i])
|
| 255 |
notice = ""
|
|
|
|
| 126 |
|
| 127 |
|
| 128 |
def show_combined_spans_all(selected_feature_llm, selected_feature_g2v,
|
| 129 |
+
llm_style_feats_analysis, background_authors_embeddings_df, task_authors_embeddings_df, visible_authors, predicted_author=None, ground_truth_author=None, max_num_authors=7):
|
| 130 |
"""
|
| 131 |
For mystery + 3 candidates:
|
| 132 |
1. get llm spans via your existing cache+API
|
|
|
|
| 152 |
|
| 153 |
if selected_feature_llm and selected_feature_llm != "None":
|
| 154 |
# print(llm_style_feats_analysis)
|
| 155 |
+
print(f"{len(llm_style_feats_analysis['spans'].values())}")
|
| 156 |
author_list = list(llm_style_feats_analysis['spans'].values())
|
| 157 |
llm_spans_list = []
|
| 158 |
for i, (_, txt) in enumerate(texts):
|
| 159 |
+
print(f"{i}/{len(texts)}")
|
| 160 |
author_spans_list = []
|
| 161 |
for txt_span in author_list[i][selected_feature_llm]:
|
| 162 |
author_spans_list.append(Span(txt.find(txt_span), txt.find(txt_span) + len(txt_span)))
|
|
|
|
| 169 |
if selected_feature_g2v and selected_feature_g2v != "None":
|
| 170 |
# get gram2vec spans
|
| 171 |
gram_spans_list = []
|
| 172 |
+
# clean the display string and get the feature name without the zscore
|
| 173 |
+
selected_feature_g2v = selected_feature_g2v.split(" | [Z=")[0].strip()
|
| 174 |
print(f"Selected Gram2Vec feature: {selected_feature_g2v}")
|
| 175 |
short = get_shorthand(selected_feature_g2v)
|
| 176 |
print(f"short hand: {short}")
|
|
|
|
| 203 |
)
|
| 204 |
combined_html = "<div>" + "\n<hr>\n".join(html_task_authors) + "</div>"
|
| 205 |
|
| 206 |
+
# print(f"\n\n\n\n{texts[4:]}")
|
| 207 |
+
|
| 208 |
# Filter background authors to those with at least one Gram2Vec span
|
| 209 |
bg_start = 4
|
| 210 |
bg_indices = list(range(bg_start, len(texts)))
|
| 211 |
kept_indices = [i for i in bg_indices if gram_spans_list[i]]
|
| 212 |
+
print(f"\n---> {kept_indices}")
|
| 213 |
filtered_texts_bg = [texts[i] for i in kept_indices]
|
| 214 |
filtered_llm_bg = [llm_spans_list[i] for i in kept_indices]
|
| 215 |
filtered_gram_bg = [gram_spans_list[i] for i in kept_indices]
|
| 216 |
|
| 217 |
+
print(filtered_texts_bg)
|
| 218 |
+
|
| 219 |
html_background_authors = create_html(
|
| 220 |
filtered_texts_bg,
|
| 221 |
filtered_llm_bg,
|
|
|
|
| 228 |
ground_truth_author=ground_truth_author
|
| 229 |
)
|
| 230 |
background_html = "<div>" + "\n<hr>\n".join(html_background_authors) + "</div>"
|
| 231 |
+
# print(f"Background HTML: {background_html}")
|
| 232 |
return combined_html, background_html
|
| 233 |
|
| 234 |
def get_label(label: str, predicted_author=None, ground_truth_author=None, bg_id: int=0) -> str:
|
|
|
|
| 240 |
return "Mystery Author"
|
| 241 |
elif label.startswith("a0_author") or label.startswith("a1_author") or label.startswith("a2_author") or label.startswith("Candidate"):
|
| 242 |
if label.startswith("Candidate"):
|
| 243 |
+
id = int(label.split(" ")[2])-1 # Get the number after 'Candidate Author'; convert to 0 index
|
| 244 |
else:
|
| 245 |
id = label.split("_")[0][-1] # Get the last character of the first part (a0, a1, a2)
|
| 246 |
if predicted_author is not None and ground_truth_author is not None:
|
| 247 |
if int(id) == predicted_author and int(id) == ground_truth_author:
|
| 248 |
+
return f"Candidate {int(id)+1} (Predicted & Ground Truth)"
|
| 249 |
elif int(id) == predicted_author:
|
| 250 |
+
return f"Candidate {int(id)+1} (Predicted)"
|
| 251 |
elif int(id) == ground_truth_author:
|
| 252 |
+
return f"Candidate {int(id)+1} (Ground Truth)"
|
| 253 |
else:
|
| 254 |
+
return f"Candidate {int(id)+1}"
|
| 255 |
else:
|
| 256 |
+
return f"Candidate {int(id)+1}"
|
| 257 |
else:
|
| 258 |
return f"Background Author {bg_id+1}"
|
| 259 |
|
| 260 |
def create_html(texts, llm_spans_list, gram_spans_list, selected_feature_llm, selected_feature_g2v, short=None, background = False, predicted_author=None, ground_truth_author=None):
|
| 261 |
html = []
|
| 262 |
for i, (label, txt) in enumerate(texts):
|
| 263 |
+
print(i, label, txt[:30])
|
| 264 |
label = get_label(label, predicted_author, ground_truth_author, i) if background else get_label(label, predicted_author, ground_truth_author)
|
| 265 |
combined = highlight_both_spans(txt, llm_spans_list[i], gram_spans_list[i])
|
| 266 |
notice = ""
|
utils/interp_space_utils.py
CHANGED
|
@@ -521,7 +521,8 @@ def compute_clusters_style_representation_3(
|
|
| 521 |
cluster_label_clm_name: str = 'authorID',
|
| 522 |
max_num_feats: int = 10,
|
| 523 |
max_num_documents_per_author=3,
|
| 524 |
-
max_num_authors=5
|
|
|
|
| 525 |
):
|
| 526 |
|
| 527 |
print(f"Computing style representation for visible clusters: {len(cluster_ids)}")
|
|
@@ -537,8 +538,8 @@ def compute_clusters_style_representation_3(
|
|
| 537 |
features = identify_style_features(author_texts, author_names, max_num_feats=max_num_feats)
|
| 538 |
|
| 539 |
# STEP 2: Prepare author pool for span extraction
|
| 540 |
-
span_df = background_corpus_df.iloc[:
|
| 541 |
-
author_names = span_df[cluster_label_clm_name].tolist()[:
|
| 542 |
print(f"Number of authors for span detection : {len(span_df)}")
|
| 543 |
print(author_names)
|
| 544 |
spans_by_author = extract_all_spans(span_df, features, cluster_label_clm_name)
|
|
|
|
| 521 |
cluster_label_clm_name: str = 'authorID',
|
| 522 |
max_num_feats: int = 10,
|
| 523 |
max_num_documents_per_author=3,
|
| 524 |
+
max_num_authors=5,
|
| 525 |
+
max_authors_for_span_extraction=7
|
| 526 |
):
|
| 527 |
|
| 528 |
print(f"Computing style representation for visible clusters: {len(cluster_ids)}")
|
|
|
|
| 538 |
features = identify_style_features(author_texts, author_names, max_num_feats=max_num_feats)
|
| 539 |
|
| 540 |
# STEP 2: Prepare author pool for span extraction
|
| 541 |
+
span_df = background_corpus_df.iloc[:max_authors_for_span_extraction]
|
| 542 |
+
author_names = span_df[cluster_label_clm_name].tolist()[:max_authors_for_span_extraction]
|
| 543 |
print(f"Number of authors for span detection : {len(span_df)}")
|
| 544 |
print(author_names)
|
| 545 |
spans_by_author = extract_all_spans(span_df, features, cluster_label_clm_name)
|
utils/llm_feat_utils.py
CHANGED
|
@@ -3,6 +3,7 @@ import os
|
|
| 3 |
import hashlib
|
| 4 |
import time
|
| 5 |
from json import JSONDecodeError
|
|
|
|
| 6 |
|
| 7 |
CACHE_DIR = "datasets/feature_spans_cache"
|
| 8 |
os.makedirs(CACHE_DIR, exist_ok=True)
|
|
@@ -66,10 +67,12 @@ def generate_feature_spans_with_retries(client, text: str, features: list[str])
|
|
| 66 |
for attempt in range(MAX_ATTEMPTS):
|
| 67 |
try:
|
| 68 |
response_str = generate_feature_spans(client, text, features)
|
|
|
|
| 69 |
result = json.loads(response_str)
|
| 70 |
return result
|
| 71 |
except (JSONDecodeError, ValueError) as e:
|
| 72 |
print(f"Attempt {attempt+1} failed: {e}")
|
|
|
|
| 73 |
if attempt < MAX_ATTEMPTS - 1:
|
| 74 |
wait_sec = WAIT_SECONDS * (2 ** attempt)
|
| 75 |
print(f"Retrying after {wait_sec} seconds...")
|
|
|
|
| 3 |
import hashlib
|
| 4 |
import time
|
| 5 |
from json import JSONDecodeError
|
| 6 |
+
import traceback
|
| 7 |
|
| 8 |
CACHE_DIR = "datasets/feature_spans_cache"
|
| 9 |
os.makedirs(CACHE_DIR, exist_ok=True)
|
|
|
|
| 67 |
for attempt in range(MAX_ATTEMPTS):
|
| 68 |
try:
|
| 69 |
response_str = generate_feature_spans(client, text, features)
|
| 70 |
+
print(response_str)
|
| 71 |
result = json.loads(response_str)
|
| 72 |
return result
|
| 73 |
except (JSONDecodeError, ValueError) as e:
|
| 74 |
print(f"Attempt {attempt+1} failed: {e}")
|
| 75 |
+
traceback.print_exc()
|
| 76 |
if attempt < MAX_ATTEMPTS - 1:
|
| 77 |
wait_sec = WAIT_SECONDS * (2 ** attempt)
|
| 78 |
print(f"Retrying after {wait_sec} seconds...")
|
utils/visualizations.py
CHANGED
|
@@ -225,7 +225,7 @@ def format_g2v_features_for_display(g2v_features_with_scores):
|
|
| 225 |
z_score = float(z_score)
|
| 226 |
|
| 227 |
# Create display string with z-score
|
| 228 |
-
display_string = f"{feature_name} | Z={z_score:.2f}]"
|
| 229 |
display_choices.append(display_string)
|
| 230 |
original_values.append(feature_name)
|
| 231 |
else:
|
|
|
|
| 225 |
z_score = float(z_score)
|
| 226 |
|
| 227 |
# Create display string with z-score
|
| 228 |
+
display_string = f"{feature_name} | [Z={z_score:.2f}]"
|
| 229 |
display_choices.append(display_string)
|
| 230 |
original_values.append(feature_name)
|
| 231 |
else:
|