import json import logging import re from argparse import ArgumentParser from typing import Any from datasets import Dataset, DatasetDict, DownloadMode, load_dataset from linalg_zero.shared.lib import get_tools from linalg_zero.shared.utils import ( get_logger, get_representative_examples_indices, normalize_text, setup_logging, ) # Log both to file and console setup_logging(level=logging.INFO, include_timestamp=True) logger = get_logger(__name__) def load_datasets(src_train: str, src_test: str) -> DatasetDict: """Load datasets for GRPO training.""" # Load training dataset (has solutions) logger.info(f"Loading train dataset from https://huggingface.co/datasets/{src_train}") train_dataset = load_dataset(src_train, split="train", download_mode=DownloadMode.FORCE_REDOWNLOAD) # Load base dataset to get validation and test splits logger.info(f"Loading base dataset from https://huggingface.co/datasets/{src_test}") base_dataset = load_dataset(src_test, download_mode=DownloadMode.FORCE_REDOWNLOAD) # Extract validation and test splits from the base dataset validation_dataset = base_dataset["validation"] test_dataset = base_dataset["test"] # Prepare results with proper three splits assert isinstance(train_dataset, Dataset) assert isinstance(validation_dataset, Dataset) assert isinstance(test_dataset, Dataset) return DatasetDict({"train": train_dataset, "validation": validation_dataset, "test": test_dataset}) def fix_think_tags(content: str) -> str: """Ensure exactly one newline after and before """ # First remove any existing whitespace around tags content = re.sub(r"\s*", "\n", content) content = re.sub(r"\s*", "\n", content) return content def process_dataset_for_grpo( # noqa: C901 dataset: DatasetDict, normalize_unicode: bool, seed: int, per_category: int ) -> DatasetDict: """Process dataset specifically for GRPO training.""" def ensure_tools(example: dict[str, Any]) -> dict[str, Any]: """Ensure tools field is present.""" if "tools" not in example or example["tools"] is None: example["tools"] = get_tools() return example def normalize_query(example: dict[str, Any]) -> dict[str, Any]: """Normalize query text if unicode normalization is enabled.""" if "query" in example: example["query"] = normalize_text(example["query"], normalize_unicode) return example def parse_messages_for_grpo(example: dict[str, Any]) -> dict[str, Any]: """Convert messages from JSON string to array and fix think tag formatting for GRPO.""" if example.get("messages"): messages = json.loads(example["messages"]) # Fix think tags in assistant messages and normalize content for msg in messages: if isinstance(msg, dict) and msg.get("role") == "assistant" and "content" in msg: content = fix_think_tags(msg["content"]) msg["content"] = normalize_text(content, normalize_unicode) # Store processed messages for reference (optional) example["processed_messages"] = messages return example def process_train_split(train_data: Dataset) -> Dataset: """Process training dataset (has messages field with solutions).""" train_data = train_data.shuffle(seed=seed) train_data = train_data.map(parse_messages_for_grpo) train_data = train_data.map(ensure_tools) train_data = train_data.map(normalize_query) return train_data def process_eval_split(eval_data: Dataset) -> Dataset: """Process evaluation dataset (validation or test - no messages, just problems).""" eval_data = eval_data.shuffle(seed=seed) eval_data = eval_data.map(ensure_tools) eval_data = eval_data.map(normalize_query) eval_indices = get_representative_examples_indices( eval_data, per_category=per_category, include_remaining=False ) eval_data = eval_data.select(eval_indices) return eval_data # The necessary columns for GRPO training keep_columns = [ "query", "ground_truth", "stepwise_ground_truths", "tools", ] # Process all splits train_dataset = process_train_split(dataset["train"]) validation_dataset = process_eval_split(dataset["validation"]) test_dataset = process_eval_split(dataset["test"]) # Remove unnecessary columns from all splits strip_cols = set(train_dataset.column_names) - set(keep_columns) if strip_cols: logger.info(f"Removing columns from train dataset: {strip_cols}") train_dataset = train_dataset.remove_columns(list(strip_cols)) strip_cols = set(validation_dataset.column_names) - set(keep_columns) if strip_cols: logger.info(f"Removing columns from validation dataset: {strip_cols}") validation_dataset = validation_dataset.remove_columns(list(strip_cols)) strip_cols = set(test_dataset.column_names) - set(keep_columns) if strip_cols: logger.info(f"Removing columns from test dataset: {strip_cols}") test_dataset = test_dataset.remove_columns(list(strip_cols)) # Ensure all schemas align with the train dataset validation_dataset = validation_dataset.cast(train_dataset.features) test_dataset = test_dataset.cast(train_dataset.features) # Prepare results with all three splits assert isinstance(train_dataset, Dataset) assert isinstance(validation_dataset, Dataset) assert isinstance(test_dataset, Dataset) return DatasetDict({"train": train_dataset, "validation": validation_dataset, "test": test_dataset}) def validate_grpo_dataset(dataset: DatasetDict) -> None: """Validate that the dataset is properly formatted for GRPO training.""" required_columns = ["query", "ground_truth", "stepwise_ground_truths", "tools"] for split_name, split_data in dataset.items(): logger.info(f"Validating {split_name} split...") # Check required columns missing_cols = set(required_columns) - set(split_data.column_names) if missing_cols: raise ValueError(f"Missing required columns in {split_name}: {missing_cols}") # Validate sample entries if len(split_data) > 0: sample = split_data[0] # Check query field if not sample["query"] or not sample["query"].strip(): raise ValueError(f"Empty query in {split_name} split") # Check ground_truth is valid JSON try: json.loads(sample["ground_truth"]) except json.JSONDecodeError as e: raise ValueError(f"Invalid ground_truth JSON in {split_name}: {e}") from e # Check stepwise_ground_truths is valid JSON list try: stepwise = json.loads(sample["stepwise_ground_truths"]) if not isinstance(stepwise, list): raise TypeError(f"stepwise_ground_truths must be a list in {split_name}") except json.JSONDecodeError as e: raise ValueError(f"Invalid stepwise_ground_truths JSON in {split_name}: {e}") from e # Check tools field if not isinstance(sample["tools"], list): raise ValueError(f"Tools field must be a list in {split_name}") logger.info("*** Dataset validation completed successfully ***") def prepare_debug(train: Dataset, validation: Dataset, test: Dataset, dataset_size: int) -> DatasetDict: """Prepare debug dataset with limited size.""" train = train.select(range(min(dataset_size, len(train)))) validation = validation.select(range(min(dataset_size, len(validation)))) test = test.select(range(min(dataset_size, len(test)))) return DatasetDict({"train": train, "validation": validation, "test": test}) def main( train_repo: str, test_repo: str, output_repo: str, push_to_hub: bool, normalize_unicode: bool, seed: int, per_category: int, ) -> None: """Main processing function for GRPO dataset preparation.""" logger.info("*** Loading datasets for GRPO training ***") dataset = load_datasets(train_repo, test_repo) # Process for GRPO logger.info("*** Processing dataset for GRPO training ***") dataset = process_dataset_for_grpo( dataset, normalize_unicode=normalize_unicode, seed=seed, per_category=per_category ) # Validate the processed dataset validate_grpo_dataset(dataset) # Push to hub if push_to_hub: logger.info("*** Pushing to Hub ***") try: dataset.push_to_hub(output_repo) logger.info(f"Successfully pushed dataset to https://huggingface.co/datasets/{output_repo}") logger.info("Dataset is now ready for GRPO training!") except Exception: logger.exception("Failed to push to hub") else: logger.info("Dataset processing completed. Use --push_to_hub to upload to HuggingFace Hub.") if __name__ == "__main__": """Script entry point for GRPO dataset preparation.""" parser = ArgumentParser(description="Prepare dataset for GRPO training") parser.add_argument( "--src_train", type=str, default="atomwalk12/linalgzero-distilled-clean", help="Source training dataset" ) parser.add_argument("--src_test", type=str, default="atomwalk12/linalgzero", help="Source test dataset") parser.add_argument( "--output_repo", default="atomwalk12/linalgzero-grpo", type=str, help="Output repository name for GRPO dataset" ) parser.add_argument( "--push_to_hub", default=False, action="store_true", help="Whether to push the dataset to HuggingFace Hub" ) parser.add_argument( "--no_normalize_unicode", default=False, action="store_true", help="Disable Unicode NFKC normalization and minus-sign replacement during dataset prep", ) parser.add_argument("--seed", default=20, type=int, help="Random seed for dataset shuffling") parser.add_argument( "--per_category", default=40, type=int, help="Number of representative examples per category for validation and test sets", ) args = parser.parse_args() main( train_repo=args.src_train, test_repo=args.src_test, output_repo=args.output_repo, push_to_hub=args.push_to_hub, normalize_unicode=(not args.no_normalize_unicode), seed=args.seed, per_category=args.per_category, )