|
|
| """
|
| Usage:
|
|
|
| python scripts/data/dedupe_dataset.py \
|
| --dedupe_dataset <This is the dataset that gets deduped> \
|
| --dataset <RAG over this dataset, unmodified> \
|
| --problem_column <name of column>
|
| """
|
|
|
| import json
|
| import os
|
|
|
| from tqdm import tqdm
|
|
|
| from rllm.data.dataset_types import TrainDataset
|
| from rllm.data.utils import load_dataset
|
| from rllm.utils import RAG
|
|
|
|
|
| def normalize_string(text: str) -> str:
|
| """Basic string normalization."""
|
|
|
| text = text.lower().strip()
|
|
|
| text = " ".join(text.split())
|
| return text
|
|
|
|
|
| def get_prompt_from_chat_template(text: str) -> str:
|
| """Extract the prompt from a chat template."""
|
| if isinstance(text, str):
|
| return text
|
| elif isinstance(text, list):
|
| return text[0]["content"] if text else ""
|
| else:
|
| raise ValueError(f"Unsupported type for text: {type(text)}. Expected str or list.")
|
|
|
|
|
| if __name__ == "__main__":
|
| import argparse
|
|
|
| parser = argparse.ArgumentParser()
|
| parser.add_argument("--dedupe_dataset", type=str, required=True, help="Path of the first dataset to check for duplicates")
|
| parser.add_argument("--dataset", type=str, required=True, help="Paths of 2nd dataset to check for duplicates against.")
|
| parser.add_argument("--problem_column", type=str, default="problem", help="Name of the column containing the problem (prompt).")
|
| parser.add_argument(
|
| "--data_dir",
|
| type=str,
|
| default=os.path.expanduser("~/rllm/rllm/data/train/code"),
|
| help="Data directory to save the deduped dataset. If not provided, will use the default data directory.",
|
| )
|
| parser.add_argument("--new_dataset_name", type=str, default=None, help="New name for the dataset. If not provided, will reuse the name and add a `_dedupe` to the name.")
|
| args = parser.parse_args()
|
|
|
|
|
|
|
|
|
| if not os.path.exists(args.data_dir):
|
| raise ValueError(f"Data directory {args.data_dir} does not exist.")
|
|
|
|
|
| ds_name = TrainDataset.Code[args.dedupe_dataset.upper()]
|
| orig_ds_name = TrainDataset.Code[args.dataset.upper()]
|
|
|
| ds = load_dataset(ds_name)
|
| orig_ds = load_dataset(orig_ds_name)
|
|
|
|
|
| problem_col = [prob[args.problem_column] for prob in ds]
|
|
|
|
|
| rag = RAG(docs=problem_col)
|
|
|
|
|
|
|
| dupe_idx = set()
|
| for prob_desc in tqdm(orig_ds, desc="Checking for duplicates"):
|
|
|
| desc = prob_desc[args.problem_column]
|
| if not isinstance(desc, str):
|
| print(f"Skipping due to non-string question: {desc}")
|
| continue
|
|
|
| normalized_question = normalize_string(desc)
|
|
|
| top_3 = rag.top_k(normalized_question, k=3)
|
|
|
|
|
| if top_3:
|
| for top in top_3:
|
| if top["score"] > 0.95:
|
| dupe_idx.add(top["idx"])
|
|
|
|
|
|
|
| if dupe_idx:
|
| print(f"Found {len(dupe_idx)} duplicates in the dataset.")
|
| ds = [p for i, p in enumerate(ds) if i not in dupe_idx]
|
| print(len(ds), "remaining after removing duplicates.")
|
| else:
|
| print("No duplicates found.")
|
|
|
|
|
| new_ds_name = args.new_dataset_name or f"{os.path.splitext(os.path.basename(args.dedupe_dataset))[0]}_dedupe"
|
| if not new_ds_name.endswith(".json"):
|
| new_ds_name += ".json"
|
|
|
|
|
| with open(os.path.join(args.data_dir, new_ds_name), "w", encoding="utf-8") as f:
|
| json.dump(ds, f)
|
|
|
| print(f"All done! Saving the deduped dataset to {new_ds_name} in {args.data_dir}")
|
|
|