| | """
|
| | Curriculum learning for Vortex model.
|
| | Progresses through stages: Foundation → Domain → Reasoning → Integration.
|
| | """
|
| |
|
| | from typing import List, Dict, Optional
|
| | import torch
|
| |
|
| |
|
| | class CurriculumScheduler:
|
| | """
|
| | Schedules curriculum stages during training.
|
| | Each stage has a start and end fraction of total training steps.
|
| | """
|
| |
|
| | STAGES = ["foundation", "domain", "reasoning", "integration"]
|
| |
|
| | def __init__(
|
| | self,
|
| | config: Dict,
|
| | total_steps: int,
|
| | ):
|
| | """
|
| | Initialize curriculum scheduler.
|
| |
|
| | Args:
|
| | config: Training config with curriculum_stages
|
| | total_steps: Total number of training steps
|
| | """
|
| | self.config = config
|
| | self.total_steps = total_steps
|
| | self.stages = config.get("curriculum_stages", [
|
| | {"name": "foundation", "start": 0.0, "end": 0.2},
|
| | {"name": "domain", "start": 0.2, "end": 0.5},
|
| | {"name": "reasoning", "start": 0.5, "end": 0.8},
|
| | {"name": "integration", "start": 0.8, "end": 1.0},
|
| | ])
|
| |
|
| |
|
| | for stage in self.stages:
|
| | stage["start_step"] = int(stage["start"] * total_steps)
|
| | stage["end_step"] = int(stage["end"] * total_steps)
|
| |
|
| | def get_stage(
|
| | self,
|
| | current_step: int,
|
| | ) -> Optional[Dict]:
|
| | """
|
| | Get current curriculum stage.
|
| |
|
| | Args:
|
| | current_step: Current training step
|
| |
|
| | Returns:
|
| | Stage dictionary or None if training complete
|
| | """
|
| | for stage in self.stages:
|
| | if stage["start_step"] <= current_step < stage["end_step"]:
|
| | return stage
|
| | return None
|
| |
|
| | def get_stage_name(self, current_step: int) -> str:
|
| | """Get current stage name."""
|
| | stage = self.get_stage(current_step)
|
| | return stage["name"] if stage else "complete"
|
| |
|
| | def get_stage_weight(
|
| | self,
|
| | current_step: int,
|
| | base_weight: float,
|
| | ) -> float:
|
| | """
|
| | Get weight for a curriculum component based on stage.
|
| |
|
| | Args:
|
| | current_step: Current training step
|
| | base_weight: Base weight for the component
|
| | Returns:
|
| | Adjusted weight (can be 0 if component not active in current stage)
|
| | """
|
| | stage = self.get_stage(current_step)
|
| | if not stage:
|
| | return 0.0
|
| |
|
| | stage_name = stage["name"]
|
| |
|
| |
|
| | stage_components = {
|
| | "foundation": ["lm_loss"],
|
| | "domain": ["lm_loss", "equation_loss", "domain_loss"],
|
| | "reasoning": ["lm_loss", "equation_loss", "domain_loss", "citation_loss"],
|
| | "integration": ["lm_loss", "equation_loss", "domain_loss", "citation_loss", "numerical_loss"],
|
| | }
|
| |
|
| | active_components = stage_components.get(stage_name, ["lm_loss"])
|
| |
|
| |
|
| |
|
| | return base_weight if "lm_loss" in active_components else 0.0
|
| |
|
| | def get_dataset_sampler(
|
| | self,
|
| | current_step: int,
|
| | ):
|
| | """
|
| | Get dataset sampler for current stage.
|
| | Different stages may mix datasets differently.
|
| |
|
| | Returns:
|
| | Sampler weights for different datasets
|
| | """
|
| | stage = self.get_stage(current_step)
|
| | if not stage:
|
| | return None
|
| |
|
| | stage_name = stage["name"]
|
| |
|
| |
|
| | mixing_proportions = {
|
| | "foundation": {
|
| | "pile_scientific": 0.3,
|
| | "s2orc": 0.3,
|
| | "automath": 0.2,
|
| | "pubmed_qa": 0.2,
|
| | },
|
| | "domain": {
|
| | "pile_scientific": 0.2,
|
| | "s2orc": 0.2,
|
| | "automath": 0.2,
|
| | "pubmed_qa": 0.2,
|
| | "deepmind_math": 0.2,
|
| | },
|
| | "reasoning": {
|
| | "pile_scientific": 0.15,
|
| | "s2orc": 0.15,
|
| | "automath": 0.3,
|
| | "deepmind_math": 0.3,
|
| | "pubmed_qa": 0.1,
|
| | },
|
| | "integration": {
|
| | "pile_scientific": 0.2,
|
| | "s2orc": 0.2,
|
| | "automath": 0.2,
|
| | "deepmind_math": 0.2,
|
| | "pubmed_qa": 0.2,
|
| | },
|
| | }
|
| |
|
| | return mixing_proportions.get(stage_name, {"pile_scientific": 1.0})
|
| |
|
| |
|
| | def test_curriculum():
|
| | """Test curriculum scheduler."""
|
| | config = {
|
| | "curriculum_stages": [
|
| | {"name": "foundation", "start": 0.0, "end": 0.2},
|
| | {"name": "domain", "start": 0.2, "end": 0.5},
|
| | {"name": "reasoning", "start": 0.5, "end": 0.8},
|
| | {"name": "integration", "start": 0.8, "end": 1.0},
|
| | ]
|
| | }
|
| |
|
| | total_steps = 1000
|
| | scheduler = CurriculumScheduler(config, total_steps)
|
| |
|
| | for step in [0, 100, 250, 500, 750, 999]:
|
| | stage = scheduler.get_stage(step)
|
| | name = scheduler.get_stage_name(step)
|
| | print(f"Step {step}: {name}")
|
| |
|
| | print("Curriculum test passed!")
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | test_curriculum()
|
| |
|