Commit
Β·
c0e59d3
1
Parent(s):
75cc8bf
changed default g2v clustering to contrastive, and added filtering to ensure spans show
Browse files- utils/gram2vec_feat_utils.py +11 -3
- utils/interp_space_utils.py +20 -5
- utils/visualizations.py +26 -2
utils/gram2vec_feat_utils.py
CHANGED
|
@@ -198,10 +198,18 @@ def show_combined_spans_all(selected_feature_llm, selected_feature_g2v,
|
|
| 198 |
)
|
| 199 |
combined_html = "<div>" + "\n<hr>\n".join(html_task_authors) + "</div>"
|
| 200 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
html_background_authors = create_html(
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
selected_feature_llm,
|
| 206 |
selected_feature_g2v,
|
| 207 |
short,
|
|
|
|
| 198 |
)
|
| 199 |
combined_html = "<div>" + "\n<hr>\n".join(html_task_authors) + "</div>"
|
| 200 |
|
| 201 |
+
# Filter background authors to those with at least one Gram2Vec span
|
| 202 |
+
bg_start = 4
|
| 203 |
+
bg_indices = list(range(bg_start, len(texts)))
|
| 204 |
+
kept_indices = [i for i in bg_indices if gram_spans_list[i]]
|
| 205 |
+
filtered_texts_bg = [texts[i] for i in kept_indices]
|
| 206 |
+
filtered_llm_bg = [llm_spans_list[i] for i in kept_indices]
|
| 207 |
+
filtered_gram_bg = [gram_spans_list[i] for i in kept_indices]
|
| 208 |
+
|
| 209 |
html_background_authors = create_html(
|
| 210 |
+
filtered_texts_bg,
|
| 211 |
+
filtered_llm_bg,
|
| 212 |
+
filtered_gram_bg,
|
| 213 |
selected_feature_llm,
|
| 214 |
selected_feature_g2v,
|
| 215 |
short,
|
utils/interp_space_utils.py
CHANGED
|
@@ -528,7 +528,7 @@ def compute_clusters_g2v_representation(
|
|
| 528 |
other_author_ids: List[Any],
|
| 529 |
features_clm_name: str,
|
| 530 |
top_n: int = 10,
|
| 531 |
-
mode: str = "
|
| 532 |
sharedness_method: str = "mean_minus_alpha_std",
|
| 533 |
alpha: float = 0.5
|
| 534 |
) -> List[str]:
|
|
@@ -569,14 +569,29 @@ def compute_clusters_g2v_representation(
|
|
| 569 |
# Contrastive mode (default): compute target mean and subtract contrast mean
|
| 570 |
all_g2v_values = np.array([list(x.values()) for x in selected_feats]).mean(axis=0)
|
| 571 |
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 575 |
|
| 576 |
final_g2v_feats_values = all_g2v_values - all_g2v_other_values
|
| 577 |
|
| 578 |
|
| 579 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 580 |
|
| 581 |
# Filter out features that are not present in any of the authors
|
| 582 |
selected_authors = {'Mystery author', 'Candidate Author 1', 'Candidate Author 2', 'Candidate Author 3'}.intersection(set(author_ids))
|
|
|
|
| 528 |
other_author_ids: List[Any],
|
| 529 |
features_clm_name: str,
|
| 530 |
top_n: int = 10,
|
| 531 |
+
mode: str = "contrastive",
|
| 532 |
sharedness_method: str = "mean_minus_alpha_std",
|
| 533 |
alpha: float = 0.5
|
| 534 |
) -> List[str]:
|
|
|
|
| 569 |
# Contrastive mode (default): compute target mean and subtract contrast mean
|
| 570 |
all_g2v_values = np.array([list(x.values()) for x in selected_feats]).mean(axis=0)
|
| 571 |
|
| 572 |
+
# If an explicit contrast set is provided, use it; otherwise use everyone outside selection
|
| 573 |
+
if other_author_ids:
|
| 574 |
+
explicit_mask = background_corpus_df['authorID'].isin(other_author_ids).to_numpy()
|
| 575 |
+
# Ensure contrast set is disjoint from the selected set
|
| 576 |
+
contrast_mask = np.logical_and(explicit_mask, ~selected_mask)
|
| 577 |
+
else:
|
| 578 |
+
contrast_mask = ~selected_mask
|
| 579 |
+
|
| 580 |
+
other_selected_feats = background_corpus_df[contrast_mask][features_clm_name].tolist()
|
| 581 |
+
if len(other_selected_feats) > 0:
|
| 582 |
+
all_g2v_other_values = np.array([list(x.values()) for x in other_selected_feats]).mean(axis=0)
|
| 583 |
+
else:
|
| 584 |
+
# No contrast docs β treat contrast mean as zeros
|
| 585 |
+
all_g2v_other_values = np.zeros_like(all_g2v_values)
|
| 586 |
|
| 587 |
final_g2v_feats_values = all_g2v_values - all_g2v_other_values
|
| 588 |
|
| 589 |
|
| 590 |
+
# Keep only features that have a positive contrastive score
|
| 591 |
+
top_g2v_feats = sorted(
|
| 592 |
+
[(feat, val) for feat, val in zip(all_g2v_feats, final_g2v_feats_values) if val > 0],
|
| 593 |
+
key=lambda x: -x[1]
|
| 594 |
+
)
|
| 595 |
|
| 596 |
# Filter out features that are not present in any of the authors
|
| 597 |
selected_authors = {'Mystery author', 'Candidate Author 1', 'Candidate Author 2', 'Candidate Author 3'}.intersection(set(author_ids))
|
utils/visualizations.py
CHANGED
|
@@ -13,6 +13,7 @@ import re
|
|
| 13 |
from utils.interp_space_utils import compute_clusters_style_representation_3, compute_clusters_g2v_representation
|
| 14 |
from utils.llm_feat_utils import split_features
|
| 15 |
from utils.gram2vec_feat_utils import get_shorthand, get_fullform
|
|
|
|
| 16 |
|
| 17 |
import plotly.io as pio
|
| 18 |
|
|
@@ -251,9 +252,32 @@ def handle_zoom(event_json, bg_proj, bg_lbls, clustered_authors_df, task_authors
|
|
| 251 |
features_clm_name='g2v_vector'
|
| 252 |
)
|
| 253 |
|
| 254 |
-
#
|
| 255 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
for feat in g2v_feats:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
HR_g2v = get_fullform(feat)
|
| 258 |
print(f"\n\n feat: {feat} ---> Human Readable: {HR_g2v}")
|
| 259 |
if HR_g2v is None:
|
|
|
|
| 13 |
from utils.interp_space_utils import compute_clusters_style_representation_3, compute_clusters_g2v_representation
|
| 14 |
from utils.llm_feat_utils import split_features
|
| 15 |
from utils.gram2vec_feat_utils import get_shorthand, get_fullform
|
| 16 |
+
from gram2vec.feature_locator import find_feature_spans
|
| 17 |
|
| 18 |
import plotly.io as pio
|
| 19 |
|
|
|
|
| 252 |
features_clm_name='g2v_vector'
|
| 253 |
)
|
| 254 |
|
| 255 |
+
# ββ Span-existence filter on task authors in the zoom βββββββββββββββββββ
|
| 256 |
+
# Keep only features that have at least one detected span in any of the
|
| 257 |
+
# visible task authors' texts
|
| 258 |
+
visible_task_authors = task_authors_df[task_authors_df['authorID'].isin(visible_authors)]
|
| 259 |
+
if visible_task_authors.empty:
|
| 260 |
+
visible_task_authors = task_authors_df
|
| 261 |
+
|
| 262 |
+
def _to_text(x):
|
| 263 |
+
return '\n\n =========== \n\n'.join(x) if isinstance(x, list) else x
|
| 264 |
+
|
| 265 |
+
task_texts = [_to_text(x) for x in visible_task_authors['fullText'].tolist()]
|
| 266 |
+
|
| 267 |
+
filtered_g2v_feats = []
|
| 268 |
for feat in g2v_feats:
|
| 269 |
+
try:
|
| 270 |
+
# `feat` is shorthand already (e.g., 'pos_bigrams:NOUN PROPN')
|
| 271 |
+
if any(find_feature_spans(txt, feat) for txt in task_texts):
|
| 272 |
+
filtered_g2v_feats.append(feat)
|
| 273 |
+
else:
|
| 274 |
+
print(f"[INFO] Dropping G2V feature with no spans in task texts: {feat}")
|
| 275 |
+
except Exception as e:
|
| 276 |
+
print(f"[WARN] Error while checking spans for {feat}: {e}")
|
| 277 |
+
|
| 278 |
+
# Convert to human readable for display
|
| 279 |
+
HR_g2v_list = []
|
| 280 |
+
for feat in filtered_g2v_feats:
|
| 281 |
HR_g2v = get_fullform(feat)
|
| 282 |
print(f"\n\n feat: {feat} ---> Human Readable: {HR_g2v}")
|
| 283 |
if HR_g2v is None:
|