File size: 8,049 Bytes
8d18b7c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 | """Curriculum Learning and Difficulty-Aware Sampling"""
import logging
import random
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
from collections import defaultdict
import numpy as np
from torch.utils.data import Sampler
logger = logging.getLogger(__name__)
@dataclass
class CurriculumStage:
"""Definition of a curriculum learning stage."""
name: str
epoch: int
domains: List[str]
min_difficulty: float = 0.0
max_difficulty: float = 1.0
sampling_strategy: str = "balanced" # "balanced", "weighted", "random"
domain_weights: Optional[Dict[str, float]] = None
class CurriculumSampler(Sampler):
"""Sampler that implements curriculum learning strategy."""
def __init__(
self,
dataset: Any,
stages: List[CurriculumStage],
current_epoch: int = 0,
seed: int = 42,
):
self.dataset = dataset
self.stages = stages
self.current_epoch = current_epoch
self.seed = seed
self.rng = np.random.RandomState(seed)
# Build index by stage
self._build_indices()
self._current_stage = self._get_stage_for_epoch(current_epoch)
logger.info(f"CurriculumSampler initialized with {len(stages)} stages")
logger.info(f"Current epoch {current_epoch} -> stage '{self._current_stage.name}'")
def _build_indices(self):
"""Build indices for each stage."""
self.stage_indices = defaultdict(list)
for idx, sample in enumerate(self.dataset):
# Determine which stages this sample belongs to
domain = sample.get("domain", "unknown")
difficulty = sample.get("difficulty", 0.5)
for stage in self.stages:
if domain in stage.domains:
if stage.min_difficulty <= difficulty <= stage.max_difficulty:
self.stage_indices[stage.name].append(idx)
logger.info(f"Built stage indices: {[(s, len(self.stage_indices[s])) for s in self.stage_indices]}")
def _get_stage_for_epoch(self, epoch: int) -> CurriculumStage:
"""Get current stage based on epoch."""
current_stage = self.stages[0]
for stage in self.stages:
if epoch >= stage.epoch:
current_stage = stage
else:
break
return current_stage
def set_epoch(self, epoch: int):
"""Update current epoch and stage."""
self.current_epoch = epoch
self._current_stage = self._get_stage_for_epoch(epoch)
logger.info(f"Epoch {epoch} -> stage '{self._current_stage.name}'")
def __iter__(self):
"""Iterate over indices according to current stage."""
indices = self.stage_indices[self._current_stage.name]
if self._current_stage.sampling_strategy == "random":
# Simple random shuffle
shuffled = indices.copy()
self.rng.shuffle(shuffled)
return iter(shuffled)
elif self._current_stage.sampling_strategy == "balanced":
# Balanced across domains within stage
# Group by domain
domain_indices = defaultdict(list)
for idx in indices:
domain = self.dataset[idx].get("domain", "unknown")
domain_indices[domain].append(idx)
# Shuffle each domain
for domain in domain_indices:
self.rng.shuffle(domain_indices[domain])
# Interleave
result = []
max_len = max(len(lst) for lst in domain_indices.values())
for i in range(max_len):
for domain in sorted(domain_indices.keys()):
if i < len(domain_indices[domain]):
result.append(domain_indices[domain][i])
return iter(result)
elif self._current_stage.sampling_strategy == "weighted":
# Weighted sampling based on domain_weights
weights = self._current_stage.domain_weights or {}
# Build weighted list
weighted_indices = []
for idx in indices:
domain = self.dataset[idx].get("domain", "unknown")
weight = weights.get(domain, 1.0)
weighted_indices.extend([idx] * int(weight * 10))
self.rng.shuffle(weighted_indices)
return iter(weighted_indices)
else:
# Default: random shuffle
shuffled = indices.copy()
self.rng.shuffle(shuffled)
return iter(shuffled)
def __len__(self) -> int:
"""Return number of samples in current stage."""
return len(self.stage_indices[self._current_stage.name])
class DifficultyAwareSampler(Sampler):
"""Sampler that weights samples by difficulty for progressive learning."""
def __init__(
self,
dataset: Any,
difficulty_key: str = "difficulty",
temperature: float = 1.0,
start_difficulty: float = 0.0,
end_difficulty: float = 1.0,
epochs: int = 10,
seed: int = 42,
):
self.dataset = dataset
self.difficulty_key = difficulty_key
self.temperature = temperature
self.start_difficulty = start_difficulty
self.end_difficulty = end_difficulty
self.epochs = epochs
self.seed = seed
self.rng = np.random.RandomState(seed)
# Precompute difficulties
self.difficulties = np.array([
sample.get(difficulty_key, 0.5) for sample in dataset
])
self.current_epoch = 0
def set_epoch(self, epoch: int):
"""Update epoch."""
self.current_epoch = epoch
def _get_difficulty_threshold(self) -> float:
"""Get current difficulty threshold based on epoch."""
progress = min(1.0, self.current_epoch / self.epochs)
return self.start_difficulty + progress * (self.end_difficulty - self.start_difficulty)
def _compute_weights(self) -> np.ndarray:
"""Compute sampling weights based on current difficulty threshold."""
threshold = self._get_difficulty_threshold()
# Weight samples below threshold more, above threshold less
# Use exponential decay for weight
weights = np.exp(-self.difficulties / (1.0 - threshold + 1e-6))
# Normalize
weights = weights / weights.sum()
return weights
def __iter__(self):
"""Iterate with difficulty-based sampling."""
weights = self._compute_weights()
# Sample indices according to weights
indices = np.random.choice(
len(self.dataset),
size=len(self.dataset),
replace=False,
p=weights,
)
return iter(indices.tolist())
def __len__(self) -> int:
return len(self.dataset)
def create_curriculum_sampler(
dataset: Any,
curriculum_config: Any, # CurriculumConfig
current_epoch: int = 0,
seed: int = 42,
) -> Optional[CurriculumSampler]:
"""Create curriculum sampler from config."""
if not curriculum_config.enable_curriculum:
return None
# Build stages from config
stages = []
for stage_cfg in curriculum_config.stages:
stage = CurriculumStage(
name=stage_cfg["name"],
epoch=stage_cfg["epoch"],
domains=[d.value for d in stage_cfg["domains"]],
min_difficulty=stage_cfg.get("min_difficulty", 0.0),
max_difficulty=stage_cfg.get("max_difficulty", 1.0),
sampling_strategy=stage_cfg.get("sampling_strategy", "balanced"),
domain_weights=stage_cfg.get("domain_weights"),
)
stages.append(stage)
return CurriculumSampler(dataset, stages, current_epoch, seed)
|