Milad Alshomary commited on
Commit
ea3113e
·
1 Parent(s): b623cb3
cluster_corpus.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import pandas as pd
3
+ import numpy as np
4
+ import os
5
+
6
+ from utils.interp_space_utils import cached_generate_style_embedding
7
+ from utils.clustering_utils import clustering_author
8
+
9
+ def load_corpus(filepath: str) -> pd.DataFrame:
10
+ """
11
+ Loads a corpus from a CSV or Pickle file into a pandas DataFrame.
12
+ The file is expected to have 'authorID' and 'fullText' columns.
13
+ """
14
+ print(f"Loading corpus from {filepath}...")
15
+ if filepath.endswith('.csv'):
16
+ df = pd.read_csv(filepath)
17
+ elif filepath.endswith('.pkl'):
18
+ df = pd.read_pickle(filepath)
19
+ else:
20
+ raise ValueError("Unsupported file format. Please use .csv or .pkl")
21
+
22
+ if 'authorID' not in df.columns or 'fullText' not in df.columns:
23
+ raise ValueError("Corpus must contain 'authorID' and 'fullText' columns.")
24
+
25
+ print(f"Corpus loaded successfully with {len(df)} documents.")
26
+ return df
27
+
28
+ def main():
29
+ """
30
+ Main function to run the clustering workflow.
31
+ """
32
+ parser = argparse.ArgumentParser(
33
+ description="Generate style embeddings and cluster a corpus of documents."
34
+ )
35
+ parser.add_argument(
36
+ "corpus_path",
37
+ type=str,
38
+ help="Path to the corpus file (.csv or .pkl)."
39
+ )
40
+ parser.add_argument(
41
+ "model_name",
42
+ type=str,
43
+ help="Hugging Face model name for sentence-transformer embeddings (e.g., 'AnnaWegmann/Style-Embedding')."
44
+ )
45
+ parser.add_argument(
46
+ "output_path",
47
+ type=str,
48
+ help="Path to save the output DataFrame with embeddings and clusters (.pkl)."
49
+ )
50
+ parser.add_argument(
51
+ "--min_samples",
52
+ type=int,
53
+ default=5,
54
+ help="min_samples parameter for DBSCAN clustering."
55
+ )
56
+ parser.add_argument(
57
+ "--metric",
58
+ type=str,
59
+ default='cosine',
60
+ choices=['cosine', 'euclidean'],
61
+ help="Distance metric for DBSCAN clustering."
62
+ )
63
+
64
+ args = parser.parse_args()
65
+
66
+ # 1. Load the corpus
67
+ corpus_df = load_corpus(args.corpus_path)
68
+
69
+ # 2. Generate style embeddings
70
+ print(f"\nGenerating style embeddings with model: {args.model_name}")
71
+ # The function returns two dataframes, we are only interested in the first one here.
72
+ # We pass `task_authors_df=None` as we are processing a single corpus.
73
+ clustered_df, _ = cached_generate_style_embedding(
74
+ background_corpus_df=corpus_df,
75
+ text_clm='fullText',
76
+ model_name=args.model_name,
77
+ task_authors_df=None
78
+ )
79
+ embedding_col_name = f'{args.model_name.split("/")[-1]}_style_embedding'
80
+ print(f"Embeddings generated and stored in column '{embedding_col_name}'.")
81
+
82
+ # 3. Perform clustering
83
+ print(f"\nPerforming DBSCAN clustering with metric='{args.metric}' and min_samples={args.min_samples}...")
84
+ clustered_df = clustering_author(
85
+ background_corpus_df=clustered_df,
86
+ embedding_clm=embedding_col_name,
87
+ min_samples=args.min_samples,
88
+ metric=args.metric
89
+ )
90
+
91
+ # 4. Save the results
92
+ output_dir = os.path.dirname(args.output_path)
93
+ if output_dir:
94
+ os.makedirs(output_dir, exist_ok=True)
95
+
96
+ clustered_df.to_pickle(args.output_path)
97
+ print(f"\nSuccessfully saved clustered DataFrame to: {args.output_path}")
98
+ print(f"DataFrame includes cluster labels in the 'cluster_label' column.")
99
+
100
+ if __name__ == "__main__":
101
+ main()
utils/clustering_utils.py CHANGED
@@ -8,6 +8,28 @@ from sklearn.metrics.pairwise import cosine_distances, cosine_similarity
8
  from scipy.stats import pearsonr
9
  from typing import List, Dict, Any
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def _find_best_dbscan_eps(X: np.ndarray,
12
  eps_values: List[float],
13
  min_samples: int,
@@ -143,12 +165,14 @@ def clustering_author(background_corpus_df: pd.DataFrame,
143
  if eps_values is None:
144
  if metric == 'cosine':
145
  eps_values = [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
146
- else:
147
  if X.shape[0] > 1:
148
- data_spread = np.std(X)
 
 
149
  eps_values = [round(data_spread * f, 2) for f in [0.25, 0.5, 1.0]]
150
- eps_values = [e for e in eps_values if e > 1e-6]
151
- if not eps_values or X.shape[0] <=1:
152
  eps_values = [0.5, 1.0, 1.5]
153
  print(f"Warning: `eps_values` not provided. Using default range for metric '{metric}': {eps_values}. "
154
  f"It's recommended to supply `eps_values` tuned to your data.")
 
8
  from scipy.stats import pearsonr
9
  from typing import List, Dict, Any
10
 
11
+ import json
12
+
13
+ def sample_ds(input_file, output_file, num_insts=10000, min_num_text_per_inst=0, max_num_text_per_inst=3):
14
+
15
+ """
16
+ Usage
17
+ sample_ds('/mnt/swordfish-pool2/nikhil/raw_all/data.jsonl', '/mnt/swordfish-pool2/milad/hiatus-data/reddit_cluster_training.pkl',
18
+ num_insts=5000,
19
+ min_num_text_per_inst=3,
20
+ max_num_text_per_inst=10)
21
+ """
22
+ f = open(input_file)
23
+ out_list = []
24
+ for i in range(num_insts):
25
+ json_obj = json.loads(f.readline())
26
+ out_list.append({
27
+ 'fullText': json_obj['syms'],
28
+ 'authorID': json_obj['author_id']
29
+ })
30
+ df = pd.DataFrame(out_list)
31
+ df.to_pickle(output_file)
32
+
33
  def _find_best_dbscan_eps(X: np.ndarray,
34
  eps_values: List[float],
35
  min_samples: int,
 
165
  if eps_values is None:
166
  if metric == 'cosine':
167
  eps_values = [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
168
+ else: # 'euclidean' or other
169
  if X.shape[0] > 1:
170
+ # For Euclidean, eps depends on the scale of the data.
171
+ # A simple heuristic: a fraction of the data's standard deviation.
172
+ data_spread = np.std(X)
173
  eps_values = [round(data_spread * f, 2) for f in [0.25, 0.5, 1.0]]
174
+ eps_values = [e for e in eps_values if e > 1e-6] # Filter out zero or near-zero eps
175
+ if not eps_values or X.shape[0] <=1: # Fallback if heuristic fails or not enough data
176
  eps_values = [0.5, 1.0, 1.5]
177
  print(f"Warning: `eps_values` not provided. Using default range for metric '{metric}': {eps_values}. "
178
  f"It's recommended to supply `eps_values` tuned to your data.")
utils/interp_space_utils.py CHANGED
@@ -172,7 +172,6 @@ def generate_style_embedding(background_corpus_df: pd.DataFrame, text_clm: str,
172
 
173
  print(f"Generating style embeddings using {model_name} on column '{text_clm}'...")
174
 
175
- print(background_corpus_df.fullText.tolist()[:10])
176
  model = SentenceTransformer(model_name)
177
  embedding_dim = model.get_sentence_embedding_dimension()
178
 
 
172
 
173
  print(f"Generating style embeddings using {model_name} on column '{text_clm}'...")
174
 
 
175
  model = SentenceTransformer(model_name)
176
  embedding_dim = model.get_sentence_embedding_dimension()
177