peter-zeng commited on
Commit
5d8a39b
·
1 Parent(s): 3676d21

Revert "updated to use mystery + predicted"

Browse files

This reverts commit 07c4d0f325f6905ba1f854f1014066abb46d2a72.
remove gram2vec changes

Files changed (1) hide show
  1. utils/interp_space_utils.py +3 -133
utils/interp_space_utils.py CHANGED
@@ -31,14 +31,6 @@ os.makedirs(os.path.dirname(REGION_CACHE), exist_ok=True)
31
  # Bump this whenever there is a change etc...
32
  CACHE_VERSION = 1
33
 
34
- # Features to exclude from Gram2Vec outputs
35
- EXCLUDED_G2V_FEATURE_PREFIXES = [
36
- 'num_tokens'
37
- ]
38
- EXCLUDED_G2V_FEATURES = set([
39
- 'num_tokens:num_tokens'
40
- ])
41
-
42
  class style_analysis_schema(BaseModel):
43
  features: list[str]
44
  spans: dict[str, dict[str, list[str]]]
@@ -67,8 +59,8 @@ def compute_g2v_features(clustered_authors_df: pd.DataFrame, task_authors_df: pd
67
  print(f"Number of authors after concatenation: {len(clustered_authors_df)}")
68
 
69
  # Gather the input texts (preserves list-of-strings if any)
70
- # If an entry is a list of strings, join; otherwise use the string as-is
71
- author_texts = [('\n\n'.join(x) if isinstance(x, list) else x) for x in clustered_authors_df.fullText.tolist()]
72
 
73
  print(f"Number of author_texts: {len(author_texts)}")
74
 
@@ -694,11 +686,7 @@ def compute_clusters_g2v_representation(
694
 
695
  # Keep only features that have a positive contrastive score
696
  top_g2v_feats = sorted(
697
- [
698
- (feat, val, z_score)
699
- for feat, val, z_score in zip(all_g2v_feats, final_g2v_feats_values, z_scores)
700
- if val > 0 and feat not in EXCLUDED_G2V_FEATURES and not any(feat.startswith(p) for p in EXCLUDED_G2V_FEATURE_PREFIXES)
701
- ],
702
  key=lambda x: -x[1] # Sort by contrastive score
703
  )
704
 
@@ -788,124 +776,6 @@ def compute_clusters_g2v_representation(
788
 
789
  return filtered_features[:top_n] # Return tuples with z-scores
790
 
791
- def compute_task_only_g2v_similarity(
792
- background_corpus_df: pd.DataFrame,
793
- visible_author_ids: List[Any],
794
- features_clm_name: str = 'g2v_vector',
795
- top_n: int = 10,
796
- require_spans: bool = True
797
- ) -> List[tuple]:
798
- """
799
- Compute top Gram2Vec features that are shared between the Mystery author and the
800
- predicted Candidate author, ignoring background authors and contrast.
801
-
802
- Selection is limited to task authors within the zoom (i.e., present in
803
- `visible_author_ids`). A feature is kept if:
804
- - it has a positive value (> 0) for both Mystery and Predicted Candidate,
805
- - and (optionally) at least one detected span exists in both authors' texts.
806
-
807
- Scoring strategy prioritizes features strong in both authors: score = min(mystery_value, predicted_value).
808
-
809
- Returns a list of (feature_name, score) tuples sorted by score desc, limited to top_n.
810
- """
811
- task_names = {'Mystery author', 'Candidate Author 1', 'Candidate Author 2', 'Candidate Author 3'}
812
-
813
- # Filter to visible task authors
814
- is_visible = background_corpus_df['authorID'].isin(visible_author_ids)
815
- is_task = background_corpus_df['authorID'].isin(task_names)
816
- visible_task_df = background_corpus_df[is_visible & is_task]
817
-
818
- if visible_task_df.empty:
819
- return []
820
-
821
- # Identify Mystery author row within the visible set
822
- mystery_rows = visible_task_df[visible_task_df['authorID'] == 'Mystery author']
823
- if mystery_rows.empty:
824
- # If Mystery is not visible, fall back to using any available Mystery row in the corpus
825
- mystery_rows = background_corpus_df[background_corpus_df['authorID'] == 'Mystery author']
826
- if mystery_rows.empty:
827
- return []
828
-
829
- mystery_row = mystery_rows.iloc[0]
830
-
831
- # Identify the predicted candidate within the visible set using the 'predicted' flag if present
832
- predicted_row = None
833
- if 'predicted' in visible_task_df.columns:
834
- pred_candidates = visible_task_df[visible_task_df['predicted'] == True]
835
- if not pred_candidates.empty:
836
- predicted_row = pred_candidates.iloc[0]
837
-
838
- # If not found in visible, try to find anywhere in the corpus
839
- if predicted_row is None and 'predicted' in background_corpus_df.columns:
840
- pred_any = background_corpus_df[background_corpus_df['predicted'] == True]
841
- # Prefer one that is also a task author
842
- pred_any = pred_any[pred_any['authorID'].isin(task_names)] if not pred_any.empty else pred_any
843
- if not pred_any.empty:
844
- predicted_row = pred_any.iloc[0]
845
-
846
- # If still not found, we cannot build a pair
847
- if predicted_row is None:
848
- return []
849
-
850
- mystery_vec = mystery_row.get(features_clm_name, {})
851
- predicted_vec = predicted_row.get(features_clm_name, {})
852
-
853
- if not isinstance(mystery_vec, dict) or not isinstance(predicted_vec, dict):
854
- return []
855
-
856
- # Prepare texts for optional span gating
857
- def _norm_txt(x):
858
- if isinstance(x, list):
859
- return '\n\n'.join(x)
860
- return str(x)
861
- mystery_text = _norm_txt(mystery_row.get('fullText', ''))
862
- predicted_text = _norm_txt(predicted_row.get('fullText', ''))
863
-
864
- try:
865
- from gram2vec.feature_locator import find_feature_spans as _find_feature_spans
866
- except Exception:
867
- _find_feature_spans = None
868
-
869
- shared_features = []
870
- # Iterate over union of feature keys (both authors share the same feature space in practice)
871
- for feature_name in set(list(mystery_vec.keys()) + list(predicted_vec.keys())):
872
- # Exclude unwanted features
873
- if feature_name in EXCLUDED_G2V_FEATURES or any(feature_name.startswith(p) for p in EXCLUDED_G2V_FEATURE_PREFIXES):
874
- continue
875
- m_val = float(mystery_vec.get(feature_name, 0.0))
876
- p_val = float(predicted_vec.get(feature_name, 0.0))
877
-
878
- # Optional span gate: require at least one span in both texts
879
- spans_m = spans_p = None
880
- if require_spans and _find_feature_spans is not None:
881
- try:
882
- spans_m = _find_feature_spans(mystery_text, feature_name) or []
883
- spans_p = _find_feature_spans(predicted_text, feature_name) or []
884
- if len(spans_m) == 0 or len(spans_p) == 0:
885
- continue
886
- except Exception:
887
- # On span errors, skip gating and proceed
888
- spans_m = spans_m if spans_m is not None else []
889
- spans_p = spans_p if spans_p is not None else []
890
-
891
- # Similarity metric: |m| + |p| - |m - p|
892
- score = abs(m_val) + abs(p_val) - abs(m_val - p_val)
893
- shared_features.append((feature_name, score, m_val, p_val, len(spans_m) if spans_m is not None else -1, len(spans_p) if spans_p is not None else -1))
894
-
895
- # Rank by score desc and return top_n
896
- shared_features.sort(key=lambda x: x[1], reverse=True)
897
- top = shared_features[:top_n]
898
-
899
- # Debug print of top-N with values and span counts for presence sanity-check
900
- try:
901
- print("[DEBUG] Task-only G2V top features (feature, mystery_val, predicted_val, score | spans_mystery, spans_predicted):")
902
- for feat_name, sc, m_val, p_val, c_m, c_p in top:
903
- print(f" {feat_name} | mystery={m_val:.4f}, predicted={p_val:.4f}, S={sc:.4f} | spans=({c_m}, {c_p})")
904
- except Exception:
905
- pass
906
-
907
- return [(f, s) for (f, s, _, _, _, _) in top]
908
-
909
  def generate_interpretable_space_representation(interp_space_path, styles_df_path, feat_clm, output_clm, num_feats=5):
910
 
911
  styles_df = pd.read_csv(styles_df_path)[[feat_clm, "documentID"]]
 
31
  # Bump this whenever there is a change etc...
32
  CACHE_VERSION = 1
33
 
 
 
 
 
 
 
 
 
34
  class style_analysis_schema(BaseModel):
35
  features: list[str]
36
  spans: dict[str, dict[str, list[str]]]
 
59
  print(f"Number of authors after concatenation: {len(clustered_authors_df)}")
60
 
61
  # Gather the input texts (preserves list-of-strings if any)
62
+ #texts = background_corpus_df[text_clm].fillna("").tolist()
63
+ author_texts = ['\n\n'.join(x) for x in clustered_authors_df.fullText.tolist()]
64
 
65
  print(f"Number of author_texts: {len(author_texts)}")
66
 
 
686
 
687
  # Keep only features that have a positive contrastive score
688
  top_g2v_feats = sorted(
689
+ [(feat, val, z_score) for feat, val, z_score in zip(all_g2v_feats, final_g2v_feats_values, z_scores) if val > 0],
 
 
 
 
690
  key=lambda x: -x[1] # Sort by contrastive score
691
  )
692
 
 
776
 
777
  return filtered_features[:top_n] # Return tuples with z-scores
778
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
779
  def generate_interpretable_space_representation(interp_space_path, styles_df_path, feat_clm, output_clm, num_feats=5):
780
 
781
  styles_df = pd.read_csv(styles_df_path)[[feat_clm, "documentID"]]