Spaces:
Running on T4
Running on T4
Upload folder using huggingface_hub
Browse files- Dockerfile +3 -0
- models.py +4 -0
- scripts/patch_openenv_web_interface_chart.py +226 -0
- server/meta_optimizer_environment.py +216 -34
- server/slm_model.py +160 -0
- server/tasks.py +95 -18
Dockerfile
CHANGED
|
@@ -57,6 +57,9 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
|
| 57 |
# Patch OpenEnv web UI so number inputs use step=0.01 (allows lr_scale=0.02, momentum_coef=0.9)
|
| 58 |
RUN WEBIF="$$(find /app/env/.venv -path '*openenv*env_server*web_interface.py' | head -1)" && \
|
| 59 |
/app/env/.venv/bin/python /app/env/scripts/patch_openenv_web_interface.py "$$WEBIF"
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
# Final runtime stage
|
| 62 |
FROM ${BASE_IMAGE}
|
|
|
|
| 57 |
# Patch OpenEnv web UI so number inputs use step=0.01 (allows lr_scale=0.02, momentum_coef=0.9)
|
| 58 |
RUN WEBIF="$$(find /app/env/.venv -path '*openenv*env_server*web_interface.py' | head -1)" && \
|
| 59 |
/app/env/.venv/bin/python /app/env/scripts/patch_openenv_web_interface.py "$$WEBIF"
|
| 60 |
+
# Patch OpenEnv web UI to add loss/perplexity chart and Run baseline (AdamW) button
|
| 61 |
+
RUN WEBIF="$$(find /app/env/.venv -path '*openenv*env_server*web_interface.py' | head -1)" && \
|
| 62 |
+
/app/env/.venv/bin/python /app/env/scripts/patch_openenv_web_interface_chart.py "$$WEBIF"
|
| 63 |
|
| 64 |
# Final runtime stage
|
| 65 |
FROM ${BASE_IMAGE}
|
models.py
CHANGED
|
@@ -72,4 +72,8 @@ class MetaOptimizerObservation(Observation):
|
|
| 72 |
default=None,
|
| 73 |
description="Step at which loss first reached threshold (None if not yet reached)",
|
| 74 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
|
|
|
| 72 |
default=None,
|
| 73 |
description="Step at which loss first reached threshold (None if not yet reached)",
|
| 74 |
)
|
| 75 |
+
perplexity: float | None = Field(
|
| 76 |
+
default=None,
|
| 77 |
+
description="exp(loss) for language modeling (None for regression)",
|
| 78 |
+
)
|
| 79 |
|
scripts/patch_openenv_web_interface_chart.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Patch OpenEnv web_interface.py to add:
|
| 4 |
+
- Loss/perplexity chart and updateLossChart()
|
| 5 |
+
- POST /web/run-baseline and GET /web/current-task for baseline comparison
|
| 6 |
+
Idempotent: safe to run multiple times.
|
| 7 |
+
"""
|
| 8 |
+
import sys
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _apply_routes_patch(text: str) -> str:
|
| 13 |
+
"""Add /web/run-baseline and /web/current-task routes."""
|
| 14 |
+
old_routes = (
|
| 15 |
+
' @app.get("/web/state")\n'
|
| 16 |
+
" async def web_state():\n"
|
| 17 |
+
' """State endpoint for web interface."""\n'
|
| 18 |
+
" return web_manager.get_state()\n"
|
| 19 |
+
"\n"
|
| 20 |
+
" return app"
|
| 21 |
+
)
|
| 22 |
+
new_routes = (
|
| 23 |
+
' @app.get("/web/state")\n'
|
| 24 |
+
" async def web_state():\n"
|
| 25 |
+
' """State endpoint for web interface."""\n'
|
| 26 |
+
" return web_manager.get_state()\n"
|
| 27 |
+
"\n"
|
| 28 |
+
' @app.get("/web/current-task")\n'
|
| 29 |
+
" async def web_current_task():\n"
|
| 30 |
+
' """Current task spec for baseline comparison (if env supports it)."""\n'
|
| 31 |
+
" get_spec = getattr(web_manager.env, \"get_current_task_spec\", None)\n"
|
| 32 |
+
" if get_spec is None:\n"
|
| 33 |
+
" return {}\n"
|
| 34 |
+
" return get_spec() or {}\n"
|
| 35 |
+
"\n"
|
| 36 |
+
' @app.post("/web/run-baseline")\n'
|
| 37 |
+
" async def web_run_baseline():\n"
|
| 38 |
+
' """Run baseline optimizer for current task; returns loss_trajectory and steps."""\n'
|
| 39 |
+
" run_bl = getattr(web_manager.env, \"run_baseline\", None)\n"
|
| 40 |
+
" if run_bl is None:\n"
|
| 41 |
+
" return {\"loss_trajectory\": [], \"steps\": [], \"error\": \"Env has no run_baseline\"}\n"
|
| 42 |
+
" return run_bl()\n"
|
| 43 |
+
"\n"
|
| 44 |
+
" return app"
|
| 45 |
+
)
|
| 46 |
+
if "web/run-baseline" not in text and "web/state" in text and "return web_manager.get_state()" in text:
|
| 47 |
+
text = text.replace(old_routes, new_routes, 1)
|
| 48 |
+
return text
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def main() -> None:
|
| 52 |
+
if len(sys.argv) < 2:
|
| 53 |
+
import openenv.core.env_server.web_interface as m
|
| 54 |
+
path = Path(m.__file__).resolve()
|
| 55 |
+
else:
|
| 56 |
+
path = Path(sys.argv[1]).resolve()
|
| 57 |
+
|
| 58 |
+
if not path.exists():
|
| 59 |
+
print(f"File not found: {path}", file=sys.stderr)
|
| 60 |
+
sys.exit(1)
|
| 61 |
+
|
| 62 |
+
text = path.read_text()
|
| 63 |
+
|
| 64 |
+
# 1) Add Chart.js script in head (after title)
|
| 65 |
+
chart_script = ' <script src="https://cdn.jsdelivr.net/npm/chart.js"></script>\n'
|
| 66 |
+
old_head = "<title>OpenEnv Web Interface</title>\n <style>"
|
| 67 |
+
new_head = "<title>OpenEnv Web Interface</title>\n" + chart_script + " <style>"
|
| 68 |
+
if chart_script not in text and old_head in text:
|
| 69 |
+
text = text.replace(old_head, new_head, 1)
|
| 70 |
+
|
| 71 |
+
# 2) Add chart container between Current Observation and Action Logs
|
| 72 |
+
old_section = """ </div>
|
| 73 |
+
</div>
|
| 74 |
+
|
| 75 |
+
<!-- Action Logs -->
|
| 76 |
+
<div class="logs-container">"""
|
| 77 |
+
new_section = """ </div>
|
| 78 |
+
</div>
|
| 79 |
+
|
| 80 |
+
<!-- Loss chart -->
|
| 81 |
+
<div class="state-display">
|
| 82 |
+
<h3>Loss / Perplexity</h3>
|
| 83 |
+
<div id="loss-chart-container" style="height:200px;"><canvas id="loss-chart"></canvas></div>
|
| 84 |
+
<button type="button" id="run-baseline-btn" class="btn btn-secondary" style="margin-top:8px;">Run baseline (AdamW)</button>
|
| 85 |
+
</div>
|
| 86 |
+
|
| 87 |
+
<!-- Action Logs -->
|
| 88 |
+
<div class="logs-container">"""
|
| 89 |
+
if "loss-chart-container" not in text and old_section in text:
|
| 90 |
+
text = text.replace(old_section, new_section, 1)
|
| 91 |
+
# If chart container exists but button does not, add button
|
| 92 |
+
if "loss-chart-container" in text and "run-baseline-btn" not in text:
|
| 93 |
+
text = text.replace(
|
| 94 |
+
"<canvas id=\"loss-chart\"></canvas></div>\n </div>",
|
| 95 |
+
"<canvas id=\"loss-chart\"></canvas></div>\n <button type=\"button\" id=\"run-baseline-btn\" class=\"btn btn-secondary\" style=\"margin-top:8px;\">Run baseline (AdamW)</button>\n </div>",
|
| 96 |
+
1,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# 3) Add updateLossChart call and method before updateChatInterface
|
| 100 |
+
old_update = """ }}
|
| 101 |
+
}}
|
| 102 |
+
|
| 103 |
+
updateChatInterface(episodeState) {{"""
|
| 104 |
+
new_update = """ }}
|
| 105 |
+
this.updateLossChart(episodeState);
|
| 106 |
+
}}
|
| 107 |
+
|
| 108 |
+
updateLossChart(episodeState) {{
|
| 109 |
+
const container = document.getElementById('loss-chart-container');
|
| 110 |
+
if (!container) return;
|
| 111 |
+
const steps = [];
|
| 112 |
+
const losses = [];
|
| 113 |
+
const perplexities = [];
|
| 114 |
+
if (episodeState.current_observation && typeof episodeState.current_observation.loss === 'number') {{
|
| 115 |
+
const o = episodeState.current_observation;
|
| 116 |
+
steps.push(o.step_count != null ? o.step_count : 0);
|
| 117 |
+
losses.push(o.loss);
|
| 118 |
+
if (typeof o.perplexity === 'number') perplexities.push(o.perplexity);
|
| 119 |
+
}}
|
| 120 |
+
(episodeState.action_logs || []).forEach(log => {{
|
| 121 |
+
if (log.observation && typeof log.observation.loss === 'number') {{
|
| 122 |
+
steps.push(log.observation.step_count != null ? log.observation.step_count : log.step_count);
|
| 123 |
+
losses.push(log.observation.loss);
|
| 124 |
+
if (typeof log.observation.perplexity === 'number') perplexities.push(log.observation.perplexity);
|
| 125 |
+
}}
|
| 126 |
+
}});
|
| 127 |
+
if (steps.length === 0) return;
|
| 128 |
+
const ctx = document.getElementById('loss-chart');
|
| 129 |
+
if (!ctx) return;
|
| 130 |
+
if (this._lossChart) this._lossChart.destroy();
|
| 131 |
+
this._lossChart = new Chart(ctx, {{
|
| 132 |
+
type: 'line',
|
| 133 |
+
data: {{
|
| 134 |
+
labels: steps,
|
| 135 |
+
datasets: [
|
| 136 |
+
{{ label: 'Loss', data: losses, borderColor: '#007bff', tension: 0.2, fill: false }}
|
| 137 |
+
].concat(perplexities.length ? [{{ label: 'Perplexity', data: perplexities, borderColor: '#28a745', tension: 0.2, fill: false }}] : [])
|
| 138 |
+
}},
|
| 139 |
+
options: {{ responsive: true, maintainAspectRatio: false, scales: {{ x: {{ title: {{ display: true, text: 'Step' }} }} }} }}
|
| 140 |
+
}});
|
| 141 |
+
}}
|
| 142 |
+
|
| 143 |
+
async runBaseline() {{
|
| 144 |
+
const btn = document.getElementById('run-baseline-btn');
|
| 145 |
+
if (btn) btn.disabled = true;
|
| 146 |
+
try {{
|
| 147 |
+
const r = await fetch('/web/run-baseline', {{ method: 'POST' }});
|
| 148 |
+
const data = await r.json();
|
| 149 |
+
if (data.error || !data.loss_trajectory || !this._lossChart) {{ if (btn) btn.disabled = false; return; }}
|
| 150 |
+
const L = data.loss_trajectory.length;
|
| 151 |
+
const steps = data.steps && data.steps.length === L ? data.steps : Array.from({{ length: L }}, (_, i) => i);
|
| 152 |
+
const curLen = this._lossChart.data.labels.length;
|
| 153 |
+
const newLen = Math.max(curLen, steps.length);
|
| 154 |
+
const newLabels = Array.from({{ length: newLen }}, (_, i) => i);
|
| 155 |
+
this._lossChart.data.labels = newLabels;
|
| 156 |
+
this._lossChart.data.datasets.forEach(ds => {{
|
| 157 |
+
while (ds.data.length < newLen) ds.data.push(null);
|
| 158 |
+
}});
|
| 159 |
+
const baselineData = data.loss_trajectory.slice();
|
| 160 |
+
while (baselineData.length < newLen) baselineData.push(null);
|
| 161 |
+
this._lossChart.data.datasets.push({{ label: 'Baseline (AdamW)', data: baselineData, borderColor: '#dc3545', tension: 0.2, fill: false }});
|
| 162 |
+
this._lossChart.update();
|
| 163 |
+
}} finally {{ if (btn) btn.disabled = false; }}
|
| 164 |
+
}}
|
| 165 |
+
|
| 166 |
+
updateChatInterface(episodeState) {{"""
|
| 167 |
+
if "updateLossChart(episodeState)" not in text and old_update in text:
|
| 168 |
+
text = text.replace(old_update, new_update, 1)
|
| 169 |
+
|
| 170 |
+
# 3b) Add Run baseline button click listener
|
| 171 |
+
old_listener = """ // State button
|
| 172 |
+
document.getElementById('state-btn').addEventListener('click', () => {{
|
| 173 |
+
this.getState();
|
| 174 |
+
}});
|
| 175 |
+
}}"""
|
| 176 |
+
new_listener = """ // State button
|
| 177 |
+
document.getElementById('state-btn').addEventListener('click', () => {{
|
| 178 |
+
this.getState();
|
| 179 |
+
}});
|
| 180 |
+
|
| 181 |
+
const runBaselineBtn = document.getElementById('run-baseline-btn');
|
| 182 |
+
if (runBaselineBtn) runBaselineBtn.addEventListener('click', () => this.runBaseline());
|
| 183 |
+
}}"""
|
| 184 |
+
if "run-baseline-btn" not in text or "runBaselineBtn.addEventListener" not in text:
|
| 185 |
+
if old_listener in text:
|
| 186 |
+
text = text.replace(old_listener, new_listener, 1)
|
| 187 |
+
|
| 188 |
+
# 3c) If updateLossChart exists but runBaseline does not, insert runBaseline
|
| 189 |
+
if "updateLossChart(episodeState)" in text and "async runBaseline()" not in text:
|
| 190 |
+
run_baseline_method = """
|
| 191 |
+
async runBaseline() {{
|
| 192 |
+
const btn = document.getElementById('run-baseline-btn');
|
| 193 |
+
if (btn) btn.disabled = true;
|
| 194 |
+
try {{
|
| 195 |
+
const r = await fetch('/web/run-baseline', {{ method: 'POST' }});
|
| 196 |
+
const data = await r.json();
|
| 197 |
+
if (data.error || !data.loss_trajectory || !this._lossChart) {{ if (btn) btn.disabled = false; return; }}
|
| 198 |
+
const L = data.loss_trajectory.length;
|
| 199 |
+
const newLen = Math.max(this._lossChart.data.labels.length, L);
|
| 200 |
+
const newLabels = Array.from({{ length: newLen }}, (_, i) => i);
|
| 201 |
+
this._lossChart.data.labels = newLabels;
|
| 202 |
+
this._lossChart.data.datasets.forEach(ds => {{
|
| 203 |
+
while (ds.data.length < newLen) ds.data.push(null);
|
| 204 |
+
}});
|
| 205 |
+
const baselineData = data.loss_trajectory.slice();
|
| 206 |
+
while (baselineData.length < newLen) baselineData.push(null);
|
| 207 |
+
this._lossChart.data.datasets.push({{ label: 'Baseline (AdamW)', data: baselineData, borderColor: '#dc3545', tension: 0.2, fill: false }});
|
| 208 |
+
this._lossChart.update();
|
| 209 |
+
}} finally {{ if (btn) btn.disabled = false; }}
|
| 210 |
+
}}
|
| 211 |
+
"""
|
| 212 |
+
text = text.replace(
|
| 213 |
+
" }});\n }}\n\n updateChatInterface(episodeState) {{",
|
| 214 |
+
" }});\n }}\n" + run_baseline_method + "\n updateChatInterface(episodeState) {{",
|
| 215 |
+
1,
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
# 4) Add run-baseline and current-task routes
|
| 219 |
+
text = _apply_routes_patch(text)
|
| 220 |
+
|
| 221 |
+
path.write_text(text)
|
| 222 |
+
print("Patched (chart + run-baseline):", path)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
if __name__ == "__main__":
|
| 226 |
+
main()
|
server/meta_optimizer_environment.py
CHANGED
|
@@ -5,15 +5,17 @@
|
|
| 5 |
# LICENSE file in the root directory of this source tree.
|
| 6 |
|
| 7 |
"""
|
| 8 |
-
Meta-optimizer environment: train an RL agent to act as an optimizer on
|
| 9 |
|
| 10 |
-
Supports
|
| 11 |
-
|
|
|
|
| 12 |
"""
|
| 13 |
|
| 14 |
import math
|
| 15 |
import random
|
| 16 |
-
from
|
|
|
|
| 17 |
from uuid import uuid4
|
| 18 |
|
| 19 |
import torch
|
|
@@ -23,7 +25,22 @@ from openenv.core.env_server.interfaces import Environment
|
|
| 23 |
from openenv.core.env_server.types import State
|
| 24 |
|
| 25 |
from my_env.models import MetaOptimizerAction, MetaOptimizerObservation
|
| 26 |
-
from .tasks import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
# Defaults
|
| 29 |
LOSS_THRESHOLD = 0.1
|
|
@@ -31,6 +48,8 @@ MAX_STEPS = 100
|
|
| 31 |
BATCH_SIZE = 32
|
| 32 |
# Dense reward scale: reward += DENSE_REWARD_SCALE * (prev_loss - current_loss) each step (potential-based, helps credit assignment)
|
| 33 |
DENSE_REWARD_SCALE = 0.2
|
|
|
|
|
|
|
| 34 |
|
| 35 |
|
| 36 |
def _default_device() -> torch.device:
|
|
@@ -60,6 +79,35 @@ def _get_batch(spec: TaskSpec, step: int, device: torch.device):
|
|
| 60 |
return X, y
|
| 61 |
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
def run_adam_baseline(
|
| 64 |
task_id: Optional[int] = None,
|
| 65 |
task_spec: Optional[Dict[str, Any]] = None,
|
|
@@ -78,6 +126,8 @@ def run_adam_baseline(
|
|
| 78 |
torch.manual_seed(seed)
|
| 79 |
device = _default_device()
|
| 80 |
spec = task_spec_from_dict(task_spec) if task_spec is not None else get_task(task_id)
|
|
|
|
|
|
|
| 81 |
model = _build_model(spec).to(device)
|
| 82 |
opt = torch.optim.Adam(model.parameters(), lr=lr)
|
| 83 |
loss_trajectory: List[float] = []
|
|
@@ -128,6 +178,8 @@ def run_sgd_baseline(
|
|
| 128 |
torch.manual_seed(seed)
|
| 129 |
device = _default_device()
|
| 130 |
spec = task_spec_from_dict(task_spec) if task_spec is not None else get_task(task_id)
|
|
|
|
|
|
|
| 131 |
model = _build_model(spec).to(device)
|
| 132 |
opt = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)
|
| 133 |
loss_trajectory = []
|
|
@@ -159,6 +211,67 @@ def run_sgd_baseline(
|
|
| 159 |
}
|
| 160 |
|
| 161 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
def run_meta_optimizer_trajectory(
|
| 163 |
task_id: Optional[int] = None,
|
| 164 |
task_spec: Optional[Dict[str, Any]] = None,
|
|
@@ -226,7 +339,7 @@ class MetaOptimizerEnvironment(Environment[MetaOptimizerAction, MetaOptimizerObs
|
|
| 226 |
self._device = _default_device()
|
| 227 |
|
| 228 |
# Episode state (set in reset)
|
| 229 |
-
self._task_spec: Optional[TaskSpec] = None
|
| 230 |
self._model: Optional[nn.Module] = None
|
| 231 |
self._velocities: Optional[List[torch.Tensor]] = None
|
| 232 |
self._step_count: int = 0
|
|
@@ -235,6 +348,9 @@ class MetaOptimizerEnvironment(Environment[MetaOptimizerAction, MetaOptimizerObs
|
|
| 235 |
self._steps_to_threshold: Optional[int] = None
|
| 236 |
self._action_log: List[Dict[str, Any]] = []
|
| 237 |
self._episode_id: Optional[str] = None
|
|
|
|
|
|
|
|
|
|
| 238 |
|
| 239 |
def reset(
|
| 240 |
self,
|
|
@@ -250,23 +366,44 @@ class MetaOptimizerEnvironment(Environment[MetaOptimizerAction, MetaOptimizerObs
|
|
| 250 |
if task_spec is not None:
|
| 251 |
self._task_spec = task_spec_from_dict(task_spec)
|
| 252 |
else:
|
| 253 |
-
tid = task_id if task_id is not None else random.choice(
|
| 254 |
-
self._task_spec =
|
| 255 |
-
self.
|
| 256 |
-
|
| 257 |
-
self.
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
self.
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
|
| 271 |
def step(
|
| 272 |
self,
|
|
@@ -289,10 +426,20 @@ class MetaOptimizerEnvironment(Environment[MetaOptimizerAction, MetaOptimizerObs
|
|
| 289 |
"weight_decay_this_step": wd,
|
| 290 |
})
|
| 291 |
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 296 |
self._model.zero_grad()
|
| 297 |
loss.backward()
|
| 298 |
|
|
@@ -314,19 +461,32 @@ class MetaOptimizerEnvironment(Environment[MetaOptimizerAction, MetaOptimizerObs
|
|
| 314 |
if wd > 0:
|
| 315 |
p.sub_(p, alpha=wd)
|
| 316 |
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 320 |
|
| 321 |
self._step_count += 1
|
| 322 |
-
if self._steps_to_threshold is None and self._current_loss <
|
| 323 |
self._steps_to_threshold = self._step_count
|
| 324 |
|
| 325 |
-
# Dense reward: reward loss decrease (potential-based shaping, does not change optimal policy)
|
| 326 |
dense_reward = DENSE_REWARD_SCALE * (prev_loss - self._current_loss)
|
| 327 |
self._prev_loss = self._current_loss
|
| 328 |
|
| 329 |
-
# End episode when we hit max_steps or when loss first crosses threshold (early termination)
|
| 330 |
done = self._step_count >= self.max_steps or self._steps_to_threshold is not None
|
| 331 |
if done:
|
| 332 |
terminal = -(self._steps_to_threshold if self._steps_to_threshold is not None else self.max_steps)
|
|
@@ -334,13 +494,14 @@ class MetaOptimizerEnvironment(Environment[MetaOptimizerAction, MetaOptimizerObs
|
|
| 334 |
else:
|
| 335 |
reward = dense_reward
|
| 336 |
|
| 337 |
-
return self._observation(reward=reward, grad_norm=grad_norm, done=done)
|
| 338 |
|
| 339 |
def _observation(
|
| 340 |
self,
|
| 341 |
reward: Optional[float] = None,
|
| 342 |
grad_norm: Optional[float] = None,
|
| 343 |
done: bool = False,
|
|
|
|
| 344 |
) -> MetaOptimizerObservation:
|
| 345 |
meta: Dict[str, Any] = {}
|
| 346 |
if self._steps_to_threshold is not None:
|
|
@@ -355,6 +516,7 @@ class MetaOptimizerEnvironment(Environment[MetaOptimizerAction, MetaOptimizerObs
|
|
| 355 |
done=done,
|
| 356 |
reward=reward,
|
| 357 |
metadata=meta,
|
|
|
|
| 358 |
)
|
| 359 |
|
| 360 |
@property
|
|
@@ -367,3 +529,23 @@ class MetaOptimizerEnvironment(Environment[MetaOptimizerAction, MetaOptimizerObs
|
|
| 367 |
def get_episode_action_log(self) -> List[Dict[str, Any]]:
|
| 368 |
"""Return the action log for the current episode (for in-process viz or eval)."""
|
| 369 |
return list(self._action_log)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
# LICENSE file in the root directory of this source tree.
|
| 6 |
|
| 7 |
"""
|
| 8 |
+
Meta-optimizer environment: train an RL agent to act as an optimizer on inner tasks.
|
| 9 |
|
| 10 |
+
Supports (1) SLM: next-token prediction with a tiny transformer; (2) sinusoid regression.
|
| 11 |
+
Rich action space (LR, momentum, grad clip, weight decay), convergence-speed reward.
|
| 12 |
+
Action log and loss/perplexity are exposed for dashboard visualization.
|
| 13 |
"""
|
| 14 |
|
| 15 |
import math
|
| 16 |
import random
|
| 17 |
+
from dataclasses import asdict
|
| 18 |
+
from typing import Any, Dict, List, Optional, Union
|
| 19 |
from uuid import uuid4
|
| 20 |
|
| 21 |
import torch
|
|
|
|
| 25 |
from openenv.core.env_server.types import State
|
| 26 |
|
| 27 |
from my_env.models import MetaOptimizerAction, MetaOptimizerObservation
|
| 28 |
+
from .tasks import (
|
| 29 |
+
DEFAULT_CORPUS,
|
| 30 |
+
SLM_TRAIN_TASK_IDS,
|
| 31 |
+
TRAIN_TASK_IDS,
|
| 32 |
+
get_slm_task,
|
| 33 |
+
get_task,
|
| 34 |
+
task_spec_from_dict,
|
| 35 |
+
TaskSpec,
|
| 36 |
+
SLMTaskSpec,
|
| 37 |
+
)
|
| 38 |
+
from .slm_model import (
|
| 39 |
+
TinyLM,
|
| 40 |
+
build_vocab,
|
| 41 |
+
get_corpus_tensor,
|
| 42 |
+
sample_batch_slm,
|
| 43 |
+
)
|
| 44 |
|
| 45 |
# Defaults
|
| 46 |
LOSS_THRESHOLD = 0.1
|
|
|
|
| 48 |
BATCH_SIZE = 32
|
| 49 |
# Dense reward scale: reward += DENSE_REWARD_SCALE * (prev_loss - current_loss) each step (potential-based, helps credit assignment)
|
| 50 |
DENSE_REWARD_SCALE = 0.2
|
| 51 |
+
# SLM loss threshold (cross-entropy); early termination when loss < this
|
| 52 |
+
SLM_LOSS_THRESHOLD = 1.5
|
| 53 |
|
| 54 |
|
| 55 |
def _default_device() -> torch.device:
|
|
|
|
| 79 |
return X, y
|
| 80 |
|
| 81 |
|
| 82 |
+
def _build_slm(spec: SLMTaskSpec) -> TinyLM:
|
| 83 |
+
"""Build a tiny decoder-only transformer for the given SLM task spec."""
|
| 84 |
+
torch.manual_seed(spec.arch_seed)
|
| 85 |
+
return TinyLM(
|
| 86 |
+
vocab_size=spec.vocab_size,
|
| 87 |
+
context_len=spec.context_len,
|
| 88 |
+
n_layer=spec.n_layer,
|
| 89 |
+
n_head=spec.n_head,
|
| 90 |
+
n_embd=spec.n_embd,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def _get_batch_slm(
|
| 95 |
+
spec: SLMTaskSpec,
|
| 96 |
+
step: int,
|
| 97 |
+
device: torch.device,
|
| 98 |
+
corpus_ids: torch.Tensor,
|
| 99 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 100 |
+
"""Sample a batch for next-token prediction. Returns input_ids [B,T], target_ids [B,T]."""
|
| 101 |
+
return sample_batch_slm(
|
| 102 |
+
corpus_ids,
|
| 103 |
+
BATCH_SIZE,
|
| 104 |
+
spec.context_len,
|
| 105 |
+
step,
|
| 106 |
+
spec.data_seed,
|
| 107 |
+
device,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
def run_adam_baseline(
|
| 112 |
task_id: Optional[int] = None,
|
| 113 |
task_spec: Optional[Dict[str, Any]] = None,
|
|
|
|
| 126 |
torch.manual_seed(seed)
|
| 127 |
device = _default_device()
|
| 128 |
spec = task_spec_from_dict(task_spec) if task_spec is not None else get_task(task_id)
|
| 129 |
+
if isinstance(spec, SLMTaskSpec):
|
| 130 |
+
raise ValueError("Use run_adamw_baseline for SLM tasks")
|
| 131 |
model = _build_model(spec).to(device)
|
| 132 |
opt = torch.optim.Adam(model.parameters(), lr=lr)
|
| 133 |
loss_trajectory: List[float] = []
|
|
|
|
| 178 |
torch.manual_seed(seed)
|
| 179 |
device = _default_device()
|
| 180 |
spec = task_spec_from_dict(task_spec) if task_spec is not None else get_task(task_id)
|
| 181 |
+
if isinstance(spec, SLMTaskSpec):
|
| 182 |
+
raise ValueError("Use run_adamw_baseline for SLM tasks")
|
| 183 |
model = _build_model(spec).to(device)
|
| 184 |
opt = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)
|
| 185 |
loss_trajectory = []
|
|
|
|
| 211 |
}
|
| 212 |
|
| 213 |
|
| 214 |
+
def run_adamw_baseline(
|
| 215 |
+
task_id: Optional[int] = None,
|
| 216 |
+
task_spec: Optional[Dict[str, Any]] = None,
|
| 217 |
+
max_steps: int = MAX_STEPS,
|
| 218 |
+
loss_threshold: float = SLM_LOSS_THRESHOLD,
|
| 219 |
+
lr: float = 1e-3,
|
| 220 |
+
weight_decay: float = 0.01,
|
| 221 |
+
betas: tuple[float, float] = (0.9, 0.999),
|
| 222 |
+
seed: Optional[int] = None,
|
| 223 |
+
return_metrics: bool = False,
|
| 224 |
+
):
|
| 225 |
+
"""
|
| 226 |
+
Run AdamW on one SLM task. Returns steps to threshold, or full metrics dict if return_metrics=True.
|
| 227 |
+
"""
|
| 228 |
+
if (task_id is None) == (task_spec is None):
|
| 229 |
+
raise ValueError("Provide exactly one of task_id or task_spec")
|
| 230 |
+
if seed is not None:
|
| 231 |
+
torch.manual_seed(seed)
|
| 232 |
+
device = _default_device()
|
| 233 |
+
spec = task_spec_from_dict(task_spec) if task_spec is not None else get_slm_task(task_id)
|
| 234 |
+
if isinstance(spec, TaskSpec):
|
| 235 |
+
raise ValueError("Use run_adam_baseline or run_sgd_baseline for sinusoid tasks")
|
| 236 |
+
assert isinstance(spec, SLMTaskSpec)
|
| 237 |
+
model = _build_slm(spec).to(device)
|
| 238 |
+
opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay, betas=betas)
|
| 239 |
+
char2idx, _ = build_vocab()
|
| 240 |
+
corpus_ids = get_corpus_tensor(DEFAULT_CORPUS, char2idx, device)
|
| 241 |
+
loss_trajectory: List[float] = []
|
| 242 |
+
steps_to_threshold: Optional[int] = None
|
| 243 |
+
for step in range(max_steps):
|
| 244 |
+
inp, tgt = _get_batch_slm(spec, step, device, corpus_ids)
|
| 245 |
+
model.train()
|
| 246 |
+
opt.zero_grad()
|
| 247 |
+
logits = model(inp)
|
| 248 |
+
loss = nn.functional.cross_entropy(logits.view(-1, spec.vocab_size), tgt.view(-1))
|
| 249 |
+
loss.backward()
|
| 250 |
+
opt.step()
|
| 251 |
+
with torch.no_grad():
|
| 252 |
+
L = nn.functional.cross_entropy(
|
| 253 |
+
model(inp).view(-1, spec.vocab_size), tgt.view(-1)
|
| 254 |
+
).item()
|
| 255 |
+
loss_trajectory.append(L)
|
| 256 |
+
if steps_to_threshold is None and L < loss_threshold:
|
| 257 |
+
steps_to_threshold = step + 1
|
| 258 |
+
final_loss = loss_trajectory[-1] if loss_trajectory else float("inf")
|
| 259 |
+
perplexity = math.exp(min(final_loss, 20.0))
|
| 260 |
+
if not return_metrics:
|
| 261 |
+
return steps_to_threshold if steps_to_threshold is not None else max_steps
|
| 262 |
+
last_k = min(10, len(loss_trajectory))
|
| 263 |
+
mean_last_k = sum(loss_trajectory[-last_k:]) / last_k if loss_trajectory else final_loss
|
| 264 |
+
return {
|
| 265 |
+
"steps_to_threshold": steps_to_threshold if steps_to_threshold is not None else max_steps,
|
| 266 |
+
"success": steps_to_threshold is not None,
|
| 267 |
+
"final_loss": final_loss,
|
| 268 |
+
"perplexity": perplexity,
|
| 269 |
+
"mean_last_10_loss": mean_last_k,
|
| 270 |
+
"loss_auc": sum(loss_trajectory) / len(loss_trajectory) if loss_trajectory else final_loss,
|
| 271 |
+
"loss_trajectory": loss_trajectory,
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
|
| 275 |
def run_meta_optimizer_trajectory(
|
| 276 |
task_id: Optional[int] = None,
|
| 277 |
task_spec: Optional[Dict[str, Any]] = None,
|
|
|
|
| 339 |
self._device = _default_device()
|
| 340 |
|
| 341 |
# Episode state (set in reset)
|
| 342 |
+
self._task_spec: Optional[Union[TaskSpec, SLMTaskSpec]] = None
|
| 343 |
self._model: Optional[nn.Module] = None
|
| 344 |
self._velocities: Optional[List[torch.Tensor]] = None
|
| 345 |
self._step_count: int = 0
|
|
|
|
| 348 |
self._steps_to_threshold: Optional[int] = None
|
| 349 |
self._action_log: List[Dict[str, Any]] = []
|
| 350 |
self._episode_id: Optional[str] = None
|
| 351 |
+
self._corpus_ids: Optional[torch.Tensor] = None # for SLM only
|
| 352 |
+
self._is_slm: bool = False
|
| 353 |
+
self._slm_loss_threshold: float = SLM_LOSS_THRESHOLD
|
| 354 |
|
| 355 |
def reset(
|
| 356 |
self,
|
|
|
|
| 366 |
if task_spec is not None:
|
| 367 |
self._task_spec = task_spec_from_dict(task_spec)
|
| 368 |
else:
|
| 369 |
+
tid = task_id if task_id is not None else random.choice(SLM_TRAIN_TASK_IDS)
|
| 370 |
+
self._task_spec = get_slm_task(tid)
|
| 371 |
+
self._is_slm = isinstance(self._task_spec, SLMTaskSpec)
|
| 372 |
+
|
| 373 |
+
if self._is_slm:
|
| 374 |
+
spec = self._task_spec
|
| 375 |
+
assert isinstance(spec, SLMTaskSpec)
|
| 376 |
+
self._model = _build_slm(spec).to(self._device)
|
| 377 |
+
self._velocities = [torch.zeros_like(p) for p in self._model.parameters()]
|
| 378 |
+
char2idx, _ = build_vocab()
|
| 379 |
+
self._corpus_ids = get_corpus_tensor(DEFAULT_CORPUS, char2idx, self._device)
|
| 380 |
+
self._step_count = 0
|
| 381 |
+
self._steps_to_threshold = None
|
| 382 |
+
self._action_log = []
|
| 383 |
+
self._episode_id = episode_id or str(uuid4())
|
| 384 |
+
inp, tgt = _get_batch_slm(spec, 0, self._device, self._corpus_ids)
|
| 385 |
+
with torch.no_grad():
|
| 386 |
+
logits = self._model(inp)
|
| 387 |
+
self._current_loss = nn.functional.cross_entropy(
|
| 388 |
+
logits.view(-1, spec.vocab_size), tgt.view(-1)
|
| 389 |
+
).item()
|
| 390 |
+
self._prev_loss = self._current_loss
|
| 391 |
+
return self._observation(reward=None, grad_norm=None, perplexity=math.exp(min(self._current_loss, 20.0)))
|
| 392 |
+
else:
|
| 393 |
+
spec = self._task_spec
|
| 394 |
+
assert isinstance(spec, TaskSpec)
|
| 395 |
+
self._model = _build_model(spec).to(self._device)
|
| 396 |
+
self._velocities = [torch.zeros_like(p) for p in self._model.parameters()]
|
| 397 |
+
self._step_count = 0
|
| 398 |
+
self._steps_to_threshold = None
|
| 399 |
+
self._action_log = []
|
| 400 |
+
self._episode_id = episode_id or str(uuid4())
|
| 401 |
+
X, y = _get_batch(spec, 0, self._device)
|
| 402 |
+
with torch.no_grad():
|
| 403 |
+
out = self._model(X)
|
| 404 |
+
self._current_loss = nn.functional.mse_loss(out, y).item()
|
| 405 |
+
self._prev_loss = self._current_loss
|
| 406 |
+
return self._observation(reward=None, grad_norm=None)
|
| 407 |
|
| 408 |
def step(
|
| 409 |
self,
|
|
|
|
| 426 |
"weight_decay_this_step": wd,
|
| 427 |
})
|
| 428 |
|
| 429 |
+
if self._is_slm:
|
| 430 |
+
spec = self._task_spec
|
| 431 |
+
assert isinstance(spec, SLMTaskSpec)
|
| 432 |
+
inp, tgt = _get_batch_slm(spec, self._step_count + 1, self._device, self._corpus_ids)
|
| 433 |
+
self._model.train()
|
| 434 |
+
logits = self._model(inp)
|
| 435 |
+
loss = nn.functional.cross_entropy(logits.view(-1, spec.vocab_size), tgt.view(-1))
|
| 436 |
+
else:
|
| 437 |
+
spec = self._task_spec
|
| 438 |
+
assert isinstance(spec, TaskSpec)
|
| 439 |
+
X, y = _get_batch(spec, self._step_count + 1, self._device)
|
| 440 |
+
self._model.train()
|
| 441 |
+
loss = nn.functional.mse_loss(self._model(X), y)
|
| 442 |
+
|
| 443 |
self._model.zero_grad()
|
| 444 |
loss.backward()
|
| 445 |
|
|
|
|
| 461 |
if wd > 0:
|
| 462 |
p.sub_(p, alpha=wd)
|
| 463 |
|
| 464 |
+
if self._is_slm:
|
| 465 |
+
spec = self._task_spec
|
| 466 |
+
assert isinstance(spec, SLMTaskSpec)
|
| 467 |
+
with torch.no_grad():
|
| 468 |
+
logits = self._model(inp)
|
| 469 |
+
self._current_loss = nn.functional.cross_entropy(
|
| 470 |
+
logits.view(-1, spec.vocab_size), tgt.view(-1)
|
| 471 |
+
).item()
|
| 472 |
+
loss_threshold = self._slm_loss_threshold
|
| 473 |
+
perp = math.exp(min(self._current_loss, 20.0))
|
| 474 |
+
else:
|
| 475 |
+
spec = self._task_spec
|
| 476 |
+
assert isinstance(spec, TaskSpec)
|
| 477 |
+
with torch.no_grad():
|
| 478 |
+
X, y = _get_batch(spec, self._step_count + 1, self._device)
|
| 479 |
+
self._current_loss = nn.functional.mse_loss(self._model(X), y).item()
|
| 480 |
+
loss_threshold = self.loss_threshold
|
| 481 |
+
perp = None
|
| 482 |
|
| 483 |
self._step_count += 1
|
| 484 |
+
if self._steps_to_threshold is None and self._current_loss < loss_threshold:
|
| 485 |
self._steps_to_threshold = self._step_count
|
| 486 |
|
|
|
|
| 487 |
dense_reward = DENSE_REWARD_SCALE * (prev_loss - self._current_loss)
|
| 488 |
self._prev_loss = self._current_loss
|
| 489 |
|
|
|
|
| 490 |
done = self._step_count >= self.max_steps or self._steps_to_threshold is not None
|
| 491 |
if done:
|
| 492 |
terminal = -(self._steps_to_threshold if self._steps_to_threshold is not None else self.max_steps)
|
|
|
|
| 494 |
else:
|
| 495 |
reward = dense_reward
|
| 496 |
|
| 497 |
+
return self._observation(reward=reward, grad_norm=grad_norm, done=done, perplexity=perp)
|
| 498 |
|
| 499 |
def _observation(
|
| 500 |
self,
|
| 501 |
reward: Optional[float] = None,
|
| 502 |
grad_norm: Optional[float] = None,
|
| 503 |
done: bool = False,
|
| 504 |
+
perplexity: Optional[float] = None,
|
| 505 |
) -> MetaOptimizerObservation:
|
| 506 |
meta: Dict[str, Any] = {}
|
| 507 |
if self._steps_to_threshold is not None:
|
|
|
|
| 516 |
done=done,
|
| 517 |
reward=reward,
|
| 518 |
metadata=meta,
|
| 519 |
+
perplexity=perplexity,
|
| 520 |
)
|
| 521 |
|
| 522 |
@property
|
|
|
|
| 529 |
def get_episode_action_log(self) -> List[Dict[str, Any]]:
|
| 530 |
"""Return the action log for the current episode (for in-process viz or eval)."""
|
| 531 |
return list(self._action_log)
|
| 532 |
+
|
| 533 |
+
def get_current_task_spec(self) -> Optional[Dict[str, Any]]:
|
| 534 |
+
"""Return current task spec as a dict for dashboard / run-baseline. None if no episode started."""
|
| 535 |
+
if self._task_spec is None:
|
| 536 |
+
return None
|
| 537 |
+
if isinstance(self._task_spec, SLMTaskSpec):
|
| 538 |
+
return {"type": "slm", **asdict(self._task_spec)}
|
| 539 |
+
return {"type": "sinusoid", **asdict(self._task_spec)}
|
| 540 |
+
|
| 541 |
+
def run_baseline(self) -> Dict[str, Any]:
|
| 542 |
+
"""Run the appropriate baseline (AdamW for SLM, Adam for sinusoid) for current task. Returns loss_trajectory and steps."""
|
| 543 |
+
spec_dict = self.get_current_task_spec()
|
| 544 |
+
if spec_dict is None:
|
| 545 |
+
return {"loss_trajectory": [], "steps": [], "error": "No task"}
|
| 546 |
+
if spec_dict.get("type") == "slm":
|
| 547 |
+
result = run_adamw_baseline(task_spec=spec_dict, max_steps=self.max_steps, return_metrics=True)
|
| 548 |
+
else:
|
| 549 |
+
result = run_adam_baseline(task_spec=spec_dict, max_steps=self.max_steps, return_metrics=True)
|
| 550 |
+
traj = result.get("loss_trajectory", [])
|
| 551 |
+
return {"loss_trajectory": traj, "steps": list(range(len(traj)))}
|
server/slm_model.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Tiny decoder-only transformer for SLM meta-optimizer inner task.
|
| 9 |
+
Pure PyTorch, no transformers dependency.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import math
|
| 13 |
+
from typing import Tuple
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# Fixed character vocab for reproducible SLM tasks (subset of printable ASCII)
|
| 20 |
+
DEFAULT_CHARS = (
|
| 21 |
+
" \n\t"
|
| 22 |
+
"abcdefghijklmnopqrstuvwxyz"
|
| 23 |
+
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
| 24 |
+
"0123456789"
|
| 25 |
+
".,;:!?'\"-()"
|
| 26 |
+
)
|
| 27 |
+
DEFAULT_VOCAB_SIZE = len(DEFAULT_CHARS)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def build_vocab(chars: str = DEFAULT_CHARS) -> Tuple[dict, dict]:
|
| 31 |
+
"""Return char2idx and idx2char dicts."""
|
| 32 |
+
char2idx = {c: i for i, c in enumerate(chars)}
|
| 33 |
+
idx2char = {i: c for c, i in char2idx.items()}
|
| 34 |
+
return char2idx, idx2char
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def encode_corpus(text: str, char2idx: dict, default_idx: int = 0) -> torch.Tensor:
|
| 38 |
+
"""Encode string to long tensor of token ids. Unknown chars map to default_idx."""
|
| 39 |
+
ids = [char2idx.get(c, default_idx) for c in text]
|
| 40 |
+
return torch.tensor(ids, dtype=torch.long)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def get_corpus_tensor(
|
| 44 |
+
text: str,
|
| 45 |
+
char2idx: dict,
|
| 46 |
+
device: torch.device,
|
| 47 |
+
) -> torch.Tensor:
|
| 48 |
+
"""Return 1D long tensor of token ids on device."""
|
| 49 |
+
t = encode_corpus(text, char2idx)
|
| 50 |
+
return t.to(device)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def sample_batch_slm(
|
| 54 |
+
corpus_ids: torch.Tensor,
|
| 55 |
+
batch_size: int,
|
| 56 |
+
context_len: int,
|
| 57 |
+
step: int,
|
| 58 |
+
data_seed: int,
|
| 59 |
+
device: torch.device,
|
| 60 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 61 |
+
"""
|
| 62 |
+
Sample batch_size contiguous chunks from corpus for next-token prediction.
|
| 63 |
+
Returns: input_ids [B, context_len], target_ids [B, context_len] (target = input shifted by 1).
|
| 64 |
+
"""
|
| 65 |
+
L = corpus_ids.size(0)
|
| 66 |
+
if L <= context_len + 1:
|
| 67 |
+
raise ValueError("Corpus too short for context_len")
|
| 68 |
+
max_start = L - context_len - 1
|
| 69 |
+
g = torch.Generator(device=device)
|
| 70 |
+
g.manual_seed(data_seed + step)
|
| 71 |
+
starts = torch.randint(0, max_start, (batch_size,), device=device, generator=g)
|
| 72 |
+
inputs = []
|
| 73 |
+
targets = []
|
| 74 |
+
for b in range(batch_size):
|
| 75 |
+
s = int(starts[b].item())
|
| 76 |
+
chunk = corpus_ids[s : s + context_len + 1]
|
| 77 |
+
inputs.append(chunk[:context_len])
|
| 78 |
+
targets.append(chunk[1 : context_len + 1])
|
| 79 |
+
return torch.stack(inputs), torch.stack(targets)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class CausalSelfAttention(nn.Module):
|
| 83 |
+
def __init__(self, n_embd: int, n_head: int, block_size: int):
|
| 84 |
+
super().__init__()
|
| 85 |
+
assert n_embd % n_head == 0
|
| 86 |
+
self.n_head = n_head
|
| 87 |
+
self.n_embd = n_embd
|
| 88 |
+
self.head_dim = n_embd // n_head
|
| 89 |
+
self.register_buffer(
|
| 90 |
+
"mask",
|
| 91 |
+
torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size),
|
| 92 |
+
)
|
| 93 |
+
self.c_attn = nn.Linear(n_embd, 3 * n_embd)
|
| 94 |
+
self.c_proj = nn.Linear(n_embd, n_embd)
|
| 95 |
+
|
| 96 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 97 |
+
B, T, C = x.shape
|
| 98 |
+
qkv = self.c_attn(x)
|
| 99 |
+
q, k, v = qkv.split(self.n_embd, dim=2)
|
| 100 |
+
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
| 101 |
+
k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
| 102 |
+
v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
| 103 |
+
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))
|
| 104 |
+
att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
|
| 105 |
+
att = torch.softmax(att, dim=-1)
|
| 106 |
+
out = (att @ v).transpose(1, 2).contiguous().view(B, T, C)
|
| 107 |
+
return self.c_proj(out)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class Block(nn.Module):
|
| 111 |
+
def __init__(self, n_embd: int, n_head: int, block_size: int):
|
| 112 |
+
super().__init__()
|
| 113 |
+
self.ln1 = nn.LayerNorm(n_embd)
|
| 114 |
+
self.attn = CausalSelfAttention(n_embd, n_head, block_size)
|
| 115 |
+
self.ln2 = nn.LayerNorm(n_embd)
|
| 116 |
+
self.mlp = nn.Sequential(
|
| 117 |
+
nn.Linear(n_embd, 4 * n_embd),
|
| 118 |
+
nn.GELU(),
|
| 119 |
+
nn.Linear(4 * n_embd, n_embd),
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 123 |
+
x = x + self.attn(self.ln1(x))
|
| 124 |
+
x = x + self.mlp(self.ln2(x))
|
| 125 |
+
return x
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class TinyLM(nn.Module):
|
| 129 |
+
"""Decoder-only transformer for next-token prediction."""
|
| 130 |
+
|
| 131 |
+
def __init__(
|
| 132 |
+
self,
|
| 133 |
+
vocab_size: int,
|
| 134 |
+
context_len: int,
|
| 135 |
+
n_layer: int,
|
| 136 |
+
n_head: int,
|
| 137 |
+
n_embd: int,
|
| 138 |
+
):
|
| 139 |
+
super().__init__()
|
| 140 |
+
self.context_len = context_len
|
| 141 |
+
self.vocab_size = vocab_size
|
| 142 |
+
self.token_embed = nn.Embedding(vocab_size, n_embd)
|
| 143 |
+
self.pos_embed = nn.Embedding(context_len, n_embd)
|
| 144 |
+
self.blocks = nn.ModuleList(
|
| 145 |
+
[Block(n_embd, n_head, context_len) for _ in range(n_layer)]
|
| 146 |
+
)
|
| 147 |
+
self.ln_f = nn.LayerNorm(n_embd)
|
| 148 |
+
self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
|
| 149 |
+
|
| 150 |
+
def forward(self, idx: torch.Tensor) -> torch.Tensor:
|
| 151 |
+
# idx: [B, T]; clamp to valid range in case of encoding drift
|
| 152 |
+
idx = idx.clamp(0, self.vocab_size - 1)
|
| 153 |
+
B, T = idx.shape
|
| 154 |
+
pos = torch.arange(0, T, device=idx.device, dtype=torch.long)
|
| 155 |
+
x = self.token_embed(idx) + self.pos_embed(pos)
|
| 156 |
+
for block in self.blocks:
|
| 157 |
+
x = block(x)
|
| 158 |
+
x = self.ln_f(x)
|
| 159 |
+
logits = self.lm_head(x)
|
| 160 |
+
return logits
|
server/tasks.py
CHANGED
|
@@ -9,6 +9,7 @@ Task registry for meta-learning.
|
|
| 9 |
|
| 10 |
Tasks can be from the internal registry (get_task(task_id)) or provided from outside
|
| 11 |
via task_spec_from_dict() — the client sends the task definition to the environment.
|
|
|
|
| 12 |
"""
|
| 13 |
|
| 14 |
from dataclasses import dataclass
|
|
@@ -16,12 +17,27 @@ from typing import Any, Dict, List
|
|
| 16 |
|
| 17 |
import math
|
| 18 |
|
|
|
|
|
|
|
| 19 |
# Distribution A: 50 training tasks (low-freq sinusoids)
|
| 20 |
TRAIN_TASK_IDS: List[int] = list(range(50))
|
| 21 |
|
| 22 |
# Distribution B: held-out eval tasks (high-freq sinusoids — different distribution)
|
| 23 |
EVAL_TASK_IDS: List[int] = [50, 51]
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
# Bounds for each distribution (freq, amplitude, phase)
|
| 26 |
DIST_A_FREQ = (1.0, 3.0)
|
| 27 |
DIST_A_AMP = (0.5, 2.0)
|
|
@@ -46,6 +62,22 @@ class TaskSpec:
|
|
| 46 |
distribution: str # "A" or "B"
|
| 47 |
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
def get_task(task_id: int) -> TaskSpec:
|
| 50 |
"""
|
| 51 |
Return the task spec for the given task_id.
|
|
@@ -88,28 +120,73 @@ def get_task(task_id: int) -> TaskSpec:
|
|
| 88 |
)
|
| 89 |
|
| 90 |
|
| 91 |
-
def
|
| 92 |
"""
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
Expected keys for type "sinusoid":
|
| 97 |
-
type="sinusoid", amplitude, freq, phase, data_seed (optional), arch_seed (optional),
|
| 98 |
-
input_dim (optional, default 1), hidden_dim (optional, default 32), task_id (optional).
|
| 99 |
"""
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
task_id=task_id,
|
| 106 |
-
input_dim=int(d.get("input_dim", 1)),
|
| 107 |
-
hidden_dim=int(d.get("hidden_dim", 32)),
|
| 108 |
-
output_dim=1,
|
| 109 |
data_seed=int(d.get("data_seed", task_id * 31337)),
|
| 110 |
arch_seed=int(d.get("arch_seed", task_id * 131 + 7)),
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
|
|
|
|
|
|
|
|
|
| 114 |
distribution=d.get("distribution", "external"),
|
| 115 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
Tasks can be from the internal registry (get_task(task_id)) or provided from outside
|
| 11 |
via task_spec_from_dict() — the client sends the task definition to the environment.
|
| 12 |
+
Supports sinusoid (regression) and SLM (next-token prediction) task types.
|
| 13 |
"""
|
| 14 |
|
| 15 |
from dataclasses import dataclass
|
|
|
|
| 17 |
|
| 18 |
import math
|
| 19 |
|
| 20 |
+
from .slm_model import DEFAULT_VOCAB_SIZE as SLM_DEFAULT_VOCAB_SIZE
|
| 21 |
+
|
| 22 |
# Distribution A: 50 training tasks (low-freq sinusoids)
|
| 23 |
TRAIN_TASK_IDS: List[int] = list(range(50))
|
| 24 |
|
| 25 |
# Distribution B: held-out eval tasks (high-freq sinusoids — different distribution)
|
| 26 |
EVAL_TASK_IDS: List[int] = [50, 51]
|
| 27 |
|
| 28 |
+
# SLM: 50 train tasks, 2 eval (different corpus split or seed range)
|
| 29 |
+
SLM_TRAIN_TASK_IDS: List[int] = list(range(50))
|
| 30 |
+
SLM_EVAL_TASK_IDS: List[int] = [50, 51]
|
| 31 |
+
|
| 32 |
+
# Fixed small corpus for SLM (character-level). ~10KB so tasks are reproducible.
|
| 33 |
+
DEFAULT_CORPUS: str = (
|
| 34 |
+
"The quick brown fox jumps over the lazy dog. "
|
| 35 |
+
"Pack my box with five dozen liquor jugs. "
|
| 36 |
+
"How vexingly quick daft zebras jump. "
|
| 37 |
+
"Sphinx of black quartz, judge my vow. "
|
| 38 |
+
"The five boxing wizards jump quickly. "
|
| 39 |
+
) * 200 # repeat to get enough length for sampling
|
| 40 |
+
|
| 41 |
# Bounds for each distribution (freq, amplitude, phase)
|
| 42 |
DIST_A_FREQ = (1.0, 3.0)
|
| 43 |
DIST_A_AMP = (0.5, 2.0)
|
|
|
|
| 62 |
distribution: str # "A" or "B"
|
| 63 |
|
| 64 |
|
| 65 |
+
@dataclass
|
| 66 |
+
class SLMTaskSpec:
|
| 67 |
+
"""Spec for one SLM (next-token prediction) task."""
|
| 68 |
+
|
| 69 |
+
task_id: int
|
| 70 |
+
data_seed: int
|
| 71 |
+
arch_seed: int
|
| 72 |
+
vocab_size: int
|
| 73 |
+
context_len: int # block size
|
| 74 |
+
n_layer: int
|
| 75 |
+
n_head: int
|
| 76 |
+
n_embd: int
|
| 77 |
+
corpus_id: str # e.g. "default"
|
| 78 |
+
distribution: str # "A" or "B" or "external"
|
| 79 |
+
|
| 80 |
+
|
| 81 |
def get_task(task_id: int) -> TaskSpec:
|
| 82 |
"""
|
| 83 |
Return the task spec for the given task_id.
|
|
|
|
| 120 |
)
|
| 121 |
|
| 122 |
|
| 123 |
+
def get_slm_task(task_id: int) -> SLMTaskSpec:
|
| 124 |
"""
|
| 125 |
+
Return the SLM task spec for the given task_id.
|
| 126 |
+
Task IDs 0..49 = Distribution A (train), 50+ = Distribution B (eval).
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
"""
|
| 128 |
+
if task_id < 0:
|
| 129 |
+
raise ValueError(f"task_id must be >= 0, got {task_id}")
|
| 130 |
+
r = task_id * 7919 + 1
|
| 131 |
+
data_seed = task_id * 31337
|
| 132 |
+
arch_seed = task_id * 131 + 7
|
| 133 |
+
if task_id < 50:
|
| 134 |
+
distribution = "A"
|
| 135 |
+
else:
|
| 136 |
+
distribution = "B"
|
| 137 |
+
return SLMTaskSpec(
|
| 138 |
+
task_id=task_id,
|
| 139 |
+
data_seed=data_seed,
|
| 140 |
+
arch_seed=arch_seed,
|
| 141 |
+
vocab_size=SLM_DEFAULT_VOCAB_SIZE,
|
| 142 |
+
context_len=64,
|
| 143 |
+
n_layer=2,
|
| 144 |
+
n_head=4,
|
| 145 |
+
n_embd=128,
|
| 146 |
+
corpus_id="default",
|
| 147 |
+
distribution=distribution,
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def slm_task_spec_from_dict(d: Dict[str, Any]) -> SLMTaskSpec:
|
| 152 |
+
"""Build an SLMTaskSpec from an external dict (type='slm')."""
|
| 153 |
+
task_id = int(d.get("task_id", 0))
|
| 154 |
+
return SLMTaskSpec(
|
| 155 |
task_id=task_id,
|
|
|
|
|
|
|
|
|
|
| 156 |
data_seed=int(d.get("data_seed", task_id * 31337)),
|
| 157 |
arch_seed=int(d.get("arch_seed", task_id * 131 + 7)),
|
| 158 |
+
vocab_size=int(d.get("vocab_size", SLM_DEFAULT_VOCAB_SIZE)),
|
| 159 |
+
context_len=int(d.get("context_len", 64)),
|
| 160 |
+
n_layer=int(d.get("n_layer", 2)),
|
| 161 |
+
n_head=int(d.get("n_head", 4)),
|
| 162 |
+
n_embd=int(d.get("n_embd", 128)),
|
| 163 |
+
corpus_id=str(d.get("corpus_id", "default")),
|
| 164 |
distribution=d.get("distribution", "external"),
|
| 165 |
)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def task_spec_from_dict(d: Dict[str, Any]) -> TaskSpec | SLMTaskSpec:
|
| 169 |
+
"""
|
| 170 |
+
Build a TaskSpec or SLMTaskSpec from an external dict (sent by the client).
|
| 171 |
+
|
| 172 |
+
For type "sinusoid": amplitude, freq, phase, data_seed (optional), arch_seed (optional), etc.
|
| 173 |
+
For type "slm": vocab_size, context_len, n_layer, n_head, n_embd (all optional with defaults).
|
| 174 |
+
"""
|
| 175 |
+
task_type = d.get("type", "slm")
|
| 176 |
+
if task_type == "sinusoid":
|
| 177 |
+
task_id = d.get("task_id", 0)
|
| 178 |
+
return TaskSpec(
|
| 179 |
+
task_id=task_id,
|
| 180 |
+
input_dim=int(d.get("input_dim", 1)),
|
| 181 |
+
hidden_dim=int(d.get("hidden_dim", 32)),
|
| 182 |
+
output_dim=1,
|
| 183 |
+
data_seed=int(d.get("data_seed", task_id * 31337)),
|
| 184 |
+
arch_seed=int(d.get("arch_seed", task_id * 131 + 7)),
|
| 185 |
+
amplitude=float(d["amplitude"]),
|
| 186 |
+
freq=float(d["freq"]),
|
| 187 |
+
phase=float(d["phase"]),
|
| 188 |
+
distribution=d.get("distribution", "external"),
|
| 189 |
+
)
|
| 190 |
+
if task_type == "slm":
|
| 191 |
+
return slm_task_spec_from_dict(d)
|
| 192 |
+
raise ValueError(f"Unknown task type: {task_type!r}")
|