linalg-zero / linalg_zero /generate.py
atomwalk12's picture
initial commit
0dd6c2f
import argparse
import logging
from pathlib import Path
import linalg_zero.generator.difficulty_config as dc
from linalg_zero.generator.analysis.utils import (
compute_stepwise_value_statistics,
print_statistics_summary,
)
from linalg_zero.generator.core import DatasetGenerator
from linalg_zero.generator.difficulty_config import DETERMINISTIC_MODE
from linalg_zero.generator.models import DifficultyCategory, Question, Topic
from linalg_zero.generator.registry import create_default_registry, create_optimized_registry
from linalg_zero.generator.utils import (
check_constraints,
convert_to_dataset_splits,
load_entropy_settings,
print_dataset,
print_split_statistics,
set_seed,
verify_dataset,
)
from linalg_zero.shared.utils import get_log_file_path, get_logger, push_to_hub, setup_logging
def main(
push_dataset: bool, use_optimized_registry: bool, dataset_name: str, n_one: int, n_two: int, n_three: int
) -> None: # pragma: no cover
# Set up logging
setup_logging(level=logging.INFO, include_timestamp=False)
logger = get_logger(__name__)
logger.info("Linear Algebra Dataset Generator")
config_path = f"{Path(__file__).parent}/generator/config/gen_properties.json"
# Create registry (either default or optimized)
if use_optimized_registry:
registry = create_optimized_registry(config_path=config_path)
logger.info("Using optimized entropy settings from analysis results")
else:
registry = create_default_registry()
logger.info("Available topics: %s", registry.list_topics())
# -----------------------------------------------
# Generate and display the linear algebra dataset
# -----------------------------------------------
def matrix_only_validator(question: Question) -> bool:
# A filter to only include questions that satisfy specific conditions
return len(question.answer) > 0
generator = DatasetGenerator(
topic=Topic.LINEAR_ALGEBRA, validator_factory=matrix_only_validator, registry=registry
)
# Generate custom amounts per difficulty category
# Easy: 3000, Medium: 2000, Hard: 1000 (total: 6000)
dataset = generator.generate_exact_for_categories(
requests={
DifficultyCategory.ONE_TOOL_CALL: n_one,
DifficultyCategory.TWO_TOOL_CALLS: n_two,
DifficultyCategory.THREE_TOOL_CALLS: n_three,
}
)
statistics = compute_stepwise_value_statistics(dataset)
print_dataset(dataset)
print_statistics_summary(statistics)
verify_dataset(dataset)
if use_optimized_registry:
config = load_entropy_settings(config_path)
check_constraints(dataset, config, statistics)
# Create stratified splits by difficulty for balanced evaluation
splits = convert_to_dataset_splits(
dataset,
test_size=0.1,
val_size=0.1,
seed=argv.seed or 42,
stratify_by="difficulty",
)
print_split_statistics(splits)
if push_dataset:
push_to_hub(splits, dataset_name, private=False, config_path=config_path)
# --------------------------------------------------
# This is an example on generating other topic types
# --------------------------------------------------
# arithmetic_generator = DatasetGenerator(topic="arithmetic")
# arithmetic_questions = arithmetic_generator.generate_dataset(num_questions=2)
# print_dataset(arithmetic_questions)
logger.info(f"Log file path: {get_log_file_path()}")
if __name__ == "__main__": # pragma: no cover
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--push_dataset", action="store_true", default=False)
parser.add_argument("--dataset_name", type=str, default="atomwalk12/linalgzero")
parser.add_argument(
"--use_optimized_registry",
action="store_true",
default=True,
help="Use optimized entropy settings from analysis results for dataset generation",
)
parser.add_argument("--n_one", type=int, default=700, help="Per-generator 1-step samples")
parser.add_argument("--n_two", type=int, default=900, help="Per-generator 2-step samples")
parser.add_argument("--n_three", type=int, default=600, help="Per-generator 3-step samples")
parser.add_argument("-scale", type=int, default=1.1, help="Scale the dataset by a factor")
argv = parser.parse_args()
if argv.seed is not None:
set_seed(argv.seed)
if DETERMINISTIC_MODE:
# Let CLI seed control per-question reseed base when deterministic
# Importing module and setting its global is sufficient
dc.DETERMINISTIC_BASE_SEED = int(argv.seed)
n_one = int(argv.n_one * argv.scale)
n_two = int(argv.n_two * argv.scale)
n_three = int(argv.n_three * argv.scale)
main(argv.push_dataset, argv.use_optimized_registry, argv.dataset_name, n_one, n_two, n_three)