"""Curriculum Learning and Difficulty-Aware Sampling""" import logging import random from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Tuple from collections import defaultdict import numpy as np from torch.utils.data import Sampler logger = logging.getLogger(__name__) @dataclass class CurriculumStage: """Definition of a curriculum learning stage.""" name: str epoch: int domains: List[str] min_difficulty: float = 0.0 max_difficulty: float = 1.0 sampling_strategy: str = "balanced" # "balanced", "weighted", "random" domain_weights: Optional[Dict[str, float]] = None class CurriculumSampler(Sampler): """Sampler that implements curriculum learning strategy.""" def __init__( self, dataset: Any, stages: List[CurriculumStage], current_epoch: int = 0, seed: int = 42, ): self.dataset = dataset self.stages = stages self.current_epoch = current_epoch self.seed = seed self.rng = np.random.RandomState(seed) # Build index by stage self._build_indices() self._current_stage = self._get_stage_for_epoch(current_epoch) logger.info(f"CurriculumSampler initialized with {len(stages)} stages") logger.info(f"Current epoch {current_epoch} -> stage '{self._current_stage.name}'") def _build_indices(self): """Build indices for each stage.""" self.stage_indices = defaultdict(list) for idx, sample in enumerate(self.dataset): # Determine which stages this sample belongs to domain = sample.get("domain", "unknown") difficulty = sample.get("difficulty", 0.5) for stage in self.stages: if domain in stage.domains: if stage.min_difficulty <= difficulty <= stage.max_difficulty: self.stage_indices[stage.name].append(idx) logger.info(f"Built stage indices: {[(s, len(self.stage_indices[s])) for s in self.stage_indices]}") def _get_stage_for_epoch(self, epoch: int) -> CurriculumStage: """Get current stage based on epoch.""" current_stage = self.stages[0] for stage in self.stages: if epoch >= stage.epoch: current_stage = stage else: break return current_stage def set_epoch(self, epoch: int): """Update current epoch and stage.""" self.current_epoch = epoch self._current_stage = self._get_stage_for_epoch(epoch) logger.info(f"Epoch {epoch} -> stage '{self._current_stage.name}'") def __iter__(self): """Iterate over indices according to current stage.""" indices = self.stage_indices[self._current_stage.name] if self._current_stage.sampling_strategy == "random": # Simple random shuffle shuffled = indices.copy() self.rng.shuffle(shuffled) return iter(shuffled) elif self._current_stage.sampling_strategy == "balanced": # Balanced across domains within stage # Group by domain domain_indices = defaultdict(list) for idx in indices: domain = self.dataset[idx].get("domain", "unknown") domain_indices[domain].append(idx) # Shuffle each domain for domain in domain_indices: self.rng.shuffle(domain_indices[domain]) # Interleave result = [] max_len = max(len(lst) for lst in domain_indices.values()) for i in range(max_len): for domain in sorted(domain_indices.keys()): if i < len(domain_indices[domain]): result.append(domain_indices[domain][i]) return iter(result) elif self._current_stage.sampling_strategy == "weighted": # Weighted sampling based on domain_weights weights = self._current_stage.domain_weights or {} # Build weighted list weighted_indices = [] for idx in indices: domain = self.dataset[idx].get("domain", "unknown") weight = weights.get(domain, 1.0) weighted_indices.extend([idx] * int(weight * 10)) self.rng.shuffle(weighted_indices) return iter(weighted_indices) else: # Default: random shuffle shuffled = indices.copy() self.rng.shuffle(shuffled) return iter(shuffled) def __len__(self) -> int: """Return number of samples in current stage.""" return len(self.stage_indices[self._current_stage.name]) class DifficultyAwareSampler(Sampler): """Sampler that weights samples by difficulty for progressive learning.""" def __init__( self, dataset: Any, difficulty_key: str = "difficulty", temperature: float = 1.0, start_difficulty: float = 0.0, end_difficulty: float = 1.0, epochs: int = 10, seed: int = 42, ): self.dataset = dataset self.difficulty_key = difficulty_key self.temperature = temperature self.start_difficulty = start_difficulty self.end_difficulty = end_difficulty self.epochs = epochs self.seed = seed self.rng = np.random.RandomState(seed) # Precompute difficulties self.difficulties = np.array([ sample.get(difficulty_key, 0.5) for sample in dataset ]) self.current_epoch = 0 def set_epoch(self, epoch: int): """Update epoch.""" self.current_epoch = epoch def _get_difficulty_threshold(self) -> float: """Get current difficulty threshold based on epoch.""" progress = min(1.0, self.current_epoch / self.epochs) return self.start_difficulty + progress * (self.end_difficulty - self.start_difficulty) def _compute_weights(self) -> np.ndarray: """Compute sampling weights based on current difficulty threshold.""" threshold = self._get_difficulty_threshold() # Weight samples below threshold more, above threshold less # Use exponential decay for weight weights = np.exp(-self.difficulties / (1.0 - threshold + 1e-6)) # Normalize weights = weights / weights.sum() return weights def __iter__(self): """Iterate with difficulty-based sampling.""" weights = self._compute_weights() # Sample indices according to weights indices = np.random.choice( len(self.dataset), size=len(self.dataset), replace=False, p=weights, ) return iter(indices.tolist()) def __len__(self) -> int: return len(self.dataset) def create_curriculum_sampler( dataset: Any, curriculum_config: Any, # CurriculumConfig current_epoch: int = 0, seed: int = 42, ) -> Optional[CurriculumSampler]: """Create curriculum sampler from config.""" if not curriculum_config.enable_curriculum: return None # Build stages from config stages = [] for stage_cfg in curriculum_config.stages: stage = CurriculumStage( name=stage_cfg["name"], epoch=stage_cfg["epoch"], domains=[d.value for d in stage_cfg["domains"]], min_difficulty=stage_cfg.get("min_difficulty", 0.0), max_difficulty=stage_cfg.get("max_difficulty", 1.0), sampling_strategy=stage_cfg.get("sampling_strategy", "balanced"), domain_weights=stage_cfg.get("domain_weights"), ) stages.append(stage) return CurriculumSampler(dataset, stages, current_epoch, seed)