Vortex-7b-V1 / training /curriculum.py
Zandy-Wandy's picture
Upload Vortex model
bf64b03 verified
"""
Curriculum learning for Vortex model.
Progresses through stages: Foundation → Domain → Reasoning → Integration.
"""
from typing import List, Dict, Optional
import torch
class CurriculumScheduler:
"""
Schedules curriculum stages during training.
Each stage has a start and end fraction of total training steps.
"""
STAGES = ["foundation", "domain", "reasoning", "integration"]
def __init__(
self,
config: Dict,
total_steps: int,
):
"""
Initialize curriculum scheduler.
Args:
config: Training config with curriculum_stages
total_steps: Total number of training steps
"""
self.config = config
self.total_steps = total_steps
self.stages = config.get("curriculum_stages", [
{"name": "foundation", "start": 0.0, "end": 0.2},
{"name": "domain", "start": 0.2, "end": 0.5},
{"name": "reasoning", "start": 0.5, "end": 0.8},
{"name": "integration", "start": 0.8, "end": 1.0},
])
# Convert fractions to step numbers
for stage in self.stages:
stage["start_step"] = int(stage["start"] * total_steps)
stage["end_step"] = int(stage["end"] * total_steps)
def get_stage(
self,
current_step: int,
) -> Optional[Dict]:
"""
Get current curriculum stage.
Args:
current_step: Current training step
Returns:
Stage dictionary or None if training complete
"""
for stage in self.stages:
if stage["start_step"] <= current_step < stage["end_step"]:
return stage
return None
def get_stage_name(self, current_step: int) -> str:
"""Get current stage name."""
stage = self.get_stage(current_step)
return stage["name"] if stage else "complete"
def get_stage_weight(
self,
current_step: int,
base_weight: float,
) -> float:
"""
Get weight for a curriculum component based on stage.
Args:
current_step: Current training step
base_weight: Base weight for the component
Returns:
Adjusted weight (can be 0 if component not active in current stage)
"""
stage = self.get_stage(current_step)
if not stage:
return 0.0
stage_name = stage["name"]
# Define which components are active in each stage
stage_components = {
"foundation": ["lm_loss"], # Only language modeling
"domain": ["lm_loss", "equation_loss", "domain_loss"],
"reasoning": ["lm_loss", "equation_loss", "domain_loss", "citation_loss"],
"integration": ["lm_loss", "equation_loss", "domain_loss", "citation_loss", "numerical_loss"],
}
active_components = stage_components.get(stage_name, ["lm_loss"])
# Return base weight if component active, else 0
# (Caller checks if their component is in active_components)
return base_weight if "lm_loss" in active_components else 0.0
def get_dataset_sampler(
self,
current_step: int,
):
"""
Get dataset sampler for current stage.
Different stages may mix datasets differently.
Returns:
Sampler weights for different datasets
"""
stage = self.get_stage(current_step)
if not stage:
return None
stage_name = stage["name"]
# Dataset mixing proportions per stage
mixing_proportions = {
"foundation": {
"pile_scientific": 0.3,
"s2orc": 0.3,
"automath": 0.2,
"pubmed_qa": 0.2,
},
"domain": {
"pile_scientific": 0.2,
"s2orc": 0.2,
"automath": 0.2,
"pubmed_qa": 0.2,
"deepmind_math": 0.2,
},
"reasoning": {
"pile_scientific": 0.15,
"s2orc": 0.15,
"automath": 0.3,
"deepmind_math": 0.3,
"pubmed_qa": 0.1,
},
"integration": {
"pile_scientific": 0.2,
"s2orc": 0.2,
"automath": 0.2,
"deepmind_math": 0.2,
"pubmed_qa": 0.2,
},
}
return mixing_proportions.get(stage_name, {"pile_scientific": 1.0})
def test_curriculum():
"""Test curriculum scheduler."""
config = {
"curriculum_stages": [
{"name": "foundation", "start": 0.0, "end": 0.2},
{"name": "domain", "start": 0.2, "end": 0.5},
{"name": "reasoning", "start": 0.5, "end": 0.8},
{"name": "integration", "start": 0.8, "end": 1.0},
]
}
total_steps = 1000
scheduler = CurriculumScheduler(config, total_steps)
for step in [0, 100, 250, 500, 750, 999]:
stage = scheduler.get_stage(step)
name = scheduler.get_stage_name(step)
print(f"Step {step}: {name}")
print("Curriculum test passed!")
if __name__ == "__main__":
test_curriculum()