| """Curriculum Learning and Difficulty-Aware Sampling"""
|
|
|
| import logging
|
| import random
|
| from dataclasses import dataclass, field
|
| from typing import Any, Dict, List, Optional, Tuple
|
| from collections import defaultdict
|
|
|
| import numpy as np
|
| from torch.utils.data import Sampler
|
|
|
| logger = logging.getLogger(__name__)
|
|
|
|
|
| @dataclass
|
| class CurriculumStage:
|
| """Definition of a curriculum learning stage."""
|
| name: str
|
| epoch: int
|
| domains: List[str]
|
| min_difficulty: float = 0.0
|
| max_difficulty: float = 1.0
|
| sampling_strategy: str = "balanced"
|
| domain_weights: Optional[Dict[str, float]] = None
|
|
|
|
|
| class CurriculumSampler(Sampler):
|
| """Sampler that implements curriculum learning strategy."""
|
|
|
| def __init__(
|
| self,
|
| dataset: Any,
|
| stages: List[CurriculumStage],
|
| current_epoch: int = 0,
|
| seed: int = 42,
|
| ):
|
| self.dataset = dataset
|
| self.stages = stages
|
| self.current_epoch = current_epoch
|
| self.seed = seed
|
| self.rng = np.random.RandomState(seed)
|
|
|
|
|
| self._build_indices()
|
| self._current_stage = self._get_stage_for_epoch(current_epoch)
|
|
|
| logger.info(f"CurriculumSampler initialized with {len(stages)} stages")
|
| logger.info(f"Current epoch {current_epoch} -> stage '{self._current_stage.name}'")
|
|
|
| def _build_indices(self):
|
| """Build indices for each stage."""
|
| self.stage_indices = defaultdict(list)
|
|
|
| for idx, sample in enumerate(self.dataset):
|
|
|
| domain = sample.get("domain", "unknown")
|
| difficulty = sample.get("difficulty", 0.5)
|
|
|
| for stage in self.stages:
|
| if domain in stage.domains:
|
| if stage.min_difficulty <= difficulty <= stage.max_difficulty:
|
| self.stage_indices[stage.name].append(idx)
|
|
|
| logger.info(f"Built stage indices: {[(s, len(self.stage_indices[s])) for s in self.stage_indices]}")
|
|
|
| def _get_stage_for_epoch(self, epoch: int) -> CurriculumStage:
|
| """Get current stage based on epoch."""
|
| current_stage = self.stages[0]
|
| for stage in self.stages:
|
| if epoch >= stage.epoch:
|
| current_stage = stage
|
| else:
|
| break
|
| return current_stage
|
|
|
| def set_epoch(self, epoch: int):
|
| """Update current epoch and stage."""
|
| self.current_epoch = epoch
|
| self._current_stage = self._get_stage_for_epoch(epoch)
|
| logger.info(f"Epoch {epoch} -> stage '{self._current_stage.name}'")
|
|
|
| def __iter__(self):
|
| """Iterate over indices according to current stage."""
|
| indices = self.stage_indices[self._current_stage.name]
|
|
|
| if self._current_stage.sampling_strategy == "random":
|
|
|
| shuffled = indices.copy()
|
| self.rng.shuffle(shuffled)
|
| return iter(shuffled)
|
|
|
| elif self._current_stage.sampling_strategy == "balanced":
|
|
|
|
|
| domain_indices = defaultdict(list)
|
| for idx in indices:
|
| domain = self.dataset[idx].get("domain", "unknown")
|
| domain_indices[domain].append(idx)
|
|
|
|
|
| for domain in domain_indices:
|
| self.rng.shuffle(domain_indices[domain])
|
|
|
|
|
| result = []
|
| max_len = max(len(lst) for lst in domain_indices.values())
|
| for i in range(max_len):
|
| for domain in sorted(domain_indices.keys()):
|
| if i < len(domain_indices[domain]):
|
| result.append(domain_indices[domain][i])
|
| return iter(result)
|
|
|
| elif self._current_stage.sampling_strategy == "weighted":
|
|
|
| weights = self._current_stage.domain_weights or {}
|
|
|
| weighted_indices = []
|
| for idx in indices:
|
| domain = self.dataset[idx].get("domain", "unknown")
|
| weight = weights.get(domain, 1.0)
|
| weighted_indices.extend([idx] * int(weight * 10))
|
| self.rng.shuffle(weighted_indices)
|
| return iter(weighted_indices)
|
|
|
| else:
|
|
|
| shuffled = indices.copy()
|
| self.rng.shuffle(shuffled)
|
| return iter(shuffled)
|
|
|
| def __len__(self) -> int:
|
| """Return number of samples in current stage."""
|
| return len(self.stage_indices[self._current_stage.name])
|
|
|
|
|
| class DifficultyAwareSampler(Sampler):
|
| """Sampler that weights samples by difficulty for progressive learning."""
|
|
|
| def __init__(
|
| self,
|
| dataset: Any,
|
| difficulty_key: str = "difficulty",
|
| temperature: float = 1.0,
|
| start_difficulty: float = 0.0,
|
| end_difficulty: float = 1.0,
|
| epochs: int = 10,
|
| seed: int = 42,
|
| ):
|
| self.dataset = dataset
|
| self.difficulty_key = difficulty_key
|
| self.temperature = temperature
|
| self.start_difficulty = start_difficulty
|
| self.end_difficulty = end_difficulty
|
| self.epochs = epochs
|
| self.seed = seed
|
| self.rng = np.random.RandomState(seed)
|
|
|
|
|
| self.difficulties = np.array([
|
| sample.get(difficulty_key, 0.5) for sample in dataset
|
| ])
|
|
|
| self.current_epoch = 0
|
|
|
| def set_epoch(self, epoch: int):
|
| """Update epoch."""
|
| self.current_epoch = epoch
|
|
|
| def _get_difficulty_threshold(self) -> float:
|
| """Get current difficulty threshold based on epoch."""
|
| progress = min(1.0, self.current_epoch / self.epochs)
|
| return self.start_difficulty + progress * (self.end_difficulty - self.start_difficulty)
|
|
|
| def _compute_weights(self) -> np.ndarray:
|
| """Compute sampling weights based on current difficulty threshold."""
|
| threshold = self._get_difficulty_threshold()
|
|
|
|
|
|
|
| weights = np.exp(-self.difficulties / (1.0 - threshold + 1e-6))
|
|
|
|
|
| weights = weights / weights.sum()
|
|
|
| return weights
|
|
|
| def __iter__(self):
|
| """Iterate with difficulty-based sampling."""
|
| weights = self._compute_weights()
|
|
|
|
|
| indices = np.random.choice(
|
| len(self.dataset),
|
| size=len(self.dataset),
|
| replace=False,
|
| p=weights,
|
| )
|
|
|
| return iter(indices.tolist())
|
|
|
| def __len__(self) -> int:
|
| return len(self.dataset)
|
|
|
|
|
| def create_curriculum_sampler(
|
| dataset: Any,
|
| curriculum_config: Any,
|
| current_epoch: int = 0,
|
| seed: int = 42,
|
| ) -> Optional[CurriculumSampler]:
|
| """Create curriculum sampler from config."""
|
| if not curriculum_config.enable_curriculum:
|
| return None
|
|
|
|
|
| stages = []
|
| for stage_cfg in curriculum_config.stages:
|
| stage = CurriculumStage(
|
| name=stage_cfg["name"],
|
| epoch=stage_cfg["epoch"],
|
| domains=[d.value for d in stage_cfg["domains"]],
|
| min_difficulty=stage_cfg.get("min_difficulty", 0.0),
|
| max_difficulty=stage_cfg.get("max_difficulty", 1.0),
|
| sampling_strategy=stage_cfg.get("sampling_strategy", "balanced"),
|
| domain_weights=stage_cfg.get("domain_weights"),
|
| )
|
| stages.append(stage)
|
|
|
| return CurriculumSampler(dataset, stages, current_epoch, seed)
|
|
|