SavirD commited on
Commit
4d2821f
·
verified ·
1 Parent(s): d9452da

Upload folder using huggingface_hub

Browse files
Dockerfile CHANGED
@@ -57,6 +57,9 @@ RUN --mount=type=cache,target=/root/.cache/uv \
57
  # Patch OpenEnv web UI so number inputs use step=0.01 (allows lr_scale=0.02, momentum_coef=0.9)
58
  RUN WEBIF="$$(find /app/env/.venv -path '*openenv*env_server*web_interface.py' | head -1)" && \
59
  /app/env/.venv/bin/python /app/env/scripts/patch_openenv_web_interface.py "$$WEBIF"
 
 
 
60
 
61
  # Final runtime stage
62
  FROM ${BASE_IMAGE}
 
57
  # Patch OpenEnv web UI so number inputs use step=0.01 (allows lr_scale=0.02, momentum_coef=0.9)
58
  RUN WEBIF="$$(find /app/env/.venv -path '*openenv*env_server*web_interface.py' | head -1)" && \
59
  /app/env/.venv/bin/python /app/env/scripts/patch_openenv_web_interface.py "$$WEBIF"
60
+ # Patch OpenEnv web UI to add loss/perplexity chart and Run baseline (AdamW) button
61
+ RUN WEBIF="$$(find /app/env/.venv -path '*openenv*env_server*web_interface.py' | head -1)" && \
62
+ /app/env/.venv/bin/python /app/env/scripts/patch_openenv_web_interface_chart.py "$$WEBIF"
63
 
64
  # Final runtime stage
65
  FROM ${BASE_IMAGE}
models.py CHANGED
@@ -72,4 +72,8 @@ class MetaOptimizerObservation(Observation):
72
  default=None,
73
  description="Step at which loss first reached threshold (None if not yet reached)",
74
  )
 
 
 
 
75
 
 
72
  default=None,
73
  description="Step at which loss first reached threshold (None if not yet reached)",
74
  )
75
+ perplexity: float | None = Field(
76
+ default=None,
77
+ description="exp(loss) for language modeling (None for regression)",
78
+ )
79
 
scripts/patch_openenv_web_interface_chart.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Patch OpenEnv web_interface.py to add:
4
+ - Loss/perplexity chart and updateLossChart()
5
+ - POST /web/run-baseline and GET /web/current-task for baseline comparison
6
+ Idempotent: safe to run multiple times.
7
+ """
8
+ import sys
9
+ from pathlib import Path
10
+
11
+
12
+ def _apply_routes_patch(text: str) -> str:
13
+ """Add /web/run-baseline and /web/current-task routes."""
14
+ old_routes = (
15
+ ' @app.get("/web/state")\n'
16
+ " async def web_state():\n"
17
+ ' """State endpoint for web interface."""\n'
18
+ " return web_manager.get_state()\n"
19
+ "\n"
20
+ " return app"
21
+ )
22
+ new_routes = (
23
+ ' @app.get("/web/state")\n'
24
+ " async def web_state():\n"
25
+ ' """State endpoint for web interface."""\n'
26
+ " return web_manager.get_state()\n"
27
+ "\n"
28
+ ' @app.get("/web/current-task")\n'
29
+ " async def web_current_task():\n"
30
+ ' """Current task spec for baseline comparison (if env supports it)."""\n'
31
+ " get_spec = getattr(web_manager.env, \"get_current_task_spec\", None)\n"
32
+ " if get_spec is None:\n"
33
+ " return {}\n"
34
+ " return get_spec() or {}\n"
35
+ "\n"
36
+ ' @app.post("/web/run-baseline")\n'
37
+ " async def web_run_baseline():\n"
38
+ ' """Run baseline optimizer for current task; returns loss_trajectory and steps."""\n'
39
+ " run_bl = getattr(web_manager.env, \"run_baseline\", None)\n"
40
+ " if run_bl is None:\n"
41
+ " return {\"loss_trajectory\": [], \"steps\": [], \"error\": \"Env has no run_baseline\"}\n"
42
+ " return run_bl()\n"
43
+ "\n"
44
+ " return app"
45
+ )
46
+ if "web/run-baseline" not in text and "web/state" in text and "return web_manager.get_state()" in text:
47
+ text = text.replace(old_routes, new_routes, 1)
48
+ return text
49
+
50
+
51
+ def main() -> None:
52
+ if len(sys.argv) < 2:
53
+ import openenv.core.env_server.web_interface as m
54
+ path = Path(m.__file__).resolve()
55
+ else:
56
+ path = Path(sys.argv[1]).resolve()
57
+
58
+ if not path.exists():
59
+ print(f"File not found: {path}", file=sys.stderr)
60
+ sys.exit(1)
61
+
62
+ text = path.read_text()
63
+
64
+ # 1) Add Chart.js script in head (after title)
65
+ chart_script = ' <script src="https://cdn.jsdelivr.net/npm/chart.js"></script>\n'
66
+ old_head = "<title>OpenEnv Web Interface</title>\n <style>"
67
+ new_head = "<title>OpenEnv Web Interface</title>\n" + chart_script + " <style>"
68
+ if chart_script not in text and old_head in text:
69
+ text = text.replace(old_head, new_head, 1)
70
+
71
+ # 2) Add chart container between Current Observation and Action Logs
72
+ old_section = """ </div>
73
+ </div>
74
+
75
+ <!-- Action Logs -->
76
+ <div class="logs-container">"""
77
+ new_section = """ </div>
78
+ </div>
79
+
80
+ <!-- Loss chart -->
81
+ <div class="state-display">
82
+ <h3>Loss / Perplexity</h3>
83
+ <div id="loss-chart-container" style="height:200px;"><canvas id="loss-chart"></canvas></div>
84
+ <button type="button" id="run-baseline-btn" class="btn btn-secondary" style="margin-top:8px;">Run baseline (AdamW)</button>
85
+ </div>
86
+
87
+ <!-- Action Logs -->
88
+ <div class="logs-container">"""
89
+ if "loss-chart-container" not in text and old_section in text:
90
+ text = text.replace(old_section, new_section, 1)
91
+ # If chart container exists but button does not, add button
92
+ if "loss-chart-container" in text and "run-baseline-btn" not in text:
93
+ text = text.replace(
94
+ "<canvas id=\"loss-chart\"></canvas></div>\n </div>",
95
+ "<canvas id=\"loss-chart\"></canvas></div>\n <button type=\"button\" id=\"run-baseline-btn\" class=\"btn btn-secondary\" style=\"margin-top:8px;\">Run baseline (AdamW)</button>\n </div>",
96
+ 1,
97
+ )
98
+
99
+ # 3) Add updateLossChart call and method before updateChatInterface
100
+ old_update = """ }}
101
+ }}
102
+
103
+ updateChatInterface(episodeState) {{"""
104
+ new_update = """ }}
105
+ this.updateLossChart(episodeState);
106
+ }}
107
+
108
+ updateLossChart(episodeState) {{
109
+ const container = document.getElementById('loss-chart-container');
110
+ if (!container) return;
111
+ const steps = [];
112
+ const losses = [];
113
+ const perplexities = [];
114
+ if (episodeState.current_observation && typeof episodeState.current_observation.loss === 'number') {{
115
+ const o = episodeState.current_observation;
116
+ steps.push(o.step_count != null ? o.step_count : 0);
117
+ losses.push(o.loss);
118
+ if (typeof o.perplexity === 'number') perplexities.push(o.perplexity);
119
+ }}
120
+ (episodeState.action_logs || []).forEach(log => {{
121
+ if (log.observation && typeof log.observation.loss === 'number') {{
122
+ steps.push(log.observation.step_count != null ? log.observation.step_count : log.step_count);
123
+ losses.push(log.observation.loss);
124
+ if (typeof log.observation.perplexity === 'number') perplexities.push(log.observation.perplexity);
125
+ }}
126
+ }});
127
+ if (steps.length === 0) return;
128
+ const ctx = document.getElementById('loss-chart');
129
+ if (!ctx) return;
130
+ if (this._lossChart) this._lossChart.destroy();
131
+ this._lossChart = new Chart(ctx, {{
132
+ type: 'line',
133
+ data: {{
134
+ labels: steps,
135
+ datasets: [
136
+ {{ label: 'Loss', data: losses, borderColor: '#007bff', tension: 0.2, fill: false }}
137
+ ].concat(perplexities.length ? [{{ label: 'Perplexity', data: perplexities, borderColor: '#28a745', tension: 0.2, fill: false }}] : [])
138
+ }},
139
+ options: {{ responsive: true, maintainAspectRatio: false, scales: {{ x: {{ title: {{ display: true, text: 'Step' }} }} }} }}
140
+ }});
141
+ }}
142
+
143
+ async runBaseline() {{
144
+ const btn = document.getElementById('run-baseline-btn');
145
+ if (btn) btn.disabled = true;
146
+ try {{
147
+ const r = await fetch('/web/run-baseline', {{ method: 'POST' }});
148
+ const data = await r.json();
149
+ if (data.error || !data.loss_trajectory || !this._lossChart) {{ if (btn) btn.disabled = false; return; }}
150
+ const L = data.loss_trajectory.length;
151
+ const steps = data.steps && data.steps.length === L ? data.steps : Array.from({{ length: L }}, (_, i) => i);
152
+ const curLen = this._lossChart.data.labels.length;
153
+ const newLen = Math.max(curLen, steps.length);
154
+ const newLabels = Array.from({{ length: newLen }}, (_, i) => i);
155
+ this._lossChart.data.labels = newLabels;
156
+ this._lossChart.data.datasets.forEach(ds => {{
157
+ while (ds.data.length < newLen) ds.data.push(null);
158
+ }});
159
+ const baselineData = data.loss_trajectory.slice();
160
+ while (baselineData.length < newLen) baselineData.push(null);
161
+ this._lossChart.data.datasets.push({{ label: 'Baseline (AdamW)', data: baselineData, borderColor: '#dc3545', tension: 0.2, fill: false }});
162
+ this._lossChart.update();
163
+ }} finally {{ if (btn) btn.disabled = false; }}
164
+ }}
165
+
166
+ updateChatInterface(episodeState) {{"""
167
+ if "updateLossChart(episodeState)" not in text and old_update in text:
168
+ text = text.replace(old_update, new_update, 1)
169
+
170
+ # 3b) Add Run baseline button click listener
171
+ old_listener = """ // State button
172
+ document.getElementById('state-btn').addEventListener('click', () => {{
173
+ this.getState();
174
+ }});
175
+ }}"""
176
+ new_listener = """ // State button
177
+ document.getElementById('state-btn').addEventListener('click', () => {{
178
+ this.getState();
179
+ }});
180
+
181
+ const runBaselineBtn = document.getElementById('run-baseline-btn');
182
+ if (runBaselineBtn) runBaselineBtn.addEventListener('click', () => this.runBaseline());
183
+ }}"""
184
+ if "run-baseline-btn" not in text or "runBaselineBtn.addEventListener" not in text:
185
+ if old_listener in text:
186
+ text = text.replace(old_listener, new_listener, 1)
187
+
188
+ # 3c) If updateLossChart exists but runBaseline does not, insert runBaseline
189
+ if "updateLossChart(episodeState)" in text and "async runBaseline()" not in text:
190
+ run_baseline_method = """
191
+ async runBaseline() {{
192
+ const btn = document.getElementById('run-baseline-btn');
193
+ if (btn) btn.disabled = true;
194
+ try {{
195
+ const r = await fetch('/web/run-baseline', {{ method: 'POST' }});
196
+ const data = await r.json();
197
+ if (data.error || !data.loss_trajectory || !this._lossChart) {{ if (btn) btn.disabled = false; return; }}
198
+ const L = data.loss_trajectory.length;
199
+ const newLen = Math.max(this._lossChart.data.labels.length, L);
200
+ const newLabels = Array.from({{ length: newLen }}, (_, i) => i);
201
+ this._lossChart.data.labels = newLabels;
202
+ this._lossChart.data.datasets.forEach(ds => {{
203
+ while (ds.data.length < newLen) ds.data.push(null);
204
+ }});
205
+ const baselineData = data.loss_trajectory.slice();
206
+ while (baselineData.length < newLen) baselineData.push(null);
207
+ this._lossChart.data.datasets.push({{ label: 'Baseline (AdamW)', data: baselineData, borderColor: '#dc3545', tension: 0.2, fill: false }});
208
+ this._lossChart.update();
209
+ }} finally {{ if (btn) btn.disabled = false; }}
210
+ }}
211
+ """
212
+ text = text.replace(
213
+ " }});\n }}\n\n updateChatInterface(episodeState) {{",
214
+ " }});\n }}\n" + run_baseline_method + "\n updateChatInterface(episodeState) {{",
215
+ 1,
216
+ )
217
+
218
+ # 4) Add run-baseline and current-task routes
219
+ text = _apply_routes_patch(text)
220
+
221
+ path.write_text(text)
222
+ print("Patched (chart + run-baseline):", path)
223
+
224
+
225
+ if __name__ == "__main__":
226
+ main()
server/meta_optimizer_environment.py CHANGED
@@ -5,15 +5,17 @@
5
  # LICENSE file in the root directory of this source tree.
6
 
7
  """
8
- Meta-optimizer environment: train an RL agent to act as an optimizer on random regression tasks.
9
 
10
- Supports 50 training tasks, held-out eval, rich action space (LR, momentum, grad clip, weight decay),
11
- and convergence-speed reward. Action log is exposed for emergent-behavior visualization.
 
12
  """
13
 
14
  import math
15
  import random
16
- from typing import Any, Dict, List, Optional
 
17
  from uuid import uuid4
18
 
19
  import torch
@@ -23,7 +25,22 @@ from openenv.core.env_server.interfaces import Environment
23
  from openenv.core.env_server.types import State
24
 
25
  from my_env.models import MetaOptimizerAction, MetaOptimizerObservation
26
- from .tasks import TRAIN_TASK_IDS, get_task, task_spec_from_dict, TaskSpec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  # Defaults
29
  LOSS_THRESHOLD = 0.1
@@ -31,6 +48,8 @@ MAX_STEPS = 100
31
  BATCH_SIZE = 32
32
  # Dense reward scale: reward += DENSE_REWARD_SCALE * (prev_loss - current_loss) each step (potential-based, helps credit assignment)
33
  DENSE_REWARD_SCALE = 0.2
 
 
34
 
35
 
36
  def _default_device() -> torch.device:
@@ -60,6 +79,35 @@ def _get_batch(spec: TaskSpec, step: int, device: torch.device):
60
  return X, y
61
 
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  def run_adam_baseline(
64
  task_id: Optional[int] = None,
65
  task_spec: Optional[Dict[str, Any]] = None,
@@ -78,6 +126,8 @@ def run_adam_baseline(
78
  torch.manual_seed(seed)
79
  device = _default_device()
80
  spec = task_spec_from_dict(task_spec) if task_spec is not None else get_task(task_id)
 
 
81
  model = _build_model(spec).to(device)
82
  opt = torch.optim.Adam(model.parameters(), lr=lr)
83
  loss_trajectory: List[float] = []
@@ -128,6 +178,8 @@ def run_sgd_baseline(
128
  torch.manual_seed(seed)
129
  device = _default_device()
130
  spec = task_spec_from_dict(task_spec) if task_spec is not None else get_task(task_id)
 
 
131
  model = _build_model(spec).to(device)
132
  opt = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)
133
  loss_trajectory = []
@@ -159,6 +211,67 @@ def run_sgd_baseline(
159
  }
160
 
161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  def run_meta_optimizer_trajectory(
163
  task_id: Optional[int] = None,
164
  task_spec: Optional[Dict[str, Any]] = None,
@@ -226,7 +339,7 @@ class MetaOptimizerEnvironment(Environment[MetaOptimizerAction, MetaOptimizerObs
226
  self._device = _default_device()
227
 
228
  # Episode state (set in reset)
229
- self._task_spec: Optional[TaskSpec] = None
230
  self._model: Optional[nn.Module] = None
231
  self._velocities: Optional[List[torch.Tensor]] = None
232
  self._step_count: int = 0
@@ -235,6 +348,9 @@ class MetaOptimizerEnvironment(Environment[MetaOptimizerAction, MetaOptimizerObs
235
  self._steps_to_threshold: Optional[int] = None
236
  self._action_log: List[Dict[str, Any]] = []
237
  self._episode_id: Optional[str] = None
 
 
 
238
 
239
  def reset(
240
  self,
@@ -250,23 +366,44 @@ class MetaOptimizerEnvironment(Environment[MetaOptimizerAction, MetaOptimizerObs
250
  if task_spec is not None:
251
  self._task_spec = task_spec_from_dict(task_spec)
252
  else:
253
- tid = task_id if task_id is not None else random.choice(TRAIN_TASK_IDS)
254
- self._task_spec = get_task(tid)
255
- self._model = _build_model(self._task_spec).to(self._device)
256
- self._velocities = [torch.zeros_like(p) for p in self._model.parameters()]
257
- self._step_count = 0
258
- self._steps_to_threshold = None
259
- self._action_log = []
260
- self._episode_id = episode_id or str(uuid4())
261
-
262
- # Initial loss (no update yet)
263
- X, y = _get_batch(self._task_spec, 0, self._device)
264
- with torch.no_grad():
265
- out = self._model(X)
266
- self._current_loss = nn.functional.mse_loss(out, y).item()
267
- self._prev_loss = self._current_loss
268
-
269
- return self._observation(reward=None, grad_norm=None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
 
271
  def step(
272
  self,
@@ -289,10 +426,20 @@ class MetaOptimizerEnvironment(Environment[MetaOptimizerAction, MetaOptimizerObs
289
  "weight_decay_this_step": wd,
290
  })
291
 
292
- X, y = _get_batch(self._task_spec, self._step_count + 1, self._device)
293
- self._model.train()
294
- out = self._model(X)
295
- loss = nn.functional.mse_loss(out, y)
 
 
 
 
 
 
 
 
 
 
296
  self._model.zero_grad()
297
  loss.backward()
298
 
@@ -314,19 +461,32 @@ class MetaOptimizerEnvironment(Environment[MetaOptimizerAction, MetaOptimizerObs
314
  if wd > 0:
315
  p.sub_(p, alpha=wd)
316
 
317
- with torch.no_grad():
318
- new_out = self._model(X)
319
- self._current_loss = nn.functional.mse_loss(new_out, y).item()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
 
321
  self._step_count += 1
322
- if self._steps_to_threshold is None and self._current_loss < self.loss_threshold:
323
  self._steps_to_threshold = self._step_count
324
 
325
- # Dense reward: reward loss decrease (potential-based shaping, does not change optimal policy)
326
  dense_reward = DENSE_REWARD_SCALE * (prev_loss - self._current_loss)
327
  self._prev_loss = self._current_loss
328
 
329
- # End episode when we hit max_steps or when loss first crosses threshold (early termination)
330
  done = self._step_count >= self.max_steps or self._steps_to_threshold is not None
331
  if done:
332
  terminal = -(self._steps_to_threshold if self._steps_to_threshold is not None else self.max_steps)
@@ -334,13 +494,14 @@ class MetaOptimizerEnvironment(Environment[MetaOptimizerAction, MetaOptimizerObs
334
  else:
335
  reward = dense_reward
336
 
337
- return self._observation(reward=reward, grad_norm=grad_norm, done=done)
338
 
339
  def _observation(
340
  self,
341
  reward: Optional[float] = None,
342
  grad_norm: Optional[float] = None,
343
  done: bool = False,
 
344
  ) -> MetaOptimizerObservation:
345
  meta: Dict[str, Any] = {}
346
  if self._steps_to_threshold is not None:
@@ -355,6 +516,7 @@ class MetaOptimizerEnvironment(Environment[MetaOptimizerAction, MetaOptimizerObs
355
  done=done,
356
  reward=reward,
357
  metadata=meta,
 
358
  )
359
 
360
  @property
@@ -367,3 +529,23 @@ class MetaOptimizerEnvironment(Environment[MetaOptimizerAction, MetaOptimizerObs
367
  def get_episode_action_log(self) -> List[Dict[str, Any]]:
368
  """Return the action log for the current episode (for in-process viz or eval)."""
369
  return list(self._action_log)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  # LICENSE file in the root directory of this source tree.
6
 
7
  """
8
+ Meta-optimizer environment: train an RL agent to act as an optimizer on inner tasks.
9
 
10
+ Supports (1) SLM: next-token prediction with a tiny transformer; (2) sinusoid regression.
11
+ Rich action space (LR, momentum, grad clip, weight decay), convergence-speed reward.
12
+ Action log and loss/perplexity are exposed for dashboard visualization.
13
  """
14
 
15
  import math
16
  import random
17
+ from dataclasses import asdict
18
+ from typing import Any, Dict, List, Optional, Union
19
  from uuid import uuid4
20
 
21
  import torch
 
25
  from openenv.core.env_server.types import State
26
 
27
  from my_env.models import MetaOptimizerAction, MetaOptimizerObservation
28
+ from .tasks import (
29
+ DEFAULT_CORPUS,
30
+ SLM_TRAIN_TASK_IDS,
31
+ TRAIN_TASK_IDS,
32
+ get_slm_task,
33
+ get_task,
34
+ task_spec_from_dict,
35
+ TaskSpec,
36
+ SLMTaskSpec,
37
+ )
38
+ from .slm_model import (
39
+ TinyLM,
40
+ build_vocab,
41
+ get_corpus_tensor,
42
+ sample_batch_slm,
43
+ )
44
 
45
  # Defaults
46
  LOSS_THRESHOLD = 0.1
 
48
  BATCH_SIZE = 32
49
  # Dense reward scale: reward += DENSE_REWARD_SCALE * (prev_loss - current_loss) each step (potential-based, helps credit assignment)
50
  DENSE_REWARD_SCALE = 0.2
51
+ # SLM loss threshold (cross-entropy); early termination when loss < this
52
+ SLM_LOSS_THRESHOLD = 1.5
53
 
54
 
55
  def _default_device() -> torch.device:
 
79
  return X, y
80
 
81
 
82
+ def _build_slm(spec: SLMTaskSpec) -> TinyLM:
83
+ """Build a tiny decoder-only transformer for the given SLM task spec."""
84
+ torch.manual_seed(spec.arch_seed)
85
+ return TinyLM(
86
+ vocab_size=spec.vocab_size,
87
+ context_len=spec.context_len,
88
+ n_layer=spec.n_layer,
89
+ n_head=spec.n_head,
90
+ n_embd=spec.n_embd,
91
+ )
92
+
93
+
94
+ def _get_batch_slm(
95
+ spec: SLMTaskSpec,
96
+ step: int,
97
+ device: torch.device,
98
+ corpus_ids: torch.Tensor,
99
+ ) -> tuple[torch.Tensor, torch.Tensor]:
100
+ """Sample a batch for next-token prediction. Returns input_ids [B,T], target_ids [B,T]."""
101
+ return sample_batch_slm(
102
+ corpus_ids,
103
+ BATCH_SIZE,
104
+ spec.context_len,
105
+ step,
106
+ spec.data_seed,
107
+ device,
108
+ )
109
+
110
+
111
  def run_adam_baseline(
112
  task_id: Optional[int] = None,
113
  task_spec: Optional[Dict[str, Any]] = None,
 
126
  torch.manual_seed(seed)
127
  device = _default_device()
128
  spec = task_spec_from_dict(task_spec) if task_spec is not None else get_task(task_id)
129
+ if isinstance(spec, SLMTaskSpec):
130
+ raise ValueError("Use run_adamw_baseline for SLM tasks")
131
  model = _build_model(spec).to(device)
132
  opt = torch.optim.Adam(model.parameters(), lr=lr)
133
  loss_trajectory: List[float] = []
 
178
  torch.manual_seed(seed)
179
  device = _default_device()
180
  spec = task_spec_from_dict(task_spec) if task_spec is not None else get_task(task_id)
181
+ if isinstance(spec, SLMTaskSpec):
182
+ raise ValueError("Use run_adamw_baseline for SLM tasks")
183
  model = _build_model(spec).to(device)
184
  opt = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)
185
  loss_trajectory = []
 
211
  }
212
 
213
 
214
+ def run_adamw_baseline(
215
+ task_id: Optional[int] = None,
216
+ task_spec: Optional[Dict[str, Any]] = None,
217
+ max_steps: int = MAX_STEPS,
218
+ loss_threshold: float = SLM_LOSS_THRESHOLD,
219
+ lr: float = 1e-3,
220
+ weight_decay: float = 0.01,
221
+ betas: tuple[float, float] = (0.9, 0.999),
222
+ seed: Optional[int] = None,
223
+ return_metrics: bool = False,
224
+ ):
225
+ """
226
+ Run AdamW on one SLM task. Returns steps to threshold, or full metrics dict if return_metrics=True.
227
+ """
228
+ if (task_id is None) == (task_spec is None):
229
+ raise ValueError("Provide exactly one of task_id or task_spec")
230
+ if seed is not None:
231
+ torch.manual_seed(seed)
232
+ device = _default_device()
233
+ spec = task_spec_from_dict(task_spec) if task_spec is not None else get_slm_task(task_id)
234
+ if isinstance(spec, TaskSpec):
235
+ raise ValueError("Use run_adam_baseline or run_sgd_baseline for sinusoid tasks")
236
+ assert isinstance(spec, SLMTaskSpec)
237
+ model = _build_slm(spec).to(device)
238
+ opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay, betas=betas)
239
+ char2idx, _ = build_vocab()
240
+ corpus_ids = get_corpus_tensor(DEFAULT_CORPUS, char2idx, device)
241
+ loss_trajectory: List[float] = []
242
+ steps_to_threshold: Optional[int] = None
243
+ for step in range(max_steps):
244
+ inp, tgt = _get_batch_slm(spec, step, device, corpus_ids)
245
+ model.train()
246
+ opt.zero_grad()
247
+ logits = model(inp)
248
+ loss = nn.functional.cross_entropy(logits.view(-1, spec.vocab_size), tgt.view(-1))
249
+ loss.backward()
250
+ opt.step()
251
+ with torch.no_grad():
252
+ L = nn.functional.cross_entropy(
253
+ model(inp).view(-1, spec.vocab_size), tgt.view(-1)
254
+ ).item()
255
+ loss_trajectory.append(L)
256
+ if steps_to_threshold is None and L < loss_threshold:
257
+ steps_to_threshold = step + 1
258
+ final_loss = loss_trajectory[-1] if loss_trajectory else float("inf")
259
+ perplexity = math.exp(min(final_loss, 20.0))
260
+ if not return_metrics:
261
+ return steps_to_threshold if steps_to_threshold is not None else max_steps
262
+ last_k = min(10, len(loss_trajectory))
263
+ mean_last_k = sum(loss_trajectory[-last_k:]) / last_k if loss_trajectory else final_loss
264
+ return {
265
+ "steps_to_threshold": steps_to_threshold if steps_to_threshold is not None else max_steps,
266
+ "success": steps_to_threshold is not None,
267
+ "final_loss": final_loss,
268
+ "perplexity": perplexity,
269
+ "mean_last_10_loss": mean_last_k,
270
+ "loss_auc": sum(loss_trajectory) / len(loss_trajectory) if loss_trajectory else final_loss,
271
+ "loss_trajectory": loss_trajectory,
272
+ }
273
+
274
+
275
  def run_meta_optimizer_trajectory(
276
  task_id: Optional[int] = None,
277
  task_spec: Optional[Dict[str, Any]] = None,
 
339
  self._device = _default_device()
340
 
341
  # Episode state (set in reset)
342
+ self._task_spec: Optional[Union[TaskSpec, SLMTaskSpec]] = None
343
  self._model: Optional[nn.Module] = None
344
  self._velocities: Optional[List[torch.Tensor]] = None
345
  self._step_count: int = 0
 
348
  self._steps_to_threshold: Optional[int] = None
349
  self._action_log: List[Dict[str, Any]] = []
350
  self._episode_id: Optional[str] = None
351
+ self._corpus_ids: Optional[torch.Tensor] = None # for SLM only
352
+ self._is_slm: bool = False
353
+ self._slm_loss_threshold: float = SLM_LOSS_THRESHOLD
354
 
355
  def reset(
356
  self,
 
366
  if task_spec is not None:
367
  self._task_spec = task_spec_from_dict(task_spec)
368
  else:
369
+ tid = task_id if task_id is not None else random.choice(SLM_TRAIN_TASK_IDS)
370
+ self._task_spec = get_slm_task(tid)
371
+ self._is_slm = isinstance(self._task_spec, SLMTaskSpec)
372
+
373
+ if self._is_slm:
374
+ spec = self._task_spec
375
+ assert isinstance(spec, SLMTaskSpec)
376
+ self._model = _build_slm(spec).to(self._device)
377
+ self._velocities = [torch.zeros_like(p) for p in self._model.parameters()]
378
+ char2idx, _ = build_vocab()
379
+ self._corpus_ids = get_corpus_tensor(DEFAULT_CORPUS, char2idx, self._device)
380
+ self._step_count = 0
381
+ self._steps_to_threshold = None
382
+ self._action_log = []
383
+ self._episode_id = episode_id or str(uuid4())
384
+ inp, tgt = _get_batch_slm(spec, 0, self._device, self._corpus_ids)
385
+ with torch.no_grad():
386
+ logits = self._model(inp)
387
+ self._current_loss = nn.functional.cross_entropy(
388
+ logits.view(-1, spec.vocab_size), tgt.view(-1)
389
+ ).item()
390
+ self._prev_loss = self._current_loss
391
+ return self._observation(reward=None, grad_norm=None, perplexity=math.exp(min(self._current_loss, 20.0)))
392
+ else:
393
+ spec = self._task_spec
394
+ assert isinstance(spec, TaskSpec)
395
+ self._model = _build_model(spec).to(self._device)
396
+ self._velocities = [torch.zeros_like(p) for p in self._model.parameters()]
397
+ self._step_count = 0
398
+ self._steps_to_threshold = None
399
+ self._action_log = []
400
+ self._episode_id = episode_id or str(uuid4())
401
+ X, y = _get_batch(spec, 0, self._device)
402
+ with torch.no_grad():
403
+ out = self._model(X)
404
+ self._current_loss = nn.functional.mse_loss(out, y).item()
405
+ self._prev_loss = self._current_loss
406
+ return self._observation(reward=None, grad_norm=None)
407
 
408
  def step(
409
  self,
 
426
  "weight_decay_this_step": wd,
427
  })
428
 
429
+ if self._is_slm:
430
+ spec = self._task_spec
431
+ assert isinstance(spec, SLMTaskSpec)
432
+ inp, tgt = _get_batch_slm(spec, self._step_count + 1, self._device, self._corpus_ids)
433
+ self._model.train()
434
+ logits = self._model(inp)
435
+ loss = nn.functional.cross_entropy(logits.view(-1, spec.vocab_size), tgt.view(-1))
436
+ else:
437
+ spec = self._task_spec
438
+ assert isinstance(spec, TaskSpec)
439
+ X, y = _get_batch(spec, self._step_count + 1, self._device)
440
+ self._model.train()
441
+ loss = nn.functional.mse_loss(self._model(X), y)
442
+
443
  self._model.zero_grad()
444
  loss.backward()
445
 
 
461
  if wd > 0:
462
  p.sub_(p, alpha=wd)
463
 
464
+ if self._is_slm:
465
+ spec = self._task_spec
466
+ assert isinstance(spec, SLMTaskSpec)
467
+ with torch.no_grad():
468
+ logits = self._model(inp)
469
+ self._current_loss = nn.functional.cross_entropy(
470
+ logits.view(-1, spec.vocab_size), tgt.view(-1)
471
+ ).item()
472
+ loss_threshold = self._slm_loss_threshold
473
+ perp = math.exp(min(self._current_loss, 20.0))
474
+ else:
475
+ spec = self._task_spec
476
+ assert isinstance(spec, TaskSpec)
477
+ with torch.no_grad():
478
+ X, y = _get_batch(spec, self._step_count + 1, self._device)
479
+ self._current_loss = nn.functional.mse_loss(self._model(X), y).item()
480
+ loss_threshold = self.loss_threshold
481
+ perp = None
482
 
483
  self._step_count += 1
484
+ if self._steps_to_threshold is None and self._current_loss < loss_threshold:
485
  self._steps_to_threshold = self._step_count
486
 
 
487
  dense_reward = DENSE_REWARD_SCALE * (prev_loss - self._current_loss)
488
  self._prev_loss = self._current_loss
489
 
 
490
  done = self._step_count >= self.max_steps or self._steps_to_threshold is not None
491
  if done:
492
  terminal = -(self._steps_to_threshold if self._steps_to_threshold is not None else self.max_steps)
 
494
  else:
495
  reward = dense_reward
496
 
497
+ return self._observation(reward=reward, grad_norm=grad_norm, done=done, perplexity=perp)
498
 
499
  def _observation(
500
  self,
501
  reward: Optional[float] = None,
502
  grad_norm: Optional[float] = None,
503
  done: bool = False,
504
+ perplexity: Optional[float] = None,
505
  ) -> MetaOptimizerObservation:
506
  meta: Dict[str, Any] = {}
507
  if self._steps_to_threshold is not None:
 
516
  done=done,
517
  reward=reward,
518
  metadata=meta,
519
+ perplexity=perplexity,
520
  )
521
 
522
  @property
 
529
  def get_episode_action_log(self) -> List[Dict[str, Any]]:
530
  """Return the action log for the current episode (for in-process viz or eval)."""
531
  return list(self._action_log)
532
+
533
+ def get_current_task_spec(self) -> Optional[Dict[str, Any]]:
534
+ """Return current task spec as a dict for dashboard / run-baseline. None if no episode started."""
535
+ if self._task_spec is None:
536
+ return None
537
+ if isinstance(self._task_spec, SLMTaskSpec):
538
+ return {"type": "slm", **asdict(self._task_spec)}
539
+ return {"type": "sinusoid", **asdict(self._task_spec)}
540
+
541
+ def run_baseline(self) -> Dict[str, Any]:
542
+ """Run the appropriate baseline (AdamW for SLM, Adam for sinusoid) for current task. Returns loss_trajectory and steps."""
543
+ spec_dict = self.get_current_task_spec()
544
+ if spec_dict is None:
545
+ return {"loss_trajectory": [], "steps": [], "error": "No task"}
546
+ if spec_dict.get("type") == "slm":
547
+ result = run_adamw_baseline(task_spec=spec_dict, max_steps=self.max_steps, return_metrics=True)
548
+ else:
549
+ result = run_adam_baseline(task_spec=spec_dict, max_steps=self.max_steps, return_metrics=True)
550
+ traj = result.get("loss_trajectory", [])
551
+ return {"loss_trajectory": traj, "steps": list(range(len(traj)))}
server/slm_model.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Tiny decoder-only transformer for SLM meta-optimizer inner task.
9
+ Pure PyTorch, no transformers dependency.
10
+ """
11
+
12
+ import math
13
+ from typing import Tuple
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+
18
+
19
+ # Fixed character vocab for reproducible SLM tasks (subset of printable ASCII)
20
+ DEFAULT_CHARS = (
21
+ " \n\t"
22
+ "abcdefghijklmnopqrstuvwxyz"
23
+ "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
24
+ "0123456789"
25
+ ".,;:!?'\"-()"
26
+ )
27
+ DEFAULT_VOCAB_SIZE = len(DEFAULT_CHARS)
28
+
29
+
30
+ def build_vocab(chars: str = DEFAULT_CHARS) -> Tuple[dict, dict]:
31
+ """Return char2idx and idx2char dicts."""
32
+ char2idx = {c: i for i, c in enumerate(chars)}
33
+ idx2char = {i: c for c, i in char2idx.items()}
34
+ return char2idx, idx2char
35
+
36
+
37
+ def encode_corpus(text: str, char2idx: dict, default_idx: int = 0) -> torch.Tensor:
38
+ """Encode string to long tensor of token ids. Unknown chars map to default_idx."""
39
+ ids = [char2idx.get(c, default_idx) for c in text]
40
+ return torch.tensor(ids, dtype=torch.long)
41
+
42
+
43
+ def get_corpus_tensor(
44
+ text: str,
45
+ char2idx: dict,
46
+ device: torch.device,
47
+ ) -> torch.Tensor:
48
+ """Return 1D long tensor of token ids on device."""
49
+ t = encode_corpus(text, char2idx)
50
+ return t.to(device)
51
+
52
+
53
+ def sample_batch_slm(
54
+ corpus_ids: torch.Tensor,
55
+ batch_size: int,
56
+ context_len: int,
57
+ step: int,
58
+ data_seed: int,
59
+ device: torch.device,
60
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
61
+ """
62
+ Sample batch_size contiguous chunks from corpus for next-token prediction.
63
+ Returns: input_ids [B, context_len], target_ids [B, context_len] (target = input shifted by 1).
64
+ """
65
+ L = corpus_ids.size(0)
66
+ if L <= context_len + 1:
67
+ raise ValueError("Corpus too short for context_len")
68
+ max_start = L - context_len - 1
69
+ g = torch.Generator(device=device)
70
+ g.manual_seed(data_seed + step)
71
+ starts = torch.randint(0, max_start, (batch_size,), device=device, generator=g)
72
+ inputs = []
73
+ targets = []
74
+ for b in range(batch_size):
75
+ s = int(starts[b].item())
76
+ chunk = corpus_ids[s : s + context_len + 1]
77
+ inputs.append(chunk[:context_len])
78
+ targets.append(chunk[1 : context_len + 1])
79
+ return torch.stack(inputs), torch.stack(targets)
80
+
81
+
82
+ class CausalSelfAttention(nn.Module):
83
+ def __init__(self, n_embd: int, n_head: int, block_size: int):
84
+ super().__init__()
85
+ assert n_embd % n_head == 0
86
+ self.n_head = n_head
87
+ self.n_embd = n_embd
88
+ self.head_dim = n_embd // n_head
89
+ self.register_buffer(
90
+ "mask",
91
+ torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size),
92
+ )
93
+ self.c_attn = nn.Linear(n_embd, 3 * n_embd)
94
+ self.c_proj = nn.Linear(n_embd, n_embd)
95
+
96
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
97
+ B, T, C = x.shape
98
+ qkv = self.c_attn(x)
99
+ q, k, v = qkv.split(self.n_embd, dim=2)
100
+ q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
101
+ k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
102
+ v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
103
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))
104
+ att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
105
+ att = torch.softmax(att, dim=-1)
106
+ out = (att @ v).transpose(1, 2).contiguous().view(B, T, C)
107
+ return self.c_proj(out)
108
+
109
+
110
+ class Block(nn.Module):
111
+ def __init__(self, n_embd: int, n_head: int, block_size: int):
112
+ super().__init__()
113
+ self.ln1 = nn.LayerNorm(n_embd)
114
+ self.attn = CausalSelfAttention(n_embd, n_head, block_size)
115
+ self.ln2 = nn.LayerNorm(n_embd)
116
+ self.mlp = nn.Sequential(
117
+ nn.Linear(n_embd, 4 * n_embd),
118
+ nn.GELU(),
119
+ nn.Linear(4 * n_embd, n_embd),
120
+ )
121
+
122
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
123
+ x = x + self.attn(self.ln1(x))
124
+ x = x + self.mlp(self.ln2(x))
125
+ return x
126
+
127
+
128
+ class TinyLM(nn.Module):
129
+ """Decoder-only transformer for next-token prediction."""
130
+
131
+ def __init__(
132
+ self,
133
+ vocab_size: int,
134
+ context_len: int,
135
+ n_layer: int,
136
+ n_head: int,
137
+ n_embd: int,
138
+ ):
139
+ super().__init__()
140
+ self.context_len = context_len
141
+ self.vocab_size = vocab_size
142
+ self.token_embed = nn.Embedding(vocab_size, n_embd)
143
+ self.pos_embed = nn.Embedding(context_len, n_embd)
144
+ self.blocks = nn.ModuleList(
145
+ [Block(n_embd, n_head, context_len) for _ in range(n_layer)]
146
+ )
147
+ self.ln_f = nn.LayerNorm(n_embd)
148
+ self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
149
+
150
+ def forward(self, idx: torch.Tensor) -> torch.Tensor:
151
+ # idx: [B, T]; clamp to valid range in case of encoding drift
152
+ idx = idx.clamp(0, self.vocab_size - 1)
153
+ B, T = idx.shape
154
+ pos = torch.arange(0, T, device=idx.device, dtype=torch.long)
155
+ x = self.token_embed(idx) + self.pos_embed(pos)
156
+ for block in self.blocks:
157
+ x = block(x)
158
+ x = self.ln_f(x)
159
+ logits = self.lm_head(x)
160
+ return logits
server/tasks.py CHANGED
@@ -9,6 +9,7 @@ Task registry for meta-learning.
9
 
10
  Tasks can be from the internal registry (get_task(task_id)) or provided from outside
11
  via task_spec_from_dict() — the client sends the task definition to the environment.
 
12
  """
13
 
14
  from dataclasses import dataclass
@@ -16,12 +17,27 @@ from typing import Any, Dict, List
16
 
17
  import math
18
 
 
 
19
  # Distribution A: 50 training tasks (low-freq sinusoids)
20
  TRAIN_TASK_IDS: List[int] = list(range(50))
21
 
22
  # Distribution B: held-out eval tasks (high-freq sinusoids — different distribution)
23
  EVAL_TASK_IDS: List[int] = [50, 51]
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  # Bounds for each distribution (freq, amplitude, phase)
26
  DIST_A_FREQ = (1.0, 3.0)
27
  DIST_A_AMP = (0.5, 2.0)
@@ -46,6 +62,22 @@ class TaskSpec:
46
  distribution: str # "A" or "B"
47
 
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  def get_task(task_id: int) -> TaskSpec:
50
  """
51
  Return the task spec for the given task_id.
@@ -88,28 +120,73 @@ def get_task(task_id: int) -> TaskSpec:
88
  )
89
 
90
 
91
- def task_spec_from_dict(d: Dict[str, Any]) -> TaskSpec:
92
  """
93
- Build a TaskSpec from an external dict (sent by the client).
94
- The task is defined outside the env; we just parse it here.
95
-
96
- Expected keys for type "sinusoid":
97
- type="sinusoid", amplitude, freq, phase, data_seed (optional), arch_seed (optional),
98
- input_dim (optional, default 1), hidden_dim (optional, default 32), task_id (optional).
99
  """
100
- task_type = d.get("type", "sinusoid")
101
- if task_type != "sinusoid":
102
- raise ValueError(f"Unknown task type: {task_type}")
103
- task_id = d.get("task_id", 0)
104
- return TaskSpec(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  task_id=task_id,
106
- input_dim=int(d.get("input_dim", 1)),
107
- hidden_dim=int(d.get("hidden_dim", 32)),
108
- output_dim=1,
109
  data_seed=int(d.get("data_seed", task_id * 31337)),
110
  arch_seed=int(d.get("arch_seed", task_id * 131 + 7)),
111
- amplitude=float(d["amplitude"]),
112
- freq=float(d["freq"]),
113
- phase=float(d["phase"]),
 
 
 
114
  distribution=d.get("distribution", "external"),
115
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  Tasks can be from the internal registry (get_task(task_id)) or provided from outside
11
  via task_spec_from_dict() — the client sends the task definition to the environment.
12
+ Supports sinusoid (regression) and SLM (next-token prediction) task types.
13
  """
14
 
15
  from dataclasses import dataclass
 
17
 
18
  import math
19
 
20
+ from .slm_model import DEFAULT_VOCAB_SIZE as SLM_DEFAULT_VOCAB_SIZE
21
+
22
  # Distribution A: 50 training tasks (low-freq sinusoids)
23
  TRAIN_TASK_IDS: List[int] = list(range(50))
24
 
25
  # Distribution B: held-out eval tasks (high-freq sinusoids — different distribution)
26
  EVAL_TASK_IDS: List[int] = [50, 51]
27
 
28
+ # SLM: 50 train tasks, 2 eval (different corpus split or seed range)
29
+ SLM_TRAIN_TASK_IDS: List[int] = list(range(50))
30
+ SLM_EVAL_TASK_IDS: List[int] = [50, 51]
31
+
32
+ # Fixed small corpus for SLM (character-level). ~10KB so tasks are reproducible.
33
+ DEFAULT_CORPUS: str = (
34
+ "The quick brown fox jumps over the lazy dog. "
35
+ "Pack my box with five dozen liquor jugs. "
36
+ "How vexingly quick daft zebras jump. "
37
+ "Sphinx of black quartz, judge my vow. "
38
+ "The five boxing wizards jump quickly. "
39
+ ) * 200 # repeat to get enough length for sampling
40
+
41
  # Bounds for each distribution (freq, amplitude, phase)
42
  DIST_A_FREQ = (1.0, 3.0)
43
  DIST_A_AMP = (0.5, 2.0)
 
62
  distribution: str # "A" or "B"
63
 
64
 
65
+ @dataclass
66
+ class SLMTaskSpec:
67
+ """Spec for one SLM (next-token prediction) task."""
68
+
69
+ task_id: int
70
+ data_seed: int
71
+ arch_seed: int
72
+ vocab_size: int
73
+ context_len: int # block size
74
+ n_layer: int
75
+ n_head: int
76
+ n_embd: int
77
+ corpus_id: str # e.g. "default"
78
+ distribution: str # "A" or "B" or "external"
79
+
80
+
81
  def get_task(task_id: int) -> TaskSpec:
82
  """
83
  Return the task spec for the given task_id.
 
120
  )
121
 
122
 
123
+ def get_slm_task(task_id: int) -> SLMTaskSpec:
124
  """
125
+ Return the SLM task spec for the given task_id.
126
+ Task IDs 0..49 = Distribution A (train), 50+ = Distribution B (eval).
 
 
 
 
127
  """
128
+ if task_id < 0:
129
+ raise ValueError(f"task_id must be >= 0, got {task_id}")
130
+ r = task_id * 7919 + 1
131
+ data_seed = task_id * 31337
132
+ arch_seed = task_id * 131 + 7
133
+ if task_id < 50:
134
+ distribution = "A"
135
+ else:
136
+ distribution = "B"
137
+ return SLMTaskSpec(
138
+ task_id=task_id,
139
+ data_seed=data_seed,
140
+ arch_seed=arch_seed,
141
+ vocab_size=SLM_DEFAULT_VOCAB_SIZE,
142
+ context_len=64,
143
+ n_layer=2,
144
+ n_head=4,
145
+ n_embd=128,
146
+ corpus_id="default",
147
+ distribution=distribution,
148
+ )
149
+
150
+
151
+ def slm_task_spec_from_dict(d: Dict[str, Any]) -> SLMTaskSpec:
152
+ """Build an SLMTaskSpec from an external dict (type='slm')."""
153
+ task_id = int(d.get("task_id", 0))
154
+ return SLMTaskSpec(
155
  task_id=task_id,
 
 
 
156
  data_seed=int(d.get("data_seed", task_id * 31337)),
157
  arch_seed=int(d.get("arch_seed", task_id * 131 + 7)),
158
+ vocab_size=int(d.get("vocab_size", SLM_DEFAULT_VOCAB_SIZE)),
159
+ context_len=int(d.get("context_len", 64)),
160
+ n_layer=int(d.get("n_layer", 2)),
161
+ n_head=int(d.get("n_head", 4)),
162
+ n_embd=int(d.get("n_embd", 128)),
163
+ corpus_id=str(d.get("corpus_id", "default")),
164
  distribution=d.get("distribution", "external"),
165
  )
166
+
167
+
168
+ def task_spec_from_dict(d: Dict[str, Any]) -> TaskSpec | SLMTaskSpec:
169
+ """
170
+ Build a TaskSpec or SLMTaskSpec from an external dict (sent by the client).
171
+
172
+ For type "sinusoid": amplitude, freq, phase, data_seed (optional), arch_seed (optional), etc.
173
+ For type "slm": vocab_size, context_len, n_layer, n_head, n_embd (all optional with defaults).
174
+ """
175
+ task_type = d.get("type", "slm")
176
+ if task_type == "sinusoid":
177
+ task_id = d.get("task_id", 0)
178
+ return TaskSpec(
179
+ task_id=task_id,
180
+ input_dim=int(d.get("input_dim", 1)),
181
+ hidden_dim=int(d.get("hidden_dim", 32)),
182
+ output_dim=1,
183
+ data_seed=int(d.get("data_seed", task_id * 31337)),
184
+ arch_seed=int(d.get("arch_seed", task_id * 131 + 7)),
185
+ amplitude=float(d["amplitude"]),
186
+ freq=float(d["freq"]),
187
+ phase=float(d["phase"]),
188
+ distribution=d.get("distribution", "external"),
189
+ )
190
+ if task_type == "slm":
191
+ return slm_task_spec_from_dict(d)
192
+ raise ValueError(f"Unknown task type: {task_type!r}")