| |
| |
| |
| |
| |
|
|
| """ |
| Task registry for meta-learning. |
| |
| Tasks can be from the internal registry (get_task(task_id)) or provided from outside |
| via task_spec_from_dict() — the client sends the task definition to the environment. |
| """ |
|
|
| from dataclasses import dataclass |
| from typing import Any, Dict, List |
|
|
| import math |
|
|
| |
| TRAIN_TASK_IDS: List[int] = list(range(50)) |
|
|
| |
| EVAL_TASK_IDS: List[int] = [50, 51] |
|
|
| |
| DIST_A_FREQ = (1.0, 3.0) |
| DIST_A_AMP = (0.5, 2.0) |
| DIST_B_FREQ = (4.0, 6.0) |
| DIST_B_AMP = (0.3, 1.5) |
|
|
|
|
| @dataclass |
| class TaskSpec: |
| """Spec for one sinusoidal regression task.""" |
|
|
| task_id: int |
| input_dim: int |
| hidden_dim: int |
| output_dim: int |
| data_seed: int |
| arch_seed: int |
| |
| amplitude: float |
| freq: float |
| phase: float |
| distribution: str |
|
|
|
|
| def get_task(task_id: int) -> TaskSpec: |
| """ |
| Return the task spec for the given task_id. |
| Task IDs 0..49 = Distribution A (train), 50+ = Distribution B (eval). |
| """ |
| if task_id < 0: |
| raise ValueError(f"task_id must be >= 0, got {task_id}") |
| r = task_id * 7919 + 1 |
| data_seed = task_id * 31337 |
| arch_seed = task_id * 131 + 7 |
| hidden_dim = 32 + (r % 33) |
|
|
| if task_id < 50: |
| |
| f_lo, f_hi = DIST_A_FREQ |
| a_lo, a_hi = DIST_A_AMP |
| distribution = "A" |
| else: |
| |
| f_lo, f_hi = DIST_B_FREQ |
| a_lo, a_hi = DIST_B_AMP |
| distribution = "B" |
|
|
| |
| freq = f_lo + (r % 1000) / 1000.0 * (f_hi - f_lo) |
| amplitude = a_lo + ((r * 3) % 1000) / 1000.0 * (a_hi - a_lo) |
| phase = ((r * 7) % 1000) / 1000.0 * 2 * math.pi |
|
|
| return TaskSpec( |
| task_id=task_id, |
| input_dim=1, |
| hidden_dim=hidden_dim, |
| output_dim=1, |
| data_seed=data_seed, |
| arch_seed=arch_seed, |
| amplitude=amplitude, |
| freq=freq, |
| phase=phase, |
| distribution=distribution, |
| ) |
|
|
|
|
| def task_spec_from_dict(d: Dict[str, Any]) -> TaskSpec: |
| """ |
| Build a TaskSpec from an external dict (sent by the client). |
| The task is defined outside the env; we just parse it here. |
| |
| Expected keys for type "sinusoid": |
| type="sinusoid", amplitude, freq, phase, data_seed (optional), arch_seed (optional), |
| input_dim (optional, default 1), hidden_dim (optional, default 32), task_id (optional). |
| """ |
| task_type = d.get("type", "sinusoid") |
| if task_type != "sinusoid": |
| raise ValueError(f"Unknown task type: {task_type}") |
| task_id = d.get("task_id", 0) |
| return TaskSpec( |
| task_id=task_id, |
| input_dim=int(d.get("input_dim", 1)), |
| hidden_dim=int(d.get("hidden_dim", 32)), |
| output_dim=1, |
| data_seed=int(d.get("data_seed", task_id * 31337)), |
| arch_seed=int(d.get("arch_seed", task_id * 131 + 7)), |
| amplitude=float(d["amplitude"]), |
| freq=float(d["freq"]), |
| phase=float(d["phase"]), |
| distribution=d.get("distribution", "external"), |
| ) |
|
|