File size: 4,977 Bytes
0dd6c2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import argparse
import logging
from pathlib import Path

import linalg_zero.generator.difficulty_config as dc
from linalg_zero.generator.analysis.utils import (
    compute_stepwise_value_statistics,
    print_statistics_summary,
)
from linalg_zero.generator.core import DatasetGenerator
from linalg_zero.generator.difficulty_config import DETERMINISTIC_MODE
from linalg_zero.generator.models import DifficultyCategory, Question, Topic
from linalg_zero.generator.registry import create_default_registry, create_optimized_registry
from linalg_zero.generator.utils import (
    check_constraints,
    convert_to_dataset_splits,
    load_entropy_settings,
    print_dataset,
    print_split_statistics,
    set_seed,
    verify_dataset,
)
from linalg_zero.shared.utils import get_log_file_path, get_logger, push_to_hub, setup_logging


def main(
    push_dataset: bool, use_optimized_registry: bool, dataset_name: str, n_one: int, n_two: int, n_three: int
) -> None:  # pragma: no cover
    # Set up logging
    setup_logging(level=logging.INFO, include_timestamp=False)
    logger = get_logger(__name__)

    logger.info("Linear Algebra Dataset Generator")
    config_path = f"{Path(__file__).parent}/generator/config/gen_properties.json"

    # Create registry (either default or optimized)
    if use_optimized_registry:
        registry = create_optimized_registry(config_path=config_path)
        logger.info("Using optimized entropy settings from analysis results")
    else:
        registry = create_default_registry()

    logger.info("Available topics: %s", registry.list_topics())

    # -----------------------------------------------
    # Generate and display the linear algebra dataset
    # -----------------------------------------------
    def matrix_only_validator(question: Question) -> bool:
        # A filter to only include questions that satisfy specific conditions
        return len(question.answer) > 0

    generator = DatasetGenerator(
        topic=Topic.LINEAR_ALGEBRA, validator_factory=matrix_only_validator, registry=registry
    )

    # Generate custom amounts per difficulty category
    # Easy: 3000, Medium: 2000, Hard: 1000 (total: 6000)
    dataset = generator.generate_exact_for_categories(
        requests={
            DifficultyCategory.ONE_TOOL_CALL: n_one,
            DifficultyCategory.TWO_TOOL_CALLS: n_two,
            DifficultyCategory.THREE_TOOL_CALLS: n_three,
        }
    )
    statistics = compute_stepwise_value_statistics(dataset)
    print_dataset(dataset)
    print_statistics_summary(statistics)
    verify_dataset(dataset)
    if use_optimized_registry:
        config = load_entropy_settings(config_path)
        check_constraints(dataset, config, statistics)

    # Create stratified splits by difficulty for balanced evaluation
    splits = convert_to_dataset_splits(
        dataset,
        test_size=0.1,
        val_size=0.1,
        seed=argv.seed or 42,
        stratify_by="difficulty",
    )
    print_split_statistics(splits)
    if push_dataset:
        push_to_hub(splits, dataset_name, private=False, config_path=config_path)

    # --------------------------------------------------
    # This is an example on generating other topic types
    # --------------------------------------------------
    # arithmetic_generator = DatasetGenerator(topic="arithmetic")
    # arithmetic_questions = arithmetic_generator.generate_dataset(num_questions=2)
    # print_dataset(arithmetic_questions)
    logger.info(f"Log file path: {get_log_file_path()}")


if __name__ == "__main__":  # pragma: no cover
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--push_dataset", action="store_true", default=False)
    parser.add_argument("--dataset_name", type=str, default="atomwalk12/linalgzero")
    parser.add_argument(
        "--use_optimized_registry",
        action="store_true",
        default=True,
        help="Use optimized entropy settings from analysis results for dataset generation",
    )
    parser.add_argument("--n_one", type=int, default=700, help="Per-generator 1-step samples")
    parser.add_argument("--n_two", type=int, default=900, help="Per-generator 2-step samples")
    parser.add_argument("--n_three", type=int, default=600, help="Per-generator 3-step samples")
    parser.add_argument("-scale", type=int, default=1.1, help="Scale the dataset by a factor")
    argv = parser.parse_args()
    if argv.seed is not None:
        set_seed(argv.seed)
        if DETERMINISTIC_MODE:
            # Let CLI seed control per-question reseed base when deterministic
            # Importing module and setting its global is sufficient
            dc.DETERMINISTIC_BASE_SEED = int(argv.seed)

    n_one = int(argv.n_one * argv.scale)
    n_two = int(argv.n_two * argv.scale)
    n_three = int(argv.n_three * argv.scale)

    main(argv.push_dataset, argv.use_optimized_registry, argv.dataset_name, n_one, n_two, n_three)