File size: 6,012 Bytes
fc6b5c1 4d2821f fc6b5c1 4d2821f fc6b5c1 4d2821f fc6b5c1 4d2821f fc6b5c1 4d2821f fc6b5c1 4d2821f fc6b5c1 4d2821f fc6b5c1 4d2821f fc6b5c1 4d2821f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 | # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Task registry for meta-learning.
Tasks can be from the internal registry (get_task(task_id)) or provided from outside
via task_spec_from_dict() — the client sends the task definition to the environment.
Supports sinusoid (regression) and SLM (next-token prediction) task types.
"""
from dataclasses import dataclass
from typing import Any, Dict, List
import math
from .slm_model import DEFAULT_VOCAB_SIZE as SLM_DEFAULT_VOCAB_SIZE
# Distribution A: 50 training tasks (low-freq sinusoids)
TRAIN_TASK_IDS: List[int] = list(range(50))
# Distribution B: held-out eval tasks (high-freq sinusoids — different distribution)
EVAL_TASK_IDS: List[int] = [50, 51]
# SLM: 50 train tasks, 2 eval (different corpus split or seed range)
SLM_TRAIN_TASK_IDS: List[int] = list(range(50))
SLM_EVAL_TASK_IDS: List[int] = [50, 51]
# Fixed small corpus for SLM (character-level). ~10KB so tasks are reproducible.
DEFAULT_CORPUS: str = (
"The quick brown fox jumps over the lazy dog. "
"Pack my box with five dozen liquor jugs. "
"How vexingly quick daft zebras jump. "
"Sphinx of black quartz, judge my vow. "
"The five boxing wizards jump quickly. "
) * 200 # repeat to get enough length for sampling
# Bounds for each distribution (freq, amplitude, phase)
DIST_A_FREQ = (1.0, 3.0)
DIST_A_AMP = (0.5, 2.0)
DIST_B_FREQ = (4.0, 6.0)
DIST_B_AMP = (0.3, 1.5)
@dataclass
class TaskSpec:
"""Spec for one sinusoidal regression task."""
task_id: int
input_dim: int # 1 for scalar sinusoid input
hidden_dim: int
output_dim: int
data_seed: int
arch_seed: int
# Sinusoidal target: y = amplitude * sin(2*pi*freq*x + phase)
amplitude: float
freq: float
phase: float
distribution: str # "A" or "B"
@dataclass
class SLMTaskSpec:
"""Spec for one SLM (next-token prediction) task."""
task_id: int
data_seed: int
arch_seed: int
vocab_size: int
context_len: int # block size
n_layer: int
n_head: int
n_embd: int
corpus_id: str # e.g. "default"
distribution: str # "A" or "B" or "external"
def get_task(task_id: int) -> TaskSpec:
"""
Return the task spec for the given task_id.
Task IDs 0..49 = Distribution A (train), 50+ = Distribution B (eval).
"""
if task_id < 0:
raise ValueError(f"task_id must be >= 0, got {task_id}")
r = task_id * 7919 + 1
data_seed = task_id * 31337
arch_seed = task_id * 131 + 7
hidden_dim = 32 + (r % 33)
if task_id < 50:
# Distribution A
f_lo, f_hi = DIST_A_FREQ
a_lo, a_hi = DIST_A_AMP
distribution = "A"
else:
# Distribution B
f_lo, f_hi = DIST_B_FREQ
a_lo, a_hi = DIST_B_AMP
distribution = "B"
# Deterministic but varied per task
freq = f_lo + (r % 1000) / 1000.0 * (f_hi - f_lo)
amplitude = a_lo + ((r * 3) % 1000) / 1000.0 * (a_hi - a_lo)
phase = ((r * 7) % 1000) / 1000.0 * 2 * math.pi
return TaskSpec(
task_id=task_id,
input_dim=1,
hidden_dim=hidden_dim,
output_dim=1,
data_seed=data_seed,
arch_seed=arch_seed,
amplitude=amplitude,
freq=freq,
phase=phase,
distribution=distribution,
)
def get_slm_task(task_id: int) -> SLMTaskSpec:
"""
Return the SLM task spec for the given task_id.
Task IDs 0..49 = Distribution A (train), 50+ = Distribution B (eval).
"""
if task_id < 0:
raise ValueError(f"task_id must be >= 0, got {task_id}")
r = task_id * 7919 + 1
data_seed = task_id * 31337
arch_seed = task_id * 131 + 7
if task_id < 50:
distribution = "A"
else:
distribution = "B"
return SLMTaskSpec(
task_id=task_id,
data_seed=data_seed,
arch_seed=arch_seed,
vocab_size=SLM_DEFAULT_VOCAB_SIZE,
context_len=64,
n_layer=2,
n_head=4,
n_embd=128,
corpus_id="default",
distribution=distribution,
)
def slm_task_spec_from_dict(d: Dict[str, Any]) -> SLMTaskSpec:
"""Build an SLMTaskSpec from an external dict (type='slm')."""
task_id = int(d.get("task_id", 0))
return SLMTaskSpec(
task_id=task_id,
data_seed=int(d.get("data_seed", task_id * 31337)),
arch_seed=int(d.get("arch_seed", task_id * 131 + 7)),
vocab_size=int(d.get("vocab_size", SLM_DEFAULT_VOCAB_SIZE)),
context_len=int(d.get("context_len", 64)),
n_layer=int(d.get("n_layer", 2)),
n_head=int(d.get("n_head", 4)),
n_embd=int(d.get("n_embd", 128)),
corpus_id=str(d.get("corpus_id", "default")),
distribution=d.get("distribution", "external"),
)
def task_spec_from_dict(d: Dict[str, Any]) -> TaskSpec | SLMTaskSpec:
"""
Build a TaskSpec or SLMTaskSpec from an external dict (sent by the client).
For type "sinusoid": amplitude, freq, phase, data_seed (optional), arch_seed (optional), etc.
For type "slm": vocab_size, context_len, n_layer, n_head, n_embd (all optional with defaults).
"""
task_type = d.get("type", "slm")
if task_type == "sinusoid":
task_id = d.get("task_id", 0)
return TaskSpec(
task_id=task_id,
input_dim=int(d.get("input_dim", 1)),
hidden_dim=int(d.get("hidden_dim", 32)),
output_dim=1,
data_seed=int(d.get("data_seed", task_id * 31337)),
arch_seed=int(d.get("arch_seed", task_id * 131 + 7)),
amplitude=float(d["amplitude"]),
freq=float(d["freq"]),
phase=float(d["phase"]),
distribution=d.get("distribution", "external"),
)
if task_type == "slm":
return slm_task_spec_from_dict(d)
raise ValueError(f"Unknown task type: {task_type!r}")
|