|
|
import argparse |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import os |
|
|
|
|
|
from utils.interp_space_utils import cached_generate_style_embedding |
|
|
from utils.clustering_utils import clustering_author |
|
|
|
|
|
def load_corpus(filepath: str) -> pd.DataFrame: |
|
|
""" |
|
|
Loads a corpus from a CSV or Pickle file into a pandas DataFrame. |
|
|
The file is expected to have 'authorID' and 'fullText' columns. |
|
|
""" |
|
|
print(f"Loading corpus from {filepath}...") |
|
|
if filepath.endswith('.csv'): |
|
|
df = pd.read_csv(filepath) |
|
|
elif filepath.endswith('.pkl'): |
|
|
df = pd.read_pickle(filepath) |
|
|
else: |
|
|
raise ValueError("Unsupported file format. Please use .csv or .pkl") |
|
|
|
|
|
if 'authorID' not in df.columns or 'fullText' not in df.columns: |
|
|
raise ValueError("Corpus must contain 'authorID' and 'fullText' columns.") |
|
|
|
|
|
print(f"Corpus loaded successfully with {len(df)} documents.") |
|
|
return df |
|
|
|
|
|
def main(): |
|
|
""" |
|
|
Main function to run the clustering workflow. |
|
|
""" |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Generate style embeddings and cluster a corpus of documents." |
|
|
) |
|
|
parser.add_argument( |
|
|
"corpus_path", |
|
|
type=str, |
|
|
help="Path to the corpus file (.csv or .pkl)." |
|
|
) |
|
|
parser.add_argument( |
|
|
"test_corpus_path", |
|
|
type=str, |
|
|
help="Path to the test corpus file (.csv or .pkl)." |
|
|
) |
|
|
parser.add_argument( |
|
|
"model_name", |
|
|
type=str, |
|
|
help="Hugging Face model name for sentence-transformer embeddings (e.g., 'AnnaWegmann/Style-Embedding')." |
|
|
) |
|
|
parser.add_argument( |
|
|
"output_path", |
|
|
type=str, |
|
|
help="Path to save the output DataFrame with embeddings and clusters (.pkl)." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--min_samples", |
|
|
type=int, |
|
|
default=5, |
|
|
help="min_samples parameter for DBSCAN clustering." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--metric", |
|
|
type=str, |
|
|
default='cosine', |
|
|
choices=['cosine', 'euclidean'], |
|
|
help="Distance metric for DBSCAN clustering." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--eps_values", |
|
|
type=float, |
|
|
nargs='+', |
|
|
default=None, |
|
|
help="A list of specific eps values to test for DBSCAN. If not provided, a default range is used." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--pca_dimensions", |
|
|
type=int, |
|
|
default=None, |
|
|
help="If provided, apply PCA to reduce embeddings to this number of dimensions before clustering." |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
corpus_df = load_corpus(args.corpus_path) |
|
|
test_corpus_df = load_corpus(args.test_corpus_path) |
|
|
|
|
|
|
|
|
print(f"\nGenerating style embeddings with model: {args.model_name}") |
|
|
|
|
|
|
|
|
clustered_df, _ = cached_generate_style_embedding( |
|
|
background_corpus_df=corpus_df, |
|
|
text_clm='fullText', |
|
|
model_name=args.model_name, |
|
|
task_authors_df=None |
|
|
) |
|
|
|
|
|
clustered_test_df, _ = cached_generate_style_embedding( |
|
|
background_corpus_df=test_corpus_df, |
|
|
text_clm='fullText', |
|
|
model_name=args.model_name, |
|
|
task_authors_df=None |
|
|
) |
|
|
embedding_col_name = f'{args.model_name.split("/")[-1]}_style_embedding' |
|
|
print(f"Embeddings generated and stored in column '{embedding_col_name}'.") |
|
|
|
|
|
|
|
|
print(f"\nPerforming DBSCAN clustering with metric='{args.metric}' and min_samples={args.min_samples}...") |
|
|
clustered_df = clustering_author( |
|
|
background_corpus_df=clustered_df, |
|
|
test_corpus_df=clustered_test_df, |
|
|
embedding_clm=embedding_col_name, |
|
|
eps_values=args.eps_values, |
|
|
min_samples=args.min_samples, |
|
|
pca_dimensions=args.pca_dimensions, |
|
|
metric=args.metric |
|
|
) |
|
|
|
|
|
|
|
|
output_dir = os.path.dirname(args.output_path) |
|
|
if output_dir: |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
clustered_df.to_pickle(args.output_path) |
|
|
print(f"\nSuccessfully saved clustered DataFrame to: {args.output_path}") |
|
|
print(f"DataFrame includes cluster labels in the 'cluster_label' column.") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |