File size: 3,384 Bytes
ea3113e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import argparse
import pandas as pd
import numpy as np
import os

from utils.interp_space_utils import cached_generate_style_embedding
from utils.clustering_utils import clustering_author

def load_corpus(filepath: str) -> pd.DataFrame:
    """
    Loads a corpus from a CSV or Pickle file into a pandas DataFrame.
    The file is expected to have 'authorID' and 'fullText' columns.
    """
    print(f"Loading corpus from {filepath}...")
    if filepath.endswith('.csv'):
        df = pd.read_csv(filepath)
    elif filepath.endswith('.pkl'):
        df = pd.read_pickle(filepath)
    else:
        raise ValueError("Unsupported file format. Please use .csv or .pkl")

    if 'authorID' not in df.columns or 'fullText' not in df.columns:
        raise ValueError("Corpus must contain 'authorID' and 'fullText' columns.")

    print(f"Corpus loaded successfully with {len(df)} documents.")
    return df

def main():
    """
    Main function to run the clustering workflow.
    """
    parser = argparse.ArgumentParser(
        description="Generate style embeddings and cluster a corpus of documents."
    )
    parser.add_argument(
        "corpus_path",
        type=str,
        help="Path to the corpus file (.csv or .pkl)."
    )
    parser.add_argument(
        "model_name",
        type=str,
        help="Hugging Face model name for sentence-transformer embeddings (e.g., 'AnnaWegmann/Style-Embedding')."
    )
    parser.add_argument(
        "output_path",
        type=str,
        help="Path to save the output DataFrame with embeddings and clusters (.pkl)."
    )
    parser.add_argument(
        "--min_samples",
        type=int,
        default=5,
        help="min_samples parameter for DBSCAN clustering."
    )
    parser.add_argument(
        "--metric",
        type=str,
        default='cosine',
        choices=['cosine', 'euclidean'],
        help="Distance metric for DBSCAN clustering."
    )

    args = parser.parse_args()

    # 1. Load the corpus
    corpus_df = load_corpus(args.corpus_path)

    # 2. Generate style embeddings
    print(f"\nGenerating style embeddings with model: {args.model_name}")
    # The function returns two dataframes, we are only interested in the first one here.
    # We pass `task_authors_df=None` as we are processing a single corpus.
    clustered_df, _ = cached_generate_style_embedding(
        background_corpus_df=corpus_df,
        text_clm='fullText',
        model_name=args.model_name,
        task_authors_df=None
    )
    embedding_col_name = f'{args.model_name.split("/")[-1]}_style_embedding'
    print(f"Embeddings generated and stored in column '{embedding_col_name}'.")

    # 3. Perform clustering
    print(f"\nPerforming DBSCAN clustering with metric='{args.metric}' and min_samples={args.min_samples}...")
    clustered_df = clustering_author(
        background_corpus_df=clustered_df,
        embedding_clm=embedding_col_name,
        min_samples=args.min_samples,
        metric=args.metric
    )
    
    # 4. Save the results
    output_dir = os.path.dirname(args.output_path)
    if output_dir:
        os.makedirs(output_dir, exist_ok=True)
        
    clustered_df.to_pickle(args.output_path)
    print(f"\nSuccessfully saved clustered DataFrame to: {args.output_path}")
    print(f"DataFrame includes cluster labels in the 'cluster_label' column.")

if __name__ == "__main__":
    main()