File size: 5,412 Bytes
0dd6c2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from collections.abc import Callable

from linalg_zero.generator.models import DifficultyCategory, Question, Topic
from linalg_zero.generator.registry import FactoryRegistry, create_default_registry
from linalg_zero.shared.utils import get_logger

logger = get_logger(__name__)


class QuestionGenerator:
    """
    Question generator using Instance Attribute Factory pattern. Here factories are passed as
    callables (i.e. functions, lambda expressions, methods, classes or partial functions).
    """

    def __init__(
        self, question_factory: Callable[[], Question], validator_factory: Callable[[Question], bool] | None = None
    ) -> None:
        """
        Initialize with factory callables.

        Args:
            question_factory: Any callable that returns a Question
            validator_factory: Optional callable to validate questions
        """
        self.question_factory = question_factory
        self.validator_factory = validator_factory or self._default_validator

    def generate(self) -> Question:
        """Generate a single question using the configured factories."""
        question = self.question_factory()

        # Set validation status using the configured validator
        question.is_valid = self.validator_factory(question)

        return question

    @staticmethod
    def _default_validator(question: Question) -> bool:
        """Default validator - checks basic requirements."""
        return len(question.question) > 0 and len(question.answer) > 0


class DatasetGenerator:
    """
    Dataset generator using Instance Attribute Factory pattern.

    Following python-patterns.guide recommendations - instead of a function
    with many parameters, use a class that accepts configuration in __init__.
    """

    def __init__(
        self,
        topic: Topic = Topic.LINEAR_ALGEBRA,
        validator_factory: Callable[[Question], bool] | None = None,
        max_attempts: int = 999999999,
        registry: FactoryRegistry | None = None,
    ):
        """Initialize with generation configuration."""
        self.topic = topic
        self.validator_factory = validator_factory or QuestionGenerator._default_validator
        self.max_attempts = max_attempts
        self.registry = registry or create_default_registry()

    def generate_dataset(self, num_questions: int) -> list[Question]:
        """Generate a dataset with the configured parameters (randomly across all factories)."""
        generator = QuestionGenerator(
            question_factory=lambda: self.registry.get_random_factory(self.topic)(),
            validator_factory=self.validator_factory,
        )

        questions: list[Question] = []
        attempts = 0

        while len(questions) < num_questions and attempts < self.max_attempts:
            question = generator.generate()
            if question.is_valid:
                questions.append(question)
            attempts += 1

        if len(questions) < num_questions:
            logger.warning(
                "Only generated %d/%d valid questions after %d attempts",
                len(questions),
                num_questions,
                self.max_attempts,
            )

        return questions

    def generate_exact_per_factory(self, difficulty: "DifficultyCategory", num_per_factory: int) -> list[Question]:
        """Generate exactly num_per_factory valid questions per registered factory of a category.

        This guarantees: total == (number_of_factories_in_category * num_per_factory).
        """

        if not isinstance(difficulty, DifficultyCategory):
            raise TypeError("difficulty must be a DifficultyCategory")

        factories = self.registry.get_factories_by_difficulty(self.topic, difficulty)
        if not factories:
            return []

        all_questions: list[Question] = []
        for factory in factories:
            per_factory_questions: list[Question] = []
            attempts = 0
            qg = QuestionGenerator(question_factory=factory, validator_factory=self.validator_factory)
            # Tight loop to ensure exact count (subject to max_attempts)
            while len(per_factory_questions) < num_per_factory and attempts < self.max_attempts:
                q = qg.generate()
                if q.is_valid:
                    per_factory_questions.append(q)
                attempts += 1
            if len(per_factory_questions) < num_per_factory:
                logger.warning(
                    "Factory produced %d/%d valid questions (difficulty=%s)",
                    len(per_factory_questions),
                    num_per_factory,
                    difficulty,
                )
            all_questions.extend(per_factory_questions)

        return all_questions

    def generate_exact_for_categories(self, requests: dict["DifficultyCategory", int]) -> list[Question]:
        """Generate exactly N per factory for each requested category.

        Example: {ONE_TOOL_CALL: 3000} will produce 3000 per registered factory in that category.
        """

        total: list[Question] = []
        for difficulty, num_per_factory in requests.items():
            if not isinstance(difficulty, DifficultyCategory):
                raise TypeError("All keys in requests must be DifficultyCategory")
            total.extend(self.generate_exact_per_factory(difficulty, num_per_factory))
        return total