""" Curriculum learning for Vortex model. Progresses through stages: Foundation → Domain → Reasoning → Integration. """ from typing import List, Dict, Optional import torch class CurriculumScheduler: """ Schedules curriculum stages during training. Each stage has a start and end fraction of total training steps. """ STAGES = ["foundation", "domain", "reasoning", "integration"] def __init__( self, config: Dict, total_steps: int, ): """ Initialize curriculum scheduler. Args: config: Training config with curriculum_stages total_steps: Total number of training steps """ self.config = config self.total_steps = total_steps self.stages = config.get("curriculum_stages", [ {"name": "foundation", "start": 0.0, "end": 0.2}, {"name": "domain", "start": 0.2, "end": 0.5}, {"name": "reasoning", "start": 0.5, "end": 0.8}, {"name": "integration", "start": 0.8, "end": 1.0}, ]) # Convert fractions to step numbers for stage in self.stages: stage["start_step"] = int(stage["start"] * total_steps) stage["end_step"] = int(stage["end"] * total_steps) def get_stage( self, current_step: int, ) -> Optional[Dict]: """ Get current curriculum stage. Args: current_step: Current training step Returns: Stage dictionary or None if training complete """ for stage in self.stages: if stage["start_step"] <= current_step < stage["end_step"]: return stage return None def get_stage_name(self, current_step: int) -> str: """Get current stage name.""" stage = self.get_stage(current_step) return stage["name"] if stage else "complete" def get_stage_weight( self, current_step: int, base_weight: float, ) -> float: """ Get weight for a curriculum component based on stage. Args: current_step: Current training step base_weight: Base weight for the component Returns: Adjusted weight (can be 0 if component not active in current stage) """ stage = self.get_stage(current_step) if not stage: return 0.0 stage_name = stage["name"] # Define which components are active in each stage stage_components = { "foundation": ["lm_loss"], # Only language modeling "domain": ["lm_loss", "equation_loss", "domain_loss"], "reasoning": ["lm_loss", "equation_loss", "domain_loss", "citation_loss"], "integration": ["lm_loss", "equation_loss", "domain_loss", "citation_loss", "numerical_loss"], } active_components = stage_components.get(stage_name, ["lm_loss"]) # Return base weight if component active, else 0 # (Caller checks if their component is in active_components) return base_weight if "lm_loss" in active_components else 0.0 def get_dataset_sampler( self, current_step: int, ): """ Get dataset sampler for current stage. Different stages may mix datasets differently. Returns: Sampler weights for different datasets """ stage = self.get_stage(current_step) if not stage: return None stage_name = stage["name"] # Dataset mixing proportions per stage mixing_proportions = { "foundation": { "pile_scientific": 0.3, "s2orc": 0.3, "automath": 0.2, "pubmed_qa": 0.2, }, "domain": { "pile_scientific": 0.2, "s2orc": 0.2, "automath": 0.2, "pubmed_qa": 0.2, "deepmind_math": 0.2, }, "reasoning": { "pile_scientific": 0.15, "s2orc": 0.15, "automath": 0.3, "deepmind_math": 0.3, "pubmed_qa": 0.1, }, "integration": { "pile_scientific": 0.2, "s2orc": 0.2, "automath": 0.2, "deepmind_math": 0.2, "pubmed_qa": 0.2, }, } return mixing_proportions.get(stage_name, {"pile_scientific": 1.0}) def test_curriculum(): """Test curriculum scheduler.""" config = { "curriculum_stages": [ {"name": "foundation", "start": 0.0, "end": 0.2}, {"name": "domain", "start": 0.2, "end": 0.5}, {"name": "reasoning", "start": 0.5, "end": 0.8}, {"name": "integration", "start": 0.8, "end": 1.0}, ] } total_steps = 1000 scheduler = CurriculumScheduler(config, total_steps) for step in [0, 100, 250, 500, 750, 999]: stage = scheduler.get_stage(step) name = scheduler.get_stage_name(step) print(f"Step {step}: {name}") print("Curriculum test passed!") if __name__ == "__main__": test_curriculum()