Wolfvin commited on
Commit
fa50230
·
verified ·
1 Parent(s): 0ce8908

Upload diffusion_llm/training/curriculum.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. diffusion_llm/training/curriculum.py +186 -0
diffusion_llm/training/curriculum.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """AAM Diffusion LLM — Curriculum Learning
2
+
3
+ Training from easy to hard:
4
+ Phase 1: Single-evidence simple narratives (basic arrangement)
5
+ Phase 2: Multi-evidence narratives (complex arrangement)
6
+ Phase 3: Complex reasoning chains (anomaly + reasoning)
7
+ Phase 4: Full model + RL fine-tuning (GRPO/DAPO)
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import logging
13
+ from dataclasses import dataclass, field
14
+ from enum import Enum
15
+ from typing import Dict, List, Optional, Tuple
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class TrainingPhase(str, Enum):
21
+ PHASE_1_SINGLE_EVIDENCE = "phase_1_single_evidence"
22
+ PHASE_2_MULTI_EVIDENCE = "phase_2_multi_evidence"
23
+ PHASE_3_REASONING = "phase_3_reasoning"
24
+ PHASE_4_RL = "phase_4_rl"
25
+
26
+
27
+ @dataclass
28
+ class PhaseConfig:
29
+ phase: TrainingPhase
30
+ budget_fraction: float
31
+ start_step: Optional[int] = None
32
+ end_step: Optional[int] = None
33
+ learning_rate: float = 3e-4
34
+ max_evidence_nodes: int = 5
35
+ max_anomalies: int = 0
36
+ use_grpo: bool = False
37
+ use_dapo: bool = False
38
+ diffusion_steps: int = 50
39
+ use_anchored_decoder: bool = True
40
+ use_evoformer: bool = True
41
+ validation_threshold: Optional[float] = None
42
+
43
+
44
+ @dataclass
45
+ class PhaseTransition:
46
+ from_phase: TrainingPhase
47
+ to_phase: TrainingPhase
48
+ step: int
49
+ reason: str
50
+ from_metrics: Optional[Dict[str, float]] = None
51
+
52
+
53
+ class CurriculumScheduler:
54
+ """Curriculum Learning for AAM 4-Phase Training."""
55
+
56
+ def __init__(self, total_steps: int = 500000, learning_rate: float = 1e-4) -> None:
57
+ self.total_steps = total_steps
58
+ self.current_phase = TrainingPhase.PHASE_1_SINGLE_EVIDENCE
59
+ self.current_step = 0
60
+
61
+ self.phase_configs = self._build_phase_configs(learning_rate)
62
+ self.transition_history: List[PhaseTransition] = []
63
+ self.phase_step_counters: Dict[TrainingPhase, int] = {phase: 0 for phase in TrainingPhase}
64
+ self.validation_metrics: Dict[str, List[float]] = {"loss": [], "perplexity": []}
65
+
66
+ def _build_phase_configs(self, base_lr: float) -> Dict[TrainingPhase, PhaseConfig]:
67
+ configs = {
68
+ TrainingPhase.PHASE_1_SINGLE_EVIDENCE: PhaseConfig(
69
+ phase=TrainingPhase.PHASE_1_SINGLE_EVIDENCE,
70
+ budget_fraction=0.25,
71
+ learning_rate=base_lr,
72
+ max_evidence_nodes=3,
73
+ max_anomalies=0,
74
+ diffusion_steps=20,
75
+ use_anchored_decoder=True,
76
+ use_evoformer=False,
77
+ ),
78
+ TrainingPhase.PHASE_2_MULTI_EVIDENCE: PhaseConfig(
79
+ phase=TrainingPhase.PHASE_2_MULTI_EVIDENCE,
80
+ budget_fraction=0.30,
81
+ learning_rate=base_lr * 0.5,
82
+ max_evidence_nodes=10,
83
+ max_anomalies=0,
84
+ diffusion_steps=30,
85
+ use_anchored_decoder=True,
86
+ use_evoformer=True,
87
+ ),
88
+ TrainingPhase.PHASE_3_REASONING: PhaseConfig(
89
+ phase=TrainingPhase.PHASE_3_REASONING,
90
+ budget_fraction=0.30,
91
+ learning_rate=base_lr * 0.1,
92
+ max_evidence_nodes=20,
93
+ max_anomalies=5,
94
+ diffusion_steps=50,
95
+ use_anchored_decoder=True,
96
+ use_evoformer=True,
97
+ ),
98
+ TrainingPhase.PHASE_4_RL: PhaseConfig(
99
+ phase=TrainingPhase.PHASE_4_RL,
100
+ budget_fraction=0.15,
101
+ learning_rate=base_lr * 0.01,
102
+ max_evidence_nodes=50,
103
+ max_anomalies=10,
104
+ diffusion_steps=50,
105
+ use_anchored_decoder=True,
106
+ use_evoformer=True,
107
+ use_grpo=True,
108
+ use_dapo=True,
109
+ ),
110
+ }
111
+
112
+ cumulative_budget = 0.0
113
+ for phase in TrainingPhase:
114
+ cfg = configs[phase]
115
+ cfg.start_step = int(cumulative_budget * self.total_steps)
116
+ cumulative_budget += cfg.budget_fraction
117
+ cfg.end_step = int(cumulative_budget * self.total_steps)
118
+
119
+ return configs
120
+
121
+ def update(self, step: int, validation_loss: Optional[float] = None) -> TrainingPhase:
122
+ self.current_step = step
123
+ self.phase_step_counters[self.current_phase] += 1
124
+
125
+ if validation_loss is not None:
126
+ self.validation_metrics["loss"].append(validation_loss)
127
+
128
+ current_config = self.phase_configs[self.current_phase]
129
+ if current_config.end_step is not None and step >= current_config.end_step:
130
+ next_phase = self._get_next_phase(self.current_phase)
131
+ if next_phase is not None:
132
+ self._transition_to(next_phase, reason=f"step_threshold_reached (step={step})")
133
+ return self.current_phase
134
+
135
+ return self.current_phase
136
+
137
+ def _transition_to(self, next_phase: TrainingPhase, reason: str) -> None:
138
+ old_phase = self.current_phase
139
+ self.transition_history.append(PhaseTransition(
140
+ from_phase=old_phase, to_phase=next_phase, step=self.current_step, reason=reason,
141
+ ))
142
+ self.current_phase = next_phase
143
+ logger.info(f"Curriculum: {old_phase.value} → {next_phase.value} (reason: {reason})")
144
+
145
+ def _get_next_phase(self, current: TrainingPhase) -> Optional[TrainingPhase]:
146
+ phase_order = list(TrainingPhase)
147
+ try:
148
+ idx = phase_order.index(current)
149
+ if idx + 1 < len(phase_order):
150
+ return phase_order[idx + 1]
151
+ except ValueError:
152
+ pass
153
+ return None
154
+
155
+ def get_current_config(self) -> PhaseConfig:
156
+ return self.phase_configs[self.current_phase]
157
+
158
+ def get_progress(self) -> Dict[str, float]:
159
+ phase_config = self.phase_configs[self.current_phase]
160
+ phase_start = phase_config.start_step or 0
161
+ phase_end = phase_config.end_step or self.total_steps
162
+ phase_budget = phase_end - phase_start
163
+ phase_progress = min((self.current_step - phase_start) / max(phase_budget, 1), 1.0) if phase_budget > 0 else 0.0
164
+ return {
165
+ "total_progress": self.current_step / max(self.total_steps, 1),
166
+ "current_phase": self.current_phase.value,
167
+ "phase_progress": phase_progress,
168
+ }
169
+
170
+ def get_schedule_summary(self) -> List[Dict[str, object]]:
171
+ summary = []
172
+ for phase in TrainingPhase:
173
+ config = self.phase_configs[phase]
174
+ summary.append({
175
+ "phase": phase.value,
176
+ "is_current": phase == self.current_phase,
177
+ "budget_fraction": config.budget_fraction,
178
+ "start_step": config.start_step,
179
+ "end_step": config.end_step,
180
+ "learning_rate": config.learning_rate,
181
+ "max_evidence_nodes": config.max_evidence_nodes,
182
+ "max_anomalies": config.max_anomalies,
183
+ "use_grpo": config.use_grpo,
184
+ "use_dapo": config.use_dapo,
185
+ })
186
+ return summary