Spaces:
Running on Zero
Running on Zero
| import argparse | |
| import logging | |
| from pathlib import Path | |
| import linalg_zero.generator.difficulty_config as dc | |
| from linalg_zero.generator.analysis.utils import ( | |
| compute_stepwise_value_statistics, | |
| print_statistics_summary, | |
| ) | |
| from linalg_zero.generator.core import DatasetGenerator | |
| from linalg_zero.generator.difficulty_config import DETERMINISTIC_MODE | |
| from linalg_zero.generator.models import DifficultyCategory, Question, Topic | |
| from linalg_zero.generator.registry import create_default_registry, create_optimized_registry | |
| from linalg_zero.generator.utils import ( | |
| check_constraints, | |
| convert_to_dataset_splits, | |
| load_entropy_settings, | |
| print_dataset, | |
| print_split_statistics, | |
| set_seed, | |
| verify_dataset, | |
| ) | |
| from linalg_zero.shared.utils import get_log_file_path, get_logger, push_to_hub, setup_logging | |
| def main( | |
| push_dataset: bool, use_optimized_registry: bool, dataset_name: str, n_one: int, n_two: int, n_three: int | |
| ) -> None: # pragma: no cover | |
| # Set up logging | |
| setup_logging(level=logging.INFO, include_timestamp=False) | |
| logger = get_logger(__name__) | |
| logger.info("Linear Algebra Dataset Generator") | |
| config_path = f"{Path(__file__).parent}/generator/config/gen_properties.json" | |
| # Create registry (either default or optimized) | |
| if use_optimized_registry: | |
| registry = create_optimized_registry(config_path=config_path) | |
| logger.info("Using optimized entropy settings from analysis results") | |
| else: | |
| registry = create_default_registry() | |
| logger.info("Available topics: %s", registry.list_topics()) | |
| # ----------------------------------------------- | |
| # Generate and display the linear algebra dataset | |
| # ----------------------------------------------- | |
| def matrix_only_validator(question: Question) -> bool: | |
| # A filter to only include questions that satisfy specific conditions | |
| return len(question.answer) > 0 | |
| generator = DatasetGenerator( | |
| topic=Topic.LINEAR_ALGEBRA, validator_factory=matrix_only_validator, registry=registry | |
| ) | |
| # Generate custom amounts per difficulty category | |
| # Easy: 3000, Medium: 2000, Hard: 1000 (total: 6000) | |
| dataset = generator.generate_exact_for_categories( | |
| requests={ | |
| DifficultyCategory.ONE_TOOL_CALL: n_one, | |
| DifficultyCategory.TWO_TOOL_CALLS: n_two, | |
| DifficultyCategory.THREE_TOOL_CALLS: n_three, | |
| } | |
| ) | |
| statistics = compute_stepwise_value_statistics(dataset) | |
| print_dataset(dataset) | |
| print_statistics_summary(statistics) | |
| verify_dataset(dataset) | |
| if use_optimized_registry: | |
| config = load_entropy_settings(config_path) | |
| check_constraints(dataset, config, statistics) | |
| # Create stratified splits by difficulty for balanced evaluation | |
| splits = convert_to_dataset_splits( | |
| dataset, | |
| test_size=0.1, | |
| val_size=0.1, | |
| seed=argv.seed or 42, | |
| stratify_by="difficulty", | |
| ) | |
| print_split_statistics(splits) | |
| if push_dataset: | |
| push_to_hub(splits, dataset_name, private=False, config_path=config_path) | |
| # -------------------------------------------------- | |
| # This is an example on generating other topic types | |
| # -------------------------------------------------- | |
| # arithmetic_generator = DatasetGenerator(topic="arithmetic") | |
| # arithmetic_questions = arithmetic_generator.generate_dataset(num_questions=2) | |
| # print_dataset(arithmetic_questions) | |
| logger.info(f"Log file path: {get_log_file_path()}") | |
| if __name__ == "__main__": # pragma: no cover | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--seed", type=int, default=42) | |
| parser.add_argument("--push_dataset", action="store_true", default=False) | |
| parser.add_argument("--dataset_name", type=str, default="atomwalk12/linalgzero") | |
| parser.add_argument( | |
| "--use_optimized_registry", | |
| action="store_true", | |
| default=True, | |
| help="Use optimized entropy settings from analysis results for dataset generation", | |
| ) | |
| parser.add_argument("--n_one", type=int, default=700, help="Per-generator 1-step samples") | |
| parser.add_argument("--n_two", type=int, default=900, help="Per-generator 2-step samples") | |
| parser.add_argument("--n_three", type=int, default=600, help="Per-generator 3-step samples") | |
| parser.add_argument("-scale", type=int, default=1.1, help="Scale the dataset by a factor") | |
| argv = parser.parse_args() | |
| if argv.seed is not None: | |
| set_seed(argv.seed) | |
| if DETERMINISTIC_MODE: | |
| # Let CLI seed control per-question reseed base when deterministic | |
| # Importing module and setting its global is sufficient | |
| dc.DETERMINISTIC_BASE_SEED = int(argv.seed) | |
| n_one = int(argv.n_one * argv.scale) | |
| n_two = int(argv.n_two * argv.scale) | |
| n_three = int(argv.n_three * argv.scale) | |
| main(argv.push_dataset, argv.use_optimized_registry, argv.dataset_name, n_one, n_two, n_three) | |