| | import os |
| | import json |
| | import pandas as pd |
| | import ast |
| |
|
| | import matplotlib.pyplot as plt |
| | from matplotlib import rcParams |
| |
|
| | import argparse |
| | import seaborn as sns |
| | from tqdm import tqdm |
| | import matplotlib.pyplot as plt |
| |
|
| | import numpy as np |
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--output_dir", type=str, default="output") |
| | parser.add_argument("--model", type=str, default=None) |
| | parser.add_argument("--input_file", type=str, required=True) |
| | parser.add_argument("--percentile", type=float, default=0.9999) |
| | args = parser.parse_args() |
| | output_dir = args.output_dir |
| | input_file = args.input_file |
| |
|
| | with open(input_file) as f: |
| | data = json.load(f) |
| |
|
| | os.makedirs(output_dir, exist_ok=True) |
| |
|
| | |
| | all_convs_new = [] |
| | convs = [] |
| | for row in data: |
| | conv = "" |
| | for turns in row["conversation_a"]: |
| | if turns["role"] == "user": |
| | conv += f"{turns['content']}\n" |
| |
|
| | convs.append(conv[:10000]) |
| | row["post_process_conv"] = conv[:10000] |
| | all_convs_new.append(row) |
| |
|
| | df = pd.DataFrame(all_convs_new) |
| | print("Number of conversations: ", len(df)) |
| |
|
| | prompt_counts = df["post_process_conv"].value_counts() |
| | |
| | top_prompts = prompt_counts.head(20) |
| | print(top_prompts) |
| |
|
| | |
| | percentile_cutoff = prompt_counts.quantile(args.percentile) |
| | print(f"{args.percentile*100} percentile count: {percentile_cutoff}") |
| |
|
| | |
| | high_frequency_prompts = prompt_counts[prompt_counts > percentile_cutoff].index |
| | print( |
| | f"Number of high frequency prompts: {len(high_frequency_prompts)}/{len(prompt_counts)}" |
| | ) |
| |
|
| | |
| | dedup_tags = np.array( |
| | [{"high_freq": False, "sampled": True} for _ in range(len(df))] |
| | ) |
| | high_freq_groups = df.groupby("post_process_conv") |
| | for prompt in tqdm(high_frequency_prompts): |
| | df_high_freq = high_freq_groups.get_group(prompt) |
| | sampled_indices = df_high_freq.sample( |
| | n=int(percentile_cutoff), random_state=42 |
| | ).index |
| | dedup_tags[df_high_freq.index] = {"high_freq": True, "sampled": False} |
| | dedup_tags[sampled_indices] = {"high_freq": True, "sampled": True} |
| |
|
| | df["dedup_tag"] = dedup_tags |
| |
|
| | |
| | df = df.drop(columns=["post_process_conv"]) |
| |
|
| | df.to_json( |
| | os.path.join(output_dir, "dedup.json"), |
| | orient="records", |
| | indent=4, |
| | force_ascii=False, |
| | ) |
| |
|