File size: 3,473 Bytes
fc6b5c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Task registry for meta-learning.

Tasks can be from the internal registry (get_task(task_id)) or provided from outside
via task_spec_from_dict() — the client sends the task definition to the environment.
"""

from dataclasses import dataclass
from typing import Any, Dict, List

import math

# Distribution A: 50 training tasks (low-freq sinusoids)
TRAIN_TASK_IDS: List[int] = list(range(50))

# Distribution B: held-out eval tasks (high-freq sinusoids — different distribution)
EVAL_TASK_IDS: List[int] = [50, 51]

# Bounds for each distribution (freq, amplitude, phase)
DIST_A_FREQ = (1.0, 3.0)
DIST_A_AMP = (0.5, 2.0)
DIST_B_FREQ = (4.0, 6.0)
DIST_B_AMP = (0.3, 1.5)


@dataclass
class TaskSpec:
    """Spec for one sinusoidal regression task."""

    task_id: int
    input_dim: int  # 1 for scalar sinusoid input
    hidden_dim: int
    output_dim: int
    data_seed: int
    arch_seed: int
    # Sinusoidal target: y = amplitude * sin(2*pi*freq*x + phase)
    amplitude: float
    freq: float
    phase: float
    distribution: str  # "A" or "B"


def get_task(task_id: int) -> TaskSpec:
    """
    Return the task spec for the given task_id.
    Task IDs 0..49 = Distribution A (train), 50+ = Distribution B (eval).
    """
    if task_id < 0:
        raise ValueError(f"task_id must be >= 0, got {task_id}")
    r = task_id * 7919 + 1
    data_seed = task_id * 31337
    arch_seed = task_id * 131 + 7
    hidden_dim = 32 + (r % 33)

    if task_id < 50:
        # Distribution A
        f_lo, f_hi = DIST_A_FREQ
        a_lo, a_hi = DIST_A_AMP
        distribution = "A"
    else:
        # Distribution B
        f_lo, f_hi = DIST_B_FREQ
        a_lo, a_hi = DIST_B_AMP
        distribution = "B"

    # Deterministic but varied per task
    freq = f_lo + (r % 1000) / 1000.0 * (f_hi - f_lo)
    amplitude = a_lo + ((r * 3) % 1000) / 1000.0 * (a_hi - a_lo)
    phase = ((r * 7) % 1000) / 1000.0 * 2 * math.pi

    return TaskSpec(
        task_id=task_id,
        input_dim=1,
        hidden_dim=hidden_dim,
        output_dim=1,
        data_seed=data_seed,
        arch_seed=arch_seed,
        amplitude=amplitude,
        freq=freq,
        phase=phase,
        distribution=distribution,
    )


def task_spec_from_dict(d: Dict[str, Any]) -> TaskSpec:
    """
    Build a TaskSpec from an external dict (sent by the client).
    The task is defined outside the env; we just parse it here.

    Expected keys for type "sinusoid":
      type="sinusoid", amplitude, freq, phase, data_seed (optional), arch_seed (optional),
      input_dim (optional, default 1), hidden_dim (optional, default 32), task_id (optional).
    """
    task_type = d.get("type", "sinusoid")
    if task_type != "sinusoid":
        raise ValueError(f"Unknown task type: {task_type}")
    task_id = d.get("task_id", 0)
    return TaskSpec(
        task_id=task_id,
        input_dim=int(d.get("input_dim", 1)),
        hidden_dim=int(d.get("hidden_dim", 32)),
        output_dim=1,
        data_seed=int(d.get("data_seed", task_id * 31337)),
        arch_seed=int(d.get("arch_seed", task_id * 131 + 7)),
        amplitude=float(d["amplitude"]),
        freq=float(d["freq"]),
        phase=float(d["phase"]),
        distribution=d.get("distribution", "external"),
    )