File size: 8,049 Bytes
8d18b7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
"""Curriculum Learning and Difficulty-Aware Sampling"""

import logging
import random
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
from collections import defaultdict

import numpy as np
from torch.utils.data import Sampler

logger = logging.getLogger(__name__)


@dataclass
class CurriculumStage:
    """Definition of a curriculum learning stage."""
    name: str
    epoch: int
    domains: List[str]
    min_difficulty: float = 0.0
    max_difficulty: float = 1.0
    sampling_strategy: str = "balanced"  # "balanced", "weighted", "random"
    domain_weights: Optional[Dict[str, float]] = None


class CurriculumSampler(Sampler):
    """Sampler that implements curriculum learning strategy."""

    def __init__(

        self,

        dataset: Any,

        stages: List[CurriculumStage],

        current_epoch: int = 0,

        seed: int = 42,

    ):
        self.dataset = dataset
        self.stages = stages
        self.current_epoch = current_epoch
        self.seed = seed
        self.rng = np.random.RandomState(seed)

        # Build index by stage
        self._build_indices()
        self._current_stage = self._get_stage_for_epoch(current_epoch)

        logger.info(f"CurriculumSampler initialized with {len(stages)} stages")
        logger.info(f"Current epoch {current_epoch} -> stage '{self._current_stage.name}'")

    def _build_indices(self):
        """Build indices for each stage."""
        self.stage_indices = defaultdict(list)

        for idx, sample in enumerate(self.dataset):
            # Determine which stages this sample belongs to
            domain = sample.get("domain", "unknown")
            difficulty = sample.get("difficulty", 0.5)

            for stage in self.stages:
                if domain in stage.domains:
                    if stage.min_difficulty <= difficulty <= stage.max_difficulty:
                        self.stage_indices[stage.name].append(idx)

        logger.info(f"Built stage indices: {[(s, len(self.stage_indices[s])) for s in self.stage_indices]}")

    def _get_stage_for_epoch(self, epoch: int) -> CurriculumStage:
        """Get current stage based on epoch."""
        current_stage = self.stages[0]
        for stage in self.stages:
            if epoch >= stage.epoch:
                current_stage = stage
            else:
                break
        return current_stage

    def set_epoch(self, epoch: int):
        """Update current epoch and stage."""
        self.current_epoch = epoch
        self._current_stage = self._get_stage_for_epoch(epoch)
        logger.info(f"Epoch {epoch} -> stage '{self._current_stage.name}'")

    def __iter__(self):
        """Iterate over indices according to current stage."""
        indices = self.stage_indices[self._current_stage.name]

        if self._current_stage.sampling_strategy == "random":
            # Simple random shuffle
            shuffled = indices.copy()
            self.rng.shuffle(shuffled)
            return iter(shuffled)

        elif self._current_stage.sampling_strategy == "balanced":
            # Balanced across domains within stage
            # Group by domain
            domain_indices = defaultdict(list)
            for idx in indices:
                domain = self.dataset[idx].get("domain", "unknown")
                domain_indices[domain].append(idx)

            # Shuffle each domain
            for domain in domain_indices:
                self.rng.shuffle(domain_indices[domain])

            # Interleave
            result = []
            max_len = max(len(lst) for lst in domain_indices.values())
            for i in range(max_len):
                for domain in sorted(domain_indices.keys()):
                    if i < len(domain_indices[domain]):
                        result.append(domain_indices[domain][i])
            return iter(result)

        elif self._current_stage.sampling_strategy == "weighted":
            # Weighted sampling based on domain_weights
            weights = self._current_stage.domain_weights or {}
            # Build weighted list
            weighted_indices = []
            for idx in indices:
                domain = self.dataset[idx].get("domain", "unknown")
                weight = weights.get(domain, 1.0)
                weighted_indices.extend([idx] * int(weight * 10))
            self.rng.shuffle(weighted_indices)
            return iter(weighted_indices)

        else:
            # Default: random shuffle
            shuffled = indices.copy()
            self.rng.shuffle(shuffled)
            return iter(shuffled)

    def __len__(self) -> int:
        """Return number of samples in current stage."""
        return len(self.stage_indices[self._current_stage.name])


class DifficultyAwareSampler(Sampler):
    """Sampler that weights samples by difficulty for progressive learning."""

    def __init__(

        self,

        dataset: Any,

        difficulty_key: str = "difficulty",

        temperature: float = 1.0,

        start_difficulty: float = 0.0,

        end_difficulty: float = 1.0,

        epochs: int = 10,

        seed: int = 42,

    ):
        self.dataset = dataset
        self.difficulty_key = difficulty_key
        self.temperature = temperature
        self.start_difficulty = start_difficulty
        self.end_difficulty = end_difficulty
        self.epochs = epochs
        self.seed = seed
        self.rng = np.random.RandomState(seed)

        # Precompute difficulties
        self.difficulties = np.array([
            sample.get(difficulty_key, 0.5) for sample in dataset
        ])

        self.current_epoch = 0

    def set_epoch(self, epoch: int):
        """Update epoch."""
        self.current_epoch = epoch

    def _get_difficulty_threshold(self) -> float:
        """Get current difficulty threshold based on epoch."""
        progress = min(1.0, self.current_epoch / self.epochs)
        return self.start_difficulty + progress * (self.end_difficulty - self.start_difficulty)

    def _compute_weights(self) -> np.ndarray:
        """Compute sampling weights based on current difficulty threshold."""
        threshold = self._get_difficulty_threshold()

        # Weight samples below threshold more, above threshold less
        # Use exponential decay for weight
        weights = np.exp(-self.difficulties / (1.0 - threshold + 1e-6))

        # Normalize
        weights = weights / weights.sum()

        return weights

    def __iter__(self):
        """Iterate with difficulty-based sampling."""
        weights = self._compute_weights()

        # Sample indices according to weights
        indices = np.random.choice(
            len(self.dataset),
            size=len(self.dataset),
            replace=False,
            p=weights,
        )

        return iter(indices.tolist())

    def __len__(self) -> int:
        return len(self.dataset)


def create_curriculum_sampler(

    dataset: Any,

    curriculum_config: Any,  # CurriculumConfig

    current_epoch: int = 0,

    seed: int = 42,

) -> Optional[CurriculumSampler]:
    """Create curriculum sampler from config."""
    if not curriculum_config.enable_curriculum:
        return None

    # Build stages from config
    stages = []
    for stage_cfg in curriculum_config.stages:
        stage = CurriculumStage(
            name=stage_cfg["name"],
            epoch=stage_cfg["epoch"],
            domains=[d.value for d in stage_cfg["domains"]],
            min_difficulty=stage_cfg.get("min_difficulty", 0.0),
            max_difficulty=stage_cfg.get("max_difficulty", 1.0),
            sampling_strategy=stage_cfg.get("sampling_strategy", "balanced"),
            domain_weights=stage_cfg.get("domain_weights"),
        )
        stages.append(stage)

    return CurriculumSampler(dataset, stages, current_epoch, seed)