| """
|
| Curriculum learning for Vortex model.
|
| Progresses through stages: Foundation → Domain → Reasoning → Integration.
|
| """
|
|
|
| from typing import List, Dict, Optional
|
| import torch
|
|
|
|
|
| class CurriculumScheduler:
|
| """
|
| Schedules curriculum stages during training.
|
| Each stage has a start and end fraction of total training steps.
|
| """
|
|
|
| STAGES = ["foundation", "domain", "reasoning", "integration"]
|
|
|
| def __init__(
|
| self,
|
| config: Dict,
|
| total_steps: int,
|
| ):
|
| """
|
| Initialize curriculum scheduler.
|
|
|
| Args:
|
| config: Training config with curriculum_stages
|
| total_steps: Total number of training steps
|
| """
|
| self.config = config
|
| self.total_steps = total_steps
|
| self.stages = config.get("curriculum_stages", [
|
| {"name": "foundation", "start": 0.0, "end": 0.2},
|
| {"name": "domain", "start": 0.2, "end": 0.5},
|
| {"name": "reasoning", "start": 0.5, "end": 0.8},
|
| {"name": "integration", "start": 0.8, "end": 1.0},
|
| ])
|
|
|
|
|
| for stage in self.stages:
|
| stage["start_step"] = int(stage["start"] * total_steps)
|
| stage["end_step"] = int(stage["end"] * total_steps)
|
|
|
| def get_stage(
|
| self,
|
| current_step: int,
|
| ) -> Optional[Dict]:
|
| """
|
| Get current curriculum stage.
|
|
|
| Args:
|
| current_step: Current training step
|
|
|
| Returns:
|
| Stage dictionary or None if training complete
|
| """
|
| for stage in self.stages:
|
| if stage["start_step"] <= current_step < stage["end_step"]:
|
| return stage
|
| return None
|
|
|
| def get_stage_name(self, current_step: int) -> str:
|
| """Get current stage name."""
|
| stage = self.get_stage(current_step)
|
| return stage["name"] if stage else "complete"
|
|
|
| def get_stage_weight(
|
| self,
|
| current_step: int,
|
| base_weight: float,
|
| ) -> float:
|
| """
|
| Get weight for a curriculum component based on stage.
|
|
|
| Args:
|
| current_step: Current training step
|
| base_weight: Base weight for the component
|
| Returns:
|
| Adjusted weight (can be 0 if component not active in current stage)
|
| """
|
| stage = self.get_stage(current_step)
|
| if not stage:
|
| return 0.0
|
|
|
| stage_name = stage["name"]
|
|
|
|
|
| stage_components = {
|
| "foundation": ["lm_loss"],
|
| "domain": ["lm_loss", "equation_loss", "domain_loss"],
|
| "reasoning": ["lm_loss", "equation_loss", "domain_loss", "citation_loss"],
|
| "integration": ["lm_loss", "equation_loss", "domain_loss", "citation_loss", "numerical_loss"],
|
| }
|
|
|
| active_components = stage_components.get(stage_name, ["lm_loss"])
|
|
|
|
|
|
|
| return base_weight if "lm_loss" in active_components else 0.0
|
|
|
| def get_dataset_sampler(
|
| self,
|
| current_step: int,
|
| ):
|
| """
|
| Get dataset sampler for current stage.
|
| Different stages may mix datasets differently.
|
|
|
| Returns:
|
| Sampler weights for different datasets
|
| """
|
| stage = self.get_stage(current_step)
|
| if not stage:
|
| return None
|
|
|
| stage_name = stage["name"]
|
|
|
|
|
| mixing_proportions = {
|
| "foundation": {
|
| "pile_scientific": 0.3,
|
| "s2orc": 0.3,
|
| "automath": 0.2,
|
| "pubmed_qa": 0.2,
|
| },
|
| "domain": {
|
| "pile_scientific": 0.2,
|
| "s2orc": 0.2,
|
| "automath": 0.2,
|
| "pubmed_qa": 0.2,
|
| "deepmind_math": 0.2,
|
| },
|
| "reasoning": {
|
| "pile_scientific": 0.15,
|
| "s2orc": 0.15,
|
| "automath": 0.3,
|
| "deepmind_math": 0.3,
|
| "pubmed_qa": 0.1,
|
| },
|
| "integration": {
|
| "pile_scientific": 0.2,
|
| "s2orc": 0.2,
|
| "automath": 0.2,
|
| "deepmind_math": 0.2,
|
| "pubmed_qa": 0.2,
|
| },
|
| }
|
|
|
| return mixing_proportions.get(stage_name, {"pile_scientific": 1.0})
|
|
|
|
|
| def test_curriculum():
|
| """Test curriculum scheduler."""
|
| config = {
|
| "curriculum_stages": [
|
| {"name": "foundation", "start": 0.0, "end": 0.2},
|
| {"name": "domain", "start": 0.2, "end": 0.5},
|
| {"name": "reasoning", "start": 0.5, "end": 0.8},
|
| {"name": "integration", "start": 0.8, "end": 1.0},
|
| ]
|
| }
|
|
|
| total_steps = 1000
|
| scheduler = CurriculumScheduler(config, total_steps)
|
|
|
| for step in [0, 100, 250, 500, 750, 999]:
|
| stage = scheduler.get_stage(step)
|
| name = scheduler.get_stage_name(step)
|
| print(f"Step {step}: {name}")
|
|
|
| print("Curriculum test passed!")
|
|
|
|
|
| if __name__ == "__main__":
|
| test_curriculum()
|
|
|