grpo-countdown-problem / src /scripts /generate_training_dataset_sft.py
Dat1710's picture
Upload folder using huggingface_hub
00db46c verified
#!/usr/bin/env python3
"""
Script to generate a CSV file containing arithmetic problems for SFT training data.
This script uses the arithmetic utilities to generate problems and creates a CSV
with columns: id, problem_description, correct_answer, and reasoning.
The reasoning column contains the step-by-step thought process for solving the problem.
"""
import argparse
import csv
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Any
from dotenv import load_dotenv
from src.dataset.grpo import get_reasoning_for_answer
from src.utils.arithmetics import (
ArithmeticProblemDescriptionGenerator,
ArithmeticProblemGenerator,
Mode,
)
# Load environment variables
load_dotenv()
# Configure logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
def generate_single_problem_with_reasoning(
problem_generator: ArithmeticProblemGenerator,
description_generator: ArithmeticProblemDescriptionGenerator,
problem_id: int,
) -> dict[str, Any] | None:
"""
Generate a single problem with reasoning.
Args:
problem_generator: The arithmetic problem generator
description_generator: The problem description generator
problem_id: The ID for this problem
Returns:
Optional[Dict[str, Any]]: Training data entry with reasoning, or None if generation failed
"""
# Generate a problem
problem = problem_generator.generate_problem()
if problem is None:
return None
# Generate description
problem_description, _ = description_generator.generate_description(problem)
correct_answer = problem.expression
# Generate reasoning using OpenAI
logger.info(f"Generating reasoning for problem {problem_id}...")
reasoning = get_reasoning_for_answer(
problem_description,
problem.expression
)
# Create training data entry
training_entry = {
"id": problem_id,
"problem_description": problem_description,
"correct_answer": correct_answer,
"reasoning": reasoning,
}
return training_entry
def generate_training_data(
num_problems: int, max_workers: int = 8
) -> list[dict[str, Any]]:
"""
Generate training data with arithmetic problems and reasoning using threading.
Args:
num_problems: Number of problems to generate
max_workers: Maximum number of worker threads for parallel processing
Returns:
List[Dict[str, Any]]: List of dictionaries containing training data with reasoning
"""
problem_generator = ArithmeticProblemGenerator(mode=Mode.MUL_DIV)
description_generator = ArithmeticProblemDescriptionGenerator()
training_data = []
max_total_attempts = num_problems * 10 # Allow more attempts than problems
logger.info(
f"Starting generation of {num_problems} problems with reasoning using {max_workers} workers..."
)
# Generate problems in batches to avoid over-submission
attempts = 0
while len(training_data) < num_problems and attempts < max_total_attempts:
# Calculate how many more problems we need
remaining_problems = num_problems - len(training_data)
# Submit a batch of tasks (no more than we need + small buffer for failures)
batch_size = min(
max_workers, remaining_problems + 2, max_total_attempts - attempts
)
logger.info(
f"Submitting batch of {batch_size} tasks. Need {remaining_problems} more problems."
)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit batch of tasks
futures = []
for i in range(batch_size):
future = executor.submit(
generate_single_problem_with_reasoning,
problem_generator,
description_generator,
attempts + i + 1,
)
futures.append(future)
# Collect results from this batch
batch_results = []
for future in as_completed(futures):
try:
result = future.result()
if result is not None:
batch_results.append(result)
# Stop collecting if we have enough problems
if len(training_data) + len(batch_results) >= num_problems:
break
except Exception as e:
logger.error(f"Error generating problem: {e}")
# Add successful results to training data
for _, result in enumerate(batch_results):
if len(training_data) >= num_problems:
break
result["id"] = len(training_data) + 1
training_data.append(result)
if len(training_data) % 10 == 0:
logger.info(
f"Generated {len(training_data)} problems with reasoning..."
)
attempts += batch_size
# Log progress
logger.info(
f"Completed batch. Have {len(training_data)} problems, need {num_problems}"
)
# Stop if we have enough problems
if len(training_data) >= num_problems:
break
if len(training_data) < num_problems:
logger.warning(
f"Only generated {len(training_data)} out of {num_problems} requested problems after {attempts} attempts"
)
else:
logger.info(
f"Successfully generated {len(training_data)} problems with reasoning in {attempts} attempts"
)
# Sort training data by ID to maintain order
training_data.sort(key=lambda x: x["id"])
return training_data
def save_to_csv(training_data: list[dict[str, Any]], output_file: Path) -> None:
"""
Save training data to a CSV file.
Args:
training_data: List of training data dictionaries
output_file: Path to the output CSV file
"""
if not training_data:
logger.error("No training data to save")
return
# Create output directory if it doesn't exist
output_file.parent.mkdir(parents=True, exist_ok=True)
logger.info(f"Saving {len(training_data)} problems to {output_file}")
with open(output_file, "w", newline="", encoding="utf-8") as csvfile:
fieldnames = ["id", "problem_description", "correct_answer", "reasoning"]
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
# Write header
writer.writeheader()
# Write data
for entry in training_data:
writer.writerow(entry)
logger.info(f"Successfully saved training data to {output_file}")
def main() -> None:
"""
Main function to handle command line arguments and orchestrate the generation process.
"""
parser = argparse.ArgumentParser(
description="Generate arithmetic problems for SFT training data in CSV format with reasoning"
)
parser.add_argument(
"--num_problems", type=int, required=True, help="Number of problems to generate"
)
parser.add_argument(
"--output_file", type=str, required=True, help="Path to the output CSV file"
)
parser.add_argument(
"--max_workers",
type=int,
default=8,
help="Maximum number of worker threads for parallel processing (default: 8)",
)
args = parser.parse_args()
# Validate arguments
if args.num_problems <= 0:
logger.error("Number of problems must be positive")
return
if args.max_workers <= 0:
logger.error("Number of workers must be positive")
return
output_path = Path(args.output_file)
# Generate training data
training_data = generate_training_data(args.num_problems, args.max_workers)
if not training_data:
logger.error("Failed to generate any training data")
return
# Save to CSV
save_to_csv(training_data, output_path)
if __name__ == "__main__":
main()