| | """ |
| | |
| | Usage: |
| | python3 topic_clustering.py --in arena.json --english-only --min-length 32 |
| | python3 topic_clustering.py --in clean_conv_20230809_100k.json --english-only --min-length 32 --max-length 1536 |
| | """ |
| | import argparse |
| | import json |
| | import pickle |
| | import string |
| | import time |
| |
|
| | import numpy as np |
| | from sentence_transformers import SentenceTransformer |
| | from sentence_transformers.util import cos_sim |
| | from sklearn.cluster import KMeans, AgglomerativeClustering |
| | import torch |
| | from tqdm import tqdm |
| | from openai import OpenAI |
| |
|
| | from fastchat.utils import detect_language |
| |
|
| |
|
| | def remove_punctuation(input_string): |
| | |
| | translator = str.maketrans("", "", string.punctuation) |
| |
|
| | |
| | no_punct = input_string.translate(translator) |
| | return no_punct |
| |
|
| |
|
| | def read_texts(input_file, min_length, max_length, english_only): |
| | visited = set() |
| | texts = [] |
| |
|
| | lines = json.load(open(input_file, "r")) |
| |
|
| | for l in tqdm(lines): |
| | if "text" in l: |
| | line_texts = [l["text"]] |
| | elif "conversation_a" in l: |
| | line_texts = [ |
| | x["content"] for x in l["conversation_a"] if x["role"] == "user" |
| | ] |
| | elif "conversation" in l: |
| | line_texts = [ |
| | x["content"] for x in l["conversation"] if x["role"] == "user" |
| | ] |
| | elif "turns" in l: |
| | line_texts = l["turns"] |
| |
|
| | for text in line_texts: |
| | text = text.strip() |
| |
|
| | |
| | if english_only: |
| | lang = detect_language(text) |
| | if lang != "English": |
| | continue |
| |
|
| | |
| | if min_length: |
| | if len(text) < min_length: |
| | continue |
| |
|
| | if max_length: |
| | if len(text) > max_length: |
| | continue |
| |
|
| | |
| | words = sorted([x.lower() for x in remove_punctuation(text).split(" ")]) |
| | words = "".join(words) |
| | if words in visited: |
| | continue |
| |
|
| | visited.add(words) |
| | texts.append(text) |
| | return np.array(texts) |
| |
|
| |
|
| | def get_embeddings(texts, model_name, batch_size): |
| | if model_name == "text-embedding-ada-002": |
| | client = OpenAI() |
| | texts = texts.tolist() |
| |
|
| | embeddings = [] |
| | for i in tqdm(range(0, len(texts), batch_size)): |
| | text = texts[i : i + batch_size] |
| | responses = client.embeddings.create(input=text, model=model_name).data |
| | embeddings.extend([data.embedding for data in responses]) |
| | embeddings = torch.tensor(embeddings) |
| | else: |
| | model = SentenceTransformer(model_name) |
| | embeddings = model.encode( |
| | texts, |
| | batch_size=batch_size, |
| | show_progress_bar=True, |
| | device="cuda", |
| | convert_to_tensor=True, |
| | ) |
| |
|
| | embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) |
| | return embeddings.cpu() |
| |
|
| |
|
| | def run_k_means(embeddings, num_clusters): |
| | np.random.seed(42) |
| | clustering_model = KMeans(n_clusters=num_clusters, n_init="auto") |
| | clustering_model.fit(embeddings.numpy()) |
| | centers = torch.from_numpy(clustering_model.cluster_centers_) |
| | labels = torch.from_numpy(clustering_model.labels_) |
| |
|
| | |
| | classes, counts = np.unique(labels, return_counts=True) |
| | indices = np.argsort(counts)[::-1] |
| | classes = [classes[i] for i in indices] |
| | new_labels = torch.empty_like(labels) |
| | new_centers = torch.empty_like(centers) |
| | for i, c in enumerate(classes): |
| | new_labels[labels == c] = i |
| | new_centers[i] = centers[c] |
| | return new_centers, new_labels |
| |
|
| |
|
| | def run_agg_cluster(embeddings, num_clusters): |
| | np.random.seed(42) |
| | clustering_model = AgglomerativeClustering(n_clusters=num_clusters) |
| | clustering_model.fit(embeddings) |
| | labels = torch.from_numpy(clustering_model.labels_) |
| |
|
| | |
| | classes, counts = np.unique(labels, return_counts=True) |
| | indices = np.argsort(counts)[::-1] |
| | classes = [classes[i] for i in indices] |
| | new_labels = torch.empty_like(labels) |
| | for i, c in enumerate(classes): |
| | new_labels[labels == c] = i |
| |
|
| | |
| | centers = [] |
| | for i in range(len(classes)): |
| | centers.append(embeddings[new_labels == i].mean(axis=0, keepdim=True)) |
| | centers = torch.cat(centers) |
| | return centers, new_labels |
| |
|
| |
|
| | def run_hdbscan_cluster(embeddings): |
| | import hdbscan |
| |
|
| | np.random.seed(42) |
| | clusterer = hdbscan.HDBSCAN(min_cluster_size=10) |
| | labels = torch.from_numpy(clusterer.fit_predict(embeddings)) |
| |
|
| | |
| | classes, counts = np.unique(labels, return_counts=True) |
| | indices = np.argsort(counts)[::-1] |
| | classes = [classes[i] for i in indices] |
| | new_labels = torch.empty_like(labels) |
| | for i, c in enumerate(classes): |
| | new_labels[labels == c] = i |
| |
|
| | |
| | centers = [] |
| | for i in range(len(classes)): |
| | centers.append(embeddings[new_labels == i].mean(axis=0, keepdim=True)) |
| | centers = torch.cat(centers) |
| | return centers, new_labels |
| |
|
| |
|
| | def get_topk_indices(centers, labels, embeddings, topk): |
| | indices = [] |
| | arange = torch.arange(len(labels)) |
| | counts = torch.unique(labels, return_counts=True)[1] |
| | topk = min(topk, counts.min().item()) |
| | for i in range(len(centers)): |
| | tmp_indices = labels == i |
| | tmp_arange = arange[tmp_indices] |
| | tmp_embeddings = embeddings[tmp_indices] |
| |
|
| | scores = cos_sim(centers[i].unsqueeze(0), tmp_embeddings)[0] |
| | sorted_indices = torch.flip(torch.argsort(scores), dims=[0]) |
| | indices.append(tmp_arange[sorted_indices[:topk]].unsqueeze(0)) |
| | return torch.cat(indices) |
| |
|
| |
|
| | def print_topk(texts, labels, topk_indices, show_cut_off): |
| | ret = "" |
| | for k in range(len(topk_indices)): |
| | num_samples = torch.sum(labels == k).item() |
| |
|
| | ret += "=" * 20 + f" cluster {k}, #samples: {num_samples} " + "=" * 20 + "\n" |
| | for idx in topk_indices[k]: |
| | ret += "PROMPT: " + texts[idx][:show_cut_off] + "\n" |
| | ret += "=" * 40 + "\n\n" |
| |
|
| | return ret |
| |
|
| |
|
| | def get_cluster_info(texts, labels, topk_indices): |
| | np.random.seed(42) |
| |
|
| | cluster_info = [] |
| | for k in range(len(topk_indices)): |
| | num_samples = torch.sum(labels == k).item() |
| | topk_prompts = [] |
| | for idx in topk_indices[k]: |
| | topk_prompts.append(texts[idx]) |
| | random_prompts = [] |
| | for idx in range(len(topk_indices)): |
| | random_prompts.append(np.random.choice(texts)) |
| | cluster_info.append((num_samples, topk_prompts, random_prompts)) |
| |
|
| | return cluster_info |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--input-file", type=str, required=True) |
| | parser.add_argument("--model", type=str, default="all-mpnet-base-v2") |
| | |
| | |
| | parser.add_argument("--batch-size", type=int, default=256) |
| | parser.add_argument("--min-length", type=int) |
| | parser.add_argument("--max-length", type=int) |
| | parser.add_argument("--english-only", action="store_true") |
| | parser.add_argument("--num-clusters", type=int, default=20) |
| | parser.add_argument( |
| | "--cluster-alg", |
| | type=str, |
| | choices=["kmeans", "aggcls", "HDBSCAN"], |
| | default="kmeans", |
| | ) |
| | parser.add_argument("--show-top-k", type=int, default=200) |
| | parser.add_argument("--show-cut-off", type=int, default=512) |
| | parser.add_argument("--save-embeddings", action="store_true") |
| | parser.add_argument("--embeddings-file", type=str, default=None) |
| | args = parser.parse_args() |
| |
|
| | num_clusters = args.num_clusters |
| | show_top_k = args.show_top_k |
| | show_cut_off = args.show_cut_off |
| |
|
| | texts = read_texts( |
| | args.input_file, args.min_length, args.max_length, args.english_only |
| | ) |
| | print(f"#text: {len(texts)}") |
| |
|
| | if args.embeddings_file is None: |
| | embeddings = get_embeddings(texts, args.model, args.batch_size) |
| | if args.save_embeddings: |
| | |
| | torch.save(embeddings, "embeddings.pt") |
| | else: |
| | embeddings = torch.load(args.embeddings_file) |
| | print(f"embeddings shape: {embeddings.shape}") |
| |
|
| | if args.cluster_alg == "kmeans": |
| | centers, labels = run_k_means(embeddings, num_clusters) |
| | elif args.cluster_alg == "aggcls": |
| | centers, labels = run_agg_cluster(embeddings, num_clusters) |
| | elif args.cluster_alg == "HDBSCAN": |
| | centers, labels = run_hdbscan_cluster(embeddings) |
| | else: |
| | raise ValueError(f"Invalid clustering algorithm: {args.cluster_alg}") |
| |
|
| | topk_indices = get_topk_indices(centers, labels, embeddings, args.show_top_k) |
| | topk_str = print_topk(texts, labels, topk_indices, args.show_cut_off) |
| | num_clusters = len(centers) |
| |
|
| | |
| | filename_prefix = f"results_c{num_clusters}_{args.cluster_alg}" |
| | print(topk_str) |
| | with open(filename_prefix + "_topk.txt", "w") as fout: |
| | fout.write(topk_str) |
| |
|
| | with open(filename_prefix + "_all.jsonl", "w") as fout: |
| | for i in range(len(centers)): |
| | tmp_indices = labels == i |
| | tmp_embeddings = embeddings[tmp_indices] |
| | tmp_texts = texts[tmp_indices] |
| |
|
| | scores = cos_sim(centers[i].unsqueeze(0), tmp_embeddings)[0] |
| | sorted_indices = torch.flip(torch.argsort(scores), dims=[0]) |
| |
|
| | for text, score in zip(tmp_texts[sorted_indices], scores[sorted_indices]): |
| | obj = {"cluster": i, "text": text, "sim": score.item()} |
| | fout.write(json.dumps(obj, ensure_ascii=False) + "\n") |
| |
|
| | cluster_info = get_cluster_info(texts, labels, topk_indices) |
| | with open(filename_prefix + "_cluster.pkl", "wb") as fout: |
| | pickle.dump(cluster_info, fout) |
| |
|