Milad Alshomary commited on
Commit
bd0cb8d
·
1 Parent(s): 016bb2f
Files changed (2) hide show
  1. cluster_corpus.py +14 -0
  2. utils/clustering_utils.py +136 -61
cluster_corpus.py CHANGED
@@ -37,6 +37,11 @@ def main():
37
  type=str,
38
  help="Path to the corpus file (.csv or .pkl)."
39
  )
 
 
 
 
 
40
  parser.add_argument(
41
  "model_name",
42
  type=str,
@@ -65,6 +70,7 @@ def main():
65
 
66
  # 1. Load the corpus
67
  corpus_df = load_corpus(args.corpus_path)
 
68
 
69
  # 2. Generate style embeddings
70
  print(f"\nGenerating style embeddings with model: {args.model_name}")
@@ -76,6 +82,13 @@ def main():
76
  model_name=args.model_name,
77
  task_authors_df=None
78
  )
 
 
 
 
 
 
 
79
  embedding_col_name = f'{args.model_name.split("/")[-1]}_style_embedding'
80
  print(f"Embeddings generated and stored in column '{embedding_col_name}'.")
81
 
@@ -83,6 +96,7 @@ def main():
83
  print(f"\nPerforming DBSCAN clustering with metric='{args.metric}' and min_samples={args.min_samples}...")
84
  clustered_df = clustering_author(
85
  background_corpus_df=clustered_df,
 
86
  embedding_clm=embedding_col_name,
87
  min_samples=args.min_samples,
88
  metric=args.metric
 
37
  type=str,
38
  help="Path to the corpus file (.csv or .pkl)."
39
  )
40
+ parser.add_argument(
41
+ "test_corpus_path",
42
+ type=str,
43
+ help="Path to the test corpus file (.csv or .pkl)."
44
+ )
45
  parser.add_argument(
46
  "model_name",
47
  type=str,
 
70
 
71
  # 1. Load the corpus
72
  corpus_df = load_corpus(args.corpus_path)
73
+ test_corpus_df = load_corpus(args.test_corpus_path)
74
 
75
  # 2. Generate style embeddings
76
  print(f"\nGenerating style embeddings with model: {args.model_name}")
 
82
  model_name=args.model_name,
83
  task_authors_df=None
84
  )
85
+
86
+ clustered_test_df, _ = cached_generate_style_embedding(
87
+ background_corpus_df=test_corpus_df,
88
+ text_clm='fullText',
89
+ model_name=args.model_name,
90
+ task_authors_df=None
91
+ )
92
  embedding_col_name = f'{args.model_name.split("/")[-1]}_style_embedding'
93
  print(f"Embeddings generated and stored in column '{embedding_col_name}'.")
94
 
 
96
  print(f"\nPerforming DBSCAN clustering with metric='{args.metric}' and min_samples={args.min_samples}...")
97
  clustered_df = clustering_author(
98
  background_corpus_df=clustered_df,
99
+ test_corpus_df=clustered_test_df,
100
  embedding_clm=embedding_col_name,
101
  min_samples=args.min_samples,
102
  metric=args.metric
utils/clustering_utils.py CHANGED
@@ -5,7 +5,7 @@ from sklearn.cluster import DBSCAN
5
  from sklearn.metrics import silhouette_score
6
  # Required for analyze_space_distance_preservation
7
  from sklearn.metrics.pairwise import cosine_distances, cosine_similarity
8
- from scipy.stats import pearsonr
9
  from typing import List, Dict, Any
10
 
11
  import json
@@ -30,63 +30,35 @@ def sample_ds(input_file, output_file, num_insts=10000, min_num_text_per_inst=0,
30
  df = pd.DataFrame(out_list)
31
  df.to_pickle(output_file)
32
 
33
- def _find_best_dbscan_eps(X: np.ndarray,
34
- eps_values: List[float],
35
- min_samples: int,
36
- metric: str) -> tuple[float | None, np.ndarray | None, float]:
37
  """
38
- Iterates through eps_values for DBSCAN and returns the parameters
39
- that yield the highest silhouette score.
40
 
41
  Args:
42
  X (np.ndarray): The input data (embeddings).
43
- eps_values (List[float]): List of eps values to try.
44
- min_samples (int): DBSCAN min_samples parameter.
45
- metric (str): Distance metric for DBSCAN and silhouette score.
46
 
47
  Returns:
48
- tuple[float | None, np.ndarray | None, float]:
49
- - best_eps: The eps value that resulted in the best score. None if no suitable clustering.
50
- - best_labels: The cluster labels from the best DBSCAN run. None if no suitable clustering.
51
- - best_score: The highest silhouette score achieved.
52
  """
53
- best_score = -1.001 # Silhouette score is in [-1, 1]
54
- best_labels = None
55
- best_eps = None
56
-
57
- for eps in eps_values:
58
- if eps <= 1e-9: # eps must be positive
59
- continue
60
- db = DBSCAN(eps=eps, min_samples=min_samples, metric=metric)
61
- labels = db.fit_predict(X)
62
-
63
- unique_labels_set = set(labels)
64
- n_clusters_ = len(unique_labels_set) - (1 if -1 in unique_labels_set else 0)
 
65
 
66
- if n_clusters_ > 1:
67
- clustered_mask = (labels != -1)
68
- if np.sum(clustered_mask) >= 2: # Need at least 2 non-noise points
69
- X_clustered = X[clustered_mask]
70
- labels_clustered = labels[clustered_mask]
71
- try:
72
- score = silhouette_score(X_clustered, labels_clustered, metric=metric)
73
- if score > best_score:
74
- best_score = score
75
- best_labels = labels.copy()
76
- best_eps = eps
77
- print('EPS:', eps, 'SCORE:', score)
78
- except ValueError: # Catch errors from silhouette_score
79
- pass
80
- elif n_clusters_ == 1 and best_labels is None: # Fallback for single cluster
81
- if not all(l == -1 for l in labels):
82
- current_score_for_single_cluster = -0.5 # Nominal score
83
- if current_score_for_single_cluster > best_score:
84
- best_score = current_score_for_single_cluster
85
- best_labels = labels.copy()
86
- best_eps = eps
87
- return best_eps, best_labels, best_score
88
 
89
  def clustering_author(background_corpus_df: pd.DataFrame,
 
90
  embedding_clm: str = 'style_embedding',
91
  eps_values: List[float] = None,
92
  min_samples: int = 5,
@@ -178,14 +150,62 @@ def clustering_author(background_corpus_df: pd.DataFrame,
178
  print(f"Warning: `eps_values` not provided. Using default range for metric '{metric}': {eps_values}. "
179
  f"It's recommended to supply `eps_values` tuned to your data.")
180
 
181
- print(f"Performing DBSCAN clustering (min_samples={min_samples}, metric='{metric}') with eps values: "
182
- f"{[f'{e:.2f}' for e in eps_values]}")
183
 
184
- best_eps, best_labels, best_score = _find_best_dbscan_eps(X, eps_values, min_samples, metric)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
  if best_labels is not None:
187
  num_found_clusters = len(set(best_labels) - {-1})
188
- print(f"Best clustering found: eps={best_eps:.2f}, Silhouette Score={best_score:.4f} ({num_found_clusters} clusters).")
 
189
  for i, label in enumerate(best_labels):
190
  original_df_idx = original_indices[i]
191
  final_labels_for_df.iloc[original_df_idx] = label
@@ -334,17 +354,72 @@ def analyze_space_distance_preservation(
334
  distances_original_space.size != distances_new_space.size:
335
  return None # Mismatch or empty distances
336
 
337
- # Handle cases where variance is zero in one of the distance arrays (leads to NaN correlation)
338
- if np.all(distances_new_space == distances_new_space[0]) or \
339
- np.all(distances_original_space == distances_original_space[0]):
340
- return 0.0 # Correlation is undefined or 0 if one variable is constant
341
-
342
  try:
343
- correlation, _ = pearsonr(distances_original_space, distances_new_space)
344
- except ValueError: # Should be caught by variance checks, but as a safeguard
 
 
 
 
 
 
 
 
345
  return None
346
 
347
  if np.isnan(correlation):
348
  return 0.0 # Default for NaN correlation
349
-
350
- return correlation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  from sklearn.metrics import silhouette_score
6
  # Required for analyze_space_distance_preservation
7
  from sklearn.metrics.pairwise import cosine_distances, cosine_similarity
8
+ from scipy.stats import pearsonr, ConstantInputWarning
9
  from typing import List, Dict, Any
10
 
11
  import json
 
30
  df = pd.DataFrame(out_list)
31
  df.to_pickle(output_file)
32
 
33
+ def _calculate_silhouette_score(X: np.ndarray, labels: np.ndarray, metric: str) -> float | None:
 
 
 
34
  """
35
+ Calculates the silhouette score for a given clustering result.
 
36
 
37
  Args:
38
  X (np.ndarray): The input data (embeddings).
39
+ labels (np.ndarray): The cluster labels for each point in X.
40
+ metric (str): The distance metric used for the score calculation.
 
41
 
42
  Returns:
43
+ float | None: The silhouette score, or None if it cannot be computed.
 
 
 
44
  """
45
+ unique_labels_set = set(labels)
46
+ n_clusters_ = len(unique_labels_set) - (1 if -1 in unique_labels_set else 0)
47
+
48
+ if n_clusters_ > 1:
49
+ clustered_mask = (labels != -1)
50
+ if np.sum(clustered_mask) > 1:
51
+ X_clustered = X[clustered_mask]
52
+ labels_clustered = labels[clustered_mask]
53
+ try:
54
+ return silhouette_score(X_clustered, labels_clustered, metric=metric)
55
+ except ValueError:
56
+ return None
57
+ return None
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  def clustering_author(background_corpus_df: pd.DataFrame,
61
+ test_corpus_df: pd.DataFrame = None,
62
  embedding_clm: str = 'style_embedding',
63
  eps_values: List[float] = None,
64
  min_samples: int = 5,
 
150
  print(f"Warning: `eps_values` not provided. Using default range for metric '{metric}': {eps_values}. "
151
  f"It's recommended to supply `eps_values` tuned to your data.")
152
 
153
+ print(f"\n--- Starting DBSCAN Clustering & Evaluation ---")
154
+ print(f"Metric: '{metric}', Min Samples: {min_samples}, EPS values: {[f'{e:.2f}' for e in eps_values]}")
155
 
156
+ best_score = -1.001
157
+ best_labels = None
158
+ best_eps = None
159
+
160
+ # This loop now lives in `clustering_author` to have access to the full DataFrame for evaluation.
161
+ for eps in eps_values:
162
+ if eps <= 1e-9: continue
163
+
164
+ print(f"\nTesting eps = {eps:.3f}...")
165
+ db = DBSCAN(eps=eps, min_samples=min_samples, metric=metric)
166
+ current_labels = db.fit_predict(X)
167
+
168
+ # --- Evaluation Step 1: Silhouette Score ---
169
+ score = _calculate_silhouette_score(X, current_labels, metric)
170
+ if score is not None:
171
+ print(f" - Silhouette Score: {score:.4f}")
172
+ if score > best_score:
173
+ best_score = score
174
+ best_labels = current_labels.copy()
175
+ best_eps = eps
176
+ else:
177
+ print(" - Silhouette Score: N/A (not enough clusters found)")
178
+
179
+ # --- Evaluation Step 2: Distance Preservation ---
180
+ # Temporarily assign labels to a copy of the DataFrame for evaluation
181
+ temp_df = background_corpus_df.copy()
182
+ temp_labels_for_df = pd.Series(-1, index=temp_df.index, dtype=int)
183
+ temp_labels_for_df.iloc[original_indices] = current_labels
184
+ temp_df['cluster_label'] = temp_labels_for_df
185
+
186
+ correlation = analyze_space_distance_preservation(temp_df, embedding_clm, 'cluster_label')
187
+ if correlation is not None:
188
+ print(f" - Distance Preservation (Pearson r): {correlation:.4f}")
189
+ else:
190
+ print(" - Distance Preservation (Pearson r): N/A (not enough clusters/data)")
191
+
192
+ # --- Evaluation Step 3: Distance Preservation on Test Corpus (if provided) ---
193
+ if test_corpus_df is not None:
194
+ # We need the centroids from the current clustering of the background corpus
195
+ centroids = _compute_cluster_centroids(temp_df[temp_df['cluster_label'] != -1], embedding_clm, 'cluster_label')
196
+
197
+ test_correlation = evaluate_test_set_distance_preservation(test_corpus_df, centroids, embedding_clm)
198
+ if test_correlation is not None:
199
+ print(f" - Test Set Distance Preservation (Pearson r): {test_correlation:.4f}")
200
+ else:
201
+ print(" - Test Set Distance Preservation (Pearson r): N/A (not enough test data or clusters)")
202
+
203
+ print('Eps {}, #clusters {}, solihouette {}, Pearson {}'.format(eps, len(set(current_labels) - {-1}), score, test_correlation))
204
 
205
  if best_labels is not None:
206
  num_found_clusters = len(set(best_labels) - {-1})
207
+ print(f"\n--- Best Clustering Result ---")
208
+ print(f"Best eps: {best_eps:.3f} yielded the highest Silhouette Score: {best_score:.4f} ({num_found_clusters} clusters).")
209
  for i, label in enumerate(best_labels):
210
  original_df_idx = original_indices[i]
211
  final_labels_for_df.iloc[original_df_idx] = label
 
354
  distances_original_space.size != distances_new_space.size:
355
  return None # Mismatch or empty distances
356
 
 
 
 
 
 
357
  try:
358
+ # Catching ConstantInputWarning that pearsonr can raise
359
+ import warnings
360
+ with warnings.catch_warnings():
361
+ warnings.filterwarnings('error', category=ConstantInputWarning)
362
+ correlation, _ = pearsonr(distances_original_space, distances_new_space)
363
+ except (ValueError, ConstantInputWarning):
364
+ # This happens if one of the distance arrays has zero variance (all distances are the same).
365
+ # This is a valid case where correlation is undefined or 0.
366
+ return 0.0
367
+ except Exception: # Safeguard for other unexpected errors
368
  return None
369
 
370
  if np.isnan(correlation):
371
  return 0.0 # Default for NaN correlation
372
+
373
+ return correlation
374
+
375
+ def evaluate_test_set_distance_preservation(
376
+ test_df: pd.DataFrame,
377
+ centroids_map: Dict[Any, np.ndarray],
378
+ embedding_clm: str = 'style_embedding'
379
+ ) -> float | None:
380
+ """
381
+ Evaluates how well a centroid space (from a background corpus) preserves
382
+ distances for a separate test corpus.
383
+
384
+ Args:
385
+ test_df (pd.DataFrame): The test corpus DataFrame with embeddings.
386
+ centroids_map (Dict[Any, np.ndarray]): A map of cluster IDs to centroid vectors,
387
+ pre-computed from the background corpus.
388
+ embedding_clm (str): The name of the embedding column.
389
+
390
+ Returns:
391
+ float | None: Pearson correlation coefficient, or None if analysis is not possible.
392
+ """
393
+ if test_df.shape[0] < 2:
394
+ return None # Need at least 2 items for pairwise distances
395
+
396
+ if not centroids_map or len(centroids_map) < 2:
397
+ return None # Need at least 2 centroids to define a meaningful projected space
398
+
399
+ # 1. Get original embeddings and distances for the test set
400
+ test_embeddings_matrix = _safe_embeddings_to_matrix(test_df[embedding_clm])
401
+ if test_embeddings_matrix.ndim != 2 or test_embeddings_matrix.shape[0] < 2:
402
+ return None # Not enough valid embeddings in the test set
403
+
404
+ distances_original_space = _get_pairwise_cosine_distances(test_embeddings_matrix)
405
+
406
+ # 2. Project test embeddings into the centroid space and get new distances
407
+ projected_embeddings_matrix = _project_to_centroid_space(test_embeddings_matrix, centroids_map)
408
+ if projected_embeddings_matrix.ndim != 2 or projected_embeddings_matrix.shape[1] < 2:
409
+ return None # Projection failed or resulted in a space with <2 dimensions
410
+
411
+ distances_new_space = _get_pairwise_cosine_distances(projected_embeddings_matrix)
412
+
413
+ # 3. Calculate Pearson correlation
414
+ if distances_original_space.size != distances_new_space.size or distances_original_space.size == 0:
415
+ return None
416
+
417
+ try:
418
+ import warnings
419
+ with warnings.catch_warnings():
420
+ warnings.filterwarnings('error', category=ConstantInputWarning)
421
+ correlation, _ = pearsonr(distances_original_space, distances_new_space)
422
+ except (ValueError, ConstantInputWarning):
423
+ return 0.0 # Zero variance in one of the distance sets
424
+
425
+ return correlation if not np.isnan(correlation) else 0.0