grpo-countdown-problem / src /scripts /generate_training_dataset_grpo.py
Dat1710's picture
Upload folder using huggingface_hub
00db46c verified
#!/usr/bin/env python3
"""
Script to generate a CSV file containing arithmetic problems for training data.
This script uses the arithmetic utilities to generate problems and creates a CSV
with columns: id, problem_description, correct_answer, num1, num2, num3, num4.
"""
import argparse
import csv
import logging
from pathlib import Path
from typing import Any
from src.utils.arithmetics import (
ArithmeticProblemDescriptionGenerator,
ArithmeticProblemGenerator,
Mode,
)
# Configure logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
def generate_training_data(num_problems: int) -> list[dict[str, Any]]:
"""
Generate training data with arithmetic problems.
Args:
num_problems: Number of problems to generate
Returns:
List[Dict[str, Any]]: List of dictionaries containing training data
"""
problem_generator = ArithmeticProblemGenerator(mode=Mode.MUL_DIV)
description_generator = ArithmeticProblemDescriptionGenerator()
training_data = []
generated_count = 0
attempts = 0
max_total_attempts = num_problems * 10 # Allow more attempts than problems
logger.info(f"Starting generation of {num_problems} problems...")
while generated_count < num_problems and attempts < max_total_attempts:
attempts += 1
# Generate a problem
problem = problem_generator.generate_problem()
if problem is None:
continue
# Generate description
problem_description, correct_answer = (
description_generator.generate_description(problem)
)
# Create training data entry
training_entry = {
"id": generated_count + 1,
"problem_description": problem_description,
"correct_answer": correct_answer,
"num1": problem.num_1,
"num2": problem.num_2,
"num3": problem.num_3,
"num4": problem.num_4,
}
training_data.append(training_entry)
generated_count += 1
if generated_count % 100 == 0:
logger.info(f"Generated {generated_count} problems...")
if generated_count < num_problems:
logger.warning(
f"Only generated {generated_count} out of {num_problems} requested problems after {attempts} attempts"
)
else:
logger.info(
f"Successfully generated {generated_count} problems in {attempts} attempts"
)
return training_data
def save_to_csv(training_data: list[dict[str, Any]], output_file: Path) -> None:
"""
Save training data to a CSV file.
Args:
training_data: List of training data dictionaries
output_file: Path to the output CSV file
"""
if not training_data:
logger.error("No training data to save")
return
# Create output directory if it doesn't exist
output_file.parent.mkdir(parents=True, exist_ok=True)
logger.info(f"Saving {len(training_data)} problems to {output_file}")
with open(output_file, "w", newline="", encoding="utf-8") as csvfile:
fieldnames = [
"id",
"problem_description",
"correct_answer",
"num1",
"num2",
"num3",
"num4",
]
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
# Write header
writer.writeheader()
# Write data
for entry in training_data:
writer.writerow(entry)
logger.info(f"Successfully saved training data to {output_file}")
def main() -> None:
"""
Main function to handle command line arguments and orchestrate the generation process.
"""
parser = argparse.ArgumentParser(
description="Generate arithmetic problems for training data in CSV format"
)
parser.add_argument(
"--num_problems", type=int, required=True, help="Number of problems to generate"
)
parser.add_argument(
"--output_file", type=str, required=True, help="Path to the output CSV file"
)
args = parser.parse_args()
# Validate arguments
if args.num_problems <= 0:
logger.error("Number of problems must be positive")
return
output_path = Path(args.output_file)
# Generate training data
training_data = generate_training_data(args.num_problems)
if not training_data:
logger.error("Failed to generate any training data")
return
# Save to CSV
save_to_csv(training_data, output_path)
if __name__ == "__main__":
main()