Zenith-7b-V1 / data /curriculum_sampler.py
Zandy-Wandy's picture
Upload Zenith-7B model
8d18b7c verified
"""Curriculum Learning and Difficulty-Aware Sampling"""
import logging
import random
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
from collections import defaultdict
import numpy as np
from torch.utils.data import Sampler
logger = logging.getLogger(__name__)
@dataclass
class CurriculumStage:
"""Definition of a curriculum learning stage."""
name: str
epoch: int
domains: List[str]
min_difficulty: float = 0.0
max_difficulty: float = 1.0
sampling_strategy: str = "balanced" # "balanced", "weighted", "random"
domain_weights: Optional[Dict[str, float]] = None
class CurriculumSampler(Sampler):
"""Sampler that implements curriculum learning strategy."""
def __init__(
self,
dataset: Any,
stages: List[CurriculumStage],
current_epoch: int = 0,
seed: int = 42,
):
self.dataset = dataset
self.stages = stages
self.current_epoch = current_epoch
self.seed = seed
self.rng = np.random.RandomState(seed)
# Build index by stage
self._build_indices()
self._current_stage = self._get_stage_for_epoch(current_epoch)
logger.info(f"CurriculumSampler initialized with {len(stages)} stages")
logger.info(f"Current epoch {current_epoch} -> stage '{self._current_stage.name}'")
def _build_indices(self):
"""Build indices for each stage."""
self.stage_indices = defaultdict(list)
for idx, sample in enumerate(self.dataset):
# Determine which stages this sample belongs to
domain = sample.get("domain", "unknown")
difficulty = sample.get("difficulty", 0.5)
for stage in self.stages:
if domain in stage.domains:
if stage.min_difficulty <= difficulty <= stage.max_difficulty:
self.stage_indices[stage.name].append(idx)
logger.info(f"Built stage indices: {[(s, len(self.stage_indices[s])) for s in self.stage_indices]}")
def _get_stage_for_epoch(self, epoch: int) -> CurriculumStage:
"""Get current stage based on epoch."""
current_stage = self.stages[0]
for stage in self.stages:
if epoch >= stage.epoch:
current_stage = stage
else:
break
return current_stage
def set_epoch(self, epoch: int):
"""Update current epoch and stage."""
self.current_epoch = epoch
self._current_stage = self._get_stage_for_epoch(epoch)
logger.info(f"Epoch {epoch} -> stage '{self._current_stage.name}'")
def __iter__(self):
"""Iterate over indices according to current stage."""
indices = self.stage_indices[self._current_stage.name]
if self._current_stage.sampling_strategy == "random":
# Simple random shuffle
shuffled = indices.copy()
self.rng.shuffle(shuffled)
return iter(shuffled)
elif self._current_stage.sampling_strategy == "balanced":
# Balanced across domains within stage
# Group by domain
domain_indices = defaultdict(list)
for idx in indices:
domain = self.dataset[idx].get("domain", "unknown")
domain_indices[domain].append(idx)
# Shuffle each domain
for domain in domain_indices:
self.rng.shuffle(domain_indices[domain])
# Interleave
result = []
max_len = max(len(lst) for lst in domain_indices.values())
for i in range(max_len):
for domain in sorted(domain_indices.keys()):
if i < len(domain_indices[domain]):
result.append(domain_indices[domain][i])
return iter(result)
elif self._current_stage.sampling_strategy == "weighted":
# Weighted sampling based on domain_weights
weights = self._current_stage.domain_weights or {}
# Build weighted list
weighted_indices = []
for idx in indices:
domain = self.dataset[idx].get("domain", "unknown")
weight = weights.get(domain, 1.0)
weighted_indices.extend([idx] * int(weight * 10))
self.rng.shuffle(weighted_indices)
return iter(weighted_indices)
else:
# Default: random shuffle
shuffled = indices.copy()
self.rng.shuffle(shuffled)
return iter(shuffled)
def __len__(self) -> int:
"""Return number of samples in current stage."""
return len(self.stage_indices[self._current_stage.name])
class DifficultyAwareSampler(Sampler):
"""Sampler that weights samples by difficulty for progressive learning."""
def __init__(
self,
dataset: Any,
difficulty_key: str = "difficulty",
temperature: float = 1.0,
start_difficulty: float = 0.0,
end_difficulty: float = 1.0,
epochs: int = 10,
seed: int = 42,
):
self.dataset = dataset
self.difficulty_key = difficulty_key
self.temperature = temperature
self.start_difficulty = start_difficulty
self.end_difficulty = end_difficulty
self.epochs = epochs
self.seed = seed
self.rng = np.random.RandomState(seed)
# Precompute difficulties
self.difficulties = np.array([
sample.get(difficulty_key, 0.5) for sample in dataset
])
self.current_epoch = 0
def set_epoch(self, epoch: int):
"""Update epoch."""
self.current_epoch = epoch
def _get_difficulty_threshold(self) -> float:
"""Get current difficulty threshold based on epoch."""
progress = min(1.0, self.current_epoch / self.epochs)
return self.start_difficulty + progress * (self.end_difficulty - self.start_difficulty)
def _compute_weights(self) -> np.ndarray:
"""Compute sampling weights based on current difficulty threshold."""
threshold = self._get_difficulty_threshold()
# Weight samples below threshold more, above threshold less
# Use exponential decay for weight
weights = np.exp(-self.difficulties / (1.0 - threshold + 1e-6))
# Normalize
weights = weights / weights.sum()
return weights
def __iter__(self):
"""Iterate with difficulty-based sampling."""
weights = self._compute_weights()
# Sample indices according to weights
indices = np.random.choice(
len(self.dataset),
size=len(self.dataset),
replace=False,
p=weights,
)
return iter(indices.tolist())
def __len__(self) -> int:
return len(self.dataset)
def create_curriculum_sampler(
dataset: Any,
curriculum_config: Any, # CurriculumConfig
current_epoch: int = 0,
seed: int = 42,
) -> Optional[CurriculumSampler]:
"""Create curriculum sampler from config."""
if not curriculum_config.enable_curriculum:
return None
# Build stages from config
stages = []
for stage_cfg in curriculum_config.stages:
stage = CurriculumStage(
name=stage_cfg["name"],
epoch=stage_cfg["epoch"],
domains=[d.value for d in stage_cfg["domains"]],
min_difficulty=stage_cfg.get("min_difficulty", 0.0),
max_difficulty=stage_cfg.get("max_difficulty", 1.0),
sampling_strategy=stage_cfg.get("sampling_strategy", "balanced"),
domain_weights=stage_cfg.get("domain_weights"),
)
stages.append(stage)
return CurriculumSampler(dataset, stages, current_epoch, seed)