linalg-zero / linalg_zero /grpo /scripts /prepare_dataset.py
atomwalk12
initial commit
0dd6c2f
import json
import logging
import re
from argparse import ArgumentParser
from typing import Any
from datasets import Dataset, DatasetDict, DownloadMode, load_dataset
from linalg_zero.shared.lib import get_tools
from linalg_zero.shared.utils import (
get_logger,
get_representative_examples_indices,
normalize_text,
setup_logging,
)
# Log both to file and console
setup_logging(level=logging.INFO, include_timestamp=True)
logger = get_logger(__name__)
def load_datasets(src_train: str, src_test: str) -> DatasetDict:
"""Load datasets for GRPO training."""
# Load training dataset (has solutions)
logger.info(f"Loading train dataset from https://huggingface.co/datasets/{src_train}")
train_dataset = load_dataset(src_train, split="train", download_mode=DownloadMode.FORCE_REDOWNLOAD)
# Load base dataset to get validation and test splits
logger.info(f"Loading base dataset from https://huggingface.co/datasets/{src_test}")
base_dataset = load_dataset(src_test, download_mode=DownloadMode.FORCE_REDOWNLOAD)
# Extract validation and test splits from the base dataset
validation_dataset = base_dataset["validation"]
test_dataset = base_dataset["test"]
# Prepare results with proper three splits
assert isinstance(train_dataset, Dataset)
assert isinstance(validation_dataset, Dataset)
assert isinstance(test_dataset, Dataset)
return DatasetDict({"train": train_dataset, "validation": validation_dataset, "test": test_dataset})
def fix_think_tags(content: str) -> str:
"""Ensure exactly one newline after <think> and before </think>"""
# First remove any existing whitespace around tags
content = re.sub(r"<think>\s*", "<think>\n", content)
content = re.sub(r"\s*</think>", "\n</think>", content)
return content
def process_dataset_for_grpo( # noqa: C901
dataset: DatasetDict, normalize_unicode: bool, seed: int, per_category: int
) -> DatasetDict:
"""Process dataset specifically for GRPO training."""
def ensure_tools(example: dict[str, Any]) -> dict[str, Any]:
"""Ensure tools field is present."""
if "tools" not in example or example["tools"] is None:
example["tools"] = get_tools()
return example
def normalize_query(example: dict[str, Any]) -> dict[str, Any]:
"""Normalize query text if unicode normalization is enabled."""
if "query" in example:
example["query"] = normalize_text(example["query"], normalize_unicode)
return example
def parse_messages_for_grpo(example: dict[str, Any]) -> dict[str, Any]:
"""Convert messages from JSON string to array and fix think tag formatting for GRPO."""
if example.get("messages"):
messages = json.loads(example["messages"])
# Fix think tags in assistant messages and normalize content
for msg in messages:
if isinstance(msg, dict) and msg.get("role") == "assistant" and "content" in msg:
content = fix_think_tags(msg["content"])
msg["content"] = normalize_text(content, normalize_unicode)
# Store processed messages for reference (optional)
example["processed_messages"] = messages
return example
def process_train_split(train_data: Dataset) -> Dataset:
"""Process training dataset (has messages field with solutions)."""
train_data = train_data.shuffle(seed=seed)
train_data = train_data.map(parse_messages_for_grpo)
train_data = train_data.map(ensure_tools)
train_data = train_data.map(normalize_query)
return train_data
def process_eval_split(eval_data: Dataset) -> Dataset:
"""Process evaluation dataset (validation or test - no messages, just problems)."""
eval_data = eval_data.shuffle(seed=seed)
eval_data = eval_data.map(ensure_tools)
eval_data = eval_data.map(normalize_query)
eval_indices = get_representative_examples_indices(
eval_data, per_category=per_category, include_remaining=False
)
eval_data = eval_data.select(eval_indices)
return eval_data
# The necessary columns for GRPO training
keep_columns = [
"query",
"ground_truth",
"stepwise_ground_truths",
"tools",
]
# Process all splits
train_dataset = process_train_split(dataset["train"])
validation_dataset = process_eval_split(dataset["validation"])
test_dataset = process_eval_split(dataset["test"])
# Remove unnecessary columns from all splits
strip_cols = set(train_dataset.column_names) - set(keep_columns)
if strip_cols:
logger.info(f"Removing columns from train dataset: {strip_cols}")
train_dataset = train_dataset.remove_columns(list(strip_cols))
strip_cols = set(validation_dataset.column_names) - set(keep_columns)
if strip_cols:
logger.info(f"Removing columns from validation dataset: {strip_cols}")
validation_dataset = validation_dataset.remove_columns(list(strip_cols))
strip_cols = set(test_dataset.column_names) - set(keep_columns)
if strip_cols:
logger.info(f"Removing columns from test dataset: {strip_cols}")
test_dataset = test_dataset.remove_columns(list(strip_cols))
# Ensure all schemas align with the train dataset
validation_dataset = validation_dataset.cast(train_dataset.features)
test_dataset = test_dataset.cast(train_dataset.features)
# Prepare results with all three splits
assert isinstance(train_dataset, Dataset)
assert isinstance(validation_dataset, Dataset)
assert isinstance(test_dataset, Dataset)
return DatasetDict({"train": train_dataset, "validation": validation_dataset, "test": test_dataset})
def validate_grpo_dataset(dataset: DatasetDict) -> None:
"""Validate that the dataset is properly formatted for GRPO training."""
required_columns = ["query", "ground_truth", "stepwise_ground_truths", "tools"]
for split_name, split_data in dataset.items():
logger.info(f"Validating {split_name} split...")
# Check required columns
missing_cols = set(required_columns) - set(split_data.column_names)
if missing_cols:
raise ValueError(f"Missing required columns in {split_name}: {missing_cols}")
# Validate sample entries
if len(split_data) > 0:
sample = split_data[0]
# Check query field
if not sample["query"] or not sample["query"].strip():
raise ValueError(f"Empty query in {split_name} split")
# Check ground_truth is valid JSON
try:
json.loads(sample["ground_truth"])
except json.JSONDecodeError as e:
raise ValueError(f"Invalid ground_truth JSON in {split_name}: {e}") from e
# Check stepwise_ground_truths is valid JSON list
try:
stepwise = json.loads(sample["stepwise_ground_truths"])
if not isinstance(stepwise, list):
raise TypeError(f"stepwise_ground_truths must be a list in {split_name}")
except json.JSONDecodeError as e:
raise ValueError(f"Invalid stepwise_ground_truths JSON in {split_name}: {e}") from e
# Check tools field
if not isinstance(sample["tools"], list):
raise ValueError(f"Tools field must be a list in {split_name}")
logger.info("*** Dataset validation completed successfully ***")
def prepare_debug(train: Dataset, validation: Dataset, test: Dataset, dataset_size: int) -> DatasetDict:
"""Prepare debug dataset with limited size."""
train = train.select(range(min(dataset_size, len(train))))
validation = validation.select(range(min(dataset_size, len(validation))))
test = test.select(range(min(dataset_size, len(test))))
return DatasetDict({"train": train, "validation": validation, "test": test})
def main(
train_repo: str,
test_repo: str,
output_repo: str,
push_to_hub: bool,
normalize_unicode: bool,
seed: int,
per_category: int,
) -> None:
"""Main processing function for GRPO dataset preparation."""
logger.info("*** Loading datasets for GRPO training ***")
dataset = load_datasets(train_repo, test_repo)
# Process for GRPO
logger.info("*** Processing dataset for GRPO training ***")
dataset = process_dataset_for_grpo(
dataset, normalize_unicode=normalize_unicode, seed=seed, per_category=per_category
)
# Validate the processed dataset
validate_grpo_dataset(dataset)
# Push to hub
if push_to_hub:
logger.info("*** Pushing to Hub ***")
try:
dataset.push_to_hub(output_repo)
logger.info(f"Successfully pushed dataset to https://huggingface.co/datasets/{output_repo}")
logger.info("Dataset is now ready for GRPO training!")
except Exception:
logger.exception("Failed to push to hub")
else:
logger.info("Dataset processing completed. Use --push_to_hub to upload to HuggingFace Hub.")
if __name__ == "__main__":
"""Script entry point for GRPO dataset preparation."""
parser = ArgumentParser(description="Prepare dataset for GRPO training")
parser.add_argument(
"--src_train", type=str, default="atomwalk12/linalgzero-distilled-clean", help="Source training dataset"
)
parser.add_argument("--src_test", type=str, default="atomwalk12/linalgzero", help="Source test dataset")
parser.add_argument(
"--output_repo", default="atomwalk12/linalgzero-grpo", type=str, help="Output repository name for GRPO dataset"
)
parser.add_argument(
"--push_to_hub", default=False, action="store_true", help="Whether to push the dataset to HuggingFace Hub"
)
parser.add_argument(
"--no_normalize_unicode",
default=False,
action="store_true",
help="Disable Unicode NFKC normalization and minus-sign replacement during dataset prep",
)
parser.add_argument("--seed", default=20, type=int, help="Random seed for dataset shuffling")
parser.add_argument(
"--per_category",
default=40,
type=int,
help="Number of representative examples per category for validation and test sets",
)
args = parser.parse_args()
main(
train_repo=args.src_train,
test_repo=args.src_test,
output_repo=args.output_repo,
push_to_hub=args.push_to_hub,
normalize_unicode=(not args.no_normalize_unicode),
seed=args.seed,
per_category=args.per_category,
)