|
|
|
|
|
"""
|
|
|
Script to generate a CSV file containing arithmetic problems for training data.
|
|
|
|
|
|
This script uses the arithmetic utilities to generate problems and creates a CSV
|
|
|
with columns: id, problem_description, correct_answer, num1, num2, num3, num4.
|
|
|
"""
|
|
|
|
|
|
import argparse
|
|
|
import csv
|
|
|
import logging
|
|
|
from pathlib import Path
|
|
|
from typing import Any
|
|
|
|
|
|
from src.utils.arithmetics import (
|
|
|
ArithmeticProblemDescriptionGenerator,
|
|
|
ArithmeticProblemGenerator,
|
|
|
Mode,
|
|
|
)
|
|
|
|
|
|
|
|
|
logging.basicConfig(
|
|
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
|
|
)
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
def generate_training_data(num_problems: int) -> list[dict[str, Any]]:
|
|
|
"""
|
|
|
Generate training data with arithmetic problems.
|
|
|
|
|
|
Args:
|
|
|
num_problems: Number of problems to generate
|
|
|
|
|
|
Returns:
|
|
|
List[Dict[str, Any]]: List of dictionaries containing training data
|
|
|
"""
|
|
|
problem_generator = ArithmeticProblemGenerator(mode=Mode.MUL_DIV)
|
|
|
description_generator = ArithmeticProblemDescriptionGenerator()
|
|
|
|
|
|
training_data = []
|
|
|
generated_count = 0
|
|
|
attempts = 0
|
|
|
max_total_attempts = num_problems * 10
|
|
|
|
|
|
logger.info(f"Starting generation of {num_problems} problems...")
|
|
|
|
|
|
while generated_count < num_problems and attempts < max_total_attempts:
|
|
|
attempts += 1
|
|
|
|
|
|
|
|
|
problem = problem_generator.generate_problem()
|
|
|
|
|
|
if problem is None:
|
|
|
continue
|
|
|
|
|
|
|
|
|
problem_description, correct_answer = (
|
|
|
description_generator.generate_description(problem)
|
|
|
)
|
|
|
|
|
|
|
|
|
training_entry = {
|
|
|
"id": generated_count + 1,
|
|
|
"problem_description": problem_description,
|
|
|
"correct_answer": correct_answer,
|
|
|
"num1": problem.num_1,
|
|
|
"num2": problem.num_2,
|
|
|
"num3": problem.num_3,
|
|
|
"num4": problem.num_4,
|
|
|
}
|
|
|
|
|
|
training_data.append(training_entry)
|
|
|
generated_count += 1
|
|
|
|
|
|
if generated_count % 100 == 0:
|
|
|
logger.info(f"Generated {generated_count} problems...")
|
|
|
|
|
|
if generated_count < num_problems:
|
|
|
logger.warning(
|
|
|
f"Only generated {generated_count} out of {num_problems} requested problems after {attempts} attempts"
|
|
|
)
|
|
|
else:
|
|
|
logger.info(
|
|
|
f"Successfully generated {generated_count} problems in {attempts} attempts"
|
|
|
)
|
|
|
|
|
|
return training_data
|
|
|
|
|
|
|
|
|
def save_to_csv(training_data: list[dict[str, Any]], output_file: Path) -> None:
|
|
|
"""
|
|
|
Save training data to a CSV file.
|
|
|
|
|
|
Args:
|
|
|
training_data: List of training data dictionaries
|
|
|
output_file: Path to the output CSV file
|
|
|
"""
|
|
|
if not training_data:
|
|
|
logger.error("No training data to save")
|
|
|
return
|
|
|
|
|
|
|
|
|
output_file.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
logger.info(f"Saving {len(training_data)} problems to {output_file}")
|
|
|
|
|
|
with open(output_file, "w", newline="", encoding="utf-8") as csvfile:
|
|
|
fieldnames = [
|
|
|
"id",
|
|
|
"problem_description",
|
|
|
"correct_answer",
|
|
|
"num1",
|
|
|
"num2",
|
|
|
"num3",
|
|
|
"num4",
|
|
|
]
|
|
|
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
|
|
|
|
|
|
|
|
writer.writeheader()
|
|
|
|
|
|
|
|
|
for entry in training_data:
|
|
|
writer.writerow(entry)
|
|
|
|
|
|
logger.info(f"Successfully saved training data to {output_file}")
|
|
|
|
|
|
|
|
|
def main() -> None:
|
|
|
"""
|
|
|
Main function to handle command line arguments and orchestrate the generation process.
|
|
|
"""
|
|
|
parser = argparse.ArgumentParser(
|
|
|
description="Generate arithmetic problems for training data in CSV format"
|
|
|
)
|
|
|
|
|
|
parser.add_argument(
|
|
|
"--num_problems", type=int, required=True, help="Number of problems to generate"
|
|
|
)
|
|
|
|
|
|
parser.add_argument(
|
|
|
"--output_file", type=str, required=True, help="Path to the output CSV file"
|
|
|
)
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
if args.num_problems <= 0:
|
|
|
logger.error("Number of problems must be positive")
|
|
|
return
|
|
|
|
|
|
output_path = Path(args.output_file)
|
|
|
|
|
|
|
|
|
training_data = generate_training_data(args.num_problems)
|
|
|
|
|
|
if not training_data:
|
|
|
logger.error("Failed to generate any training data")
|
|
|
return
|
|
|
|
|
|
|
|
|
save_to_csv(training_data, output_path)
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
|
|