# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """ Task registry for meta-learning. Tasks can be from the internal registry (get_task(task_id)) or provided from outside via task_spec_from_dict() — the client sends the task definition to the environment. """ from dataclasses import dataclass from typing import Any, Dict, List import math # Distribution A: 50 training tasks (low-freq sinusoids) TRAIN_TASK_IDS: List[int] = list(range(50)) # Distribution B: held-out eval tasks (high-freq sinusoids — different distribution) EVAL_TASK_IDS: List[int] = [50, 51] # Bounds for each distribution (freq, amplitude, phase) DIST_A_FREQ = (1.0, 3.0) DIST_A_AMP = (0.5, 2.0) DIST_B_FREQ = (4.0, 6.0) DIST_B_AMP = (0.3, 1.5) @dataclass class TaskSpec: """Spec for one sinusoidal regression task.""" task_id: int input_dim: int # 1 for scalar sinusoid input hidden_dim: int output_dim: int data_seed: int arch_seed: int # Sinusoidal target: y = amplitude * sin(2*pi*freq*x + phase) amplitude: float freq: float phase: float distribution: str # "A" or "B" def get_task(task_id: int) -> TaskSpec: """ Return the task spec for the given task_id. Task IDs 0..49 = Distribution A (train), 50+ = Distribution B (eval). """ if task_id < 0: raise ValueError(f"task_id must be >= 0, got {task_id}") r = task_id * 7919 + 1 data_seed = task_id * 31337 arch_seed = task_id * 131 + 7 hidden_dim = 32 + (r % 33) if task_id < 50: # Distribution A f_lo, f_hi = DIST_A_FREQ a_lo, a_hi = DIST_A_AMP distribution = "A" else: # Distribution B f_lo, f_hi = DIST_B_FREQ a_lo, a_hi = DIST_B_AMP distribution = "B" # Deterministic but varied per task freq = f_lo + (r % 1000) / 1000.0 * (f_hi - f_lo) amplitude = a_lo + ((r * 3) % 1000) / 1000.0 * (a_hi - a_lo) phase = ((r * 7) % 1000) / 1000.0 * 2 * math.pi return TaskSpec( task_id=task_id, input_dim=1, hidden_dim=hidden_dim, output_dim=1, data_seed=data_seed, arch_seed=arch_seed, amplitude=amplitude, freq=freq, phase=phase, distribution=distribution, ) def task_spec_from_dict(d: Dict[str, Any]) -> TaskSpec: """ Build a TaskSpec from an external dict (sent by the client). The task is defined outside the env; we just parse it here. Expected keys for type "sinusoid": type="sinusoid", amplitude, freq, phase, data_seed (optional), arch_seed (optional), input_dim (optional, default 1), hidden_dim (optional, default 32), task_id (optional). """ task_type = d.get("type", "sinusoid") if task_type != "sinusoid": raise ValueError(f"Unknown task type: {task_type}") task_id = d.get("task_id", 0) return TaskSpec( task_id=task_id, input_dim=int(d.get("input_dim", 1)), hidden_dim=int(d.get("hidden_dim", 32)), output_dim=1, data_seed=int(d.get("data_seed", task_id * 31337)), arch_seed=int(d.get("arch_seed", task_id * 131 + 7)), amplitude=float(d["amplitude"]), freq=float(d["freq"]), phase=float(d["phase"]), distribution=d.get("distribution", "external"), )