my_env / server /meta_optimizer_environment.py
SavirD's picture
Upload folder using huggingface_hub
ab2a5b9 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Meta-optimizer environment: train an RL agent to act as an optimizer on random regression tasks.
Supports 50 training tasks, held-out eval, rich action space (LR, momentum, grad clip, weight decay),
and convergence-speed reward. Action log is exposed for emergent-behavior visualization.
"""
import math
import random
from typing import Any, Dict, List, Optional
from uuid import uuid4
import torch
import torch.nn as nn
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State
from my_env.models import MetaOptimizerAction, MetaOptimizerObservation
from .tasks import TRAIN_TASK_IDS, get_task, task_spec_from_dict, TaskSpec
# Defaults
LOSS_THRESHOLD = 0.1
MAX_STEPS = 100
BATCH_SIZE = 32
# Dense reward scale: reward += DENSE_REWARD_SCALE * (prev_loss - current_loss) each step (potential-based, helps credit assignment)
DENSE_REWARD_SCALE = 0.2
def _build_model(spec: TaskSpec) -> nn.Module:
"""Build a 2-layer MLP for the given task spec."""
torch.manual_seed(spec.arch_seed)
return nn.Sequential(
nn.Linear(spec.input_dim, spec.hidden_dim),
nn.ReLU(),
nn.Linear(spec.hidden_dim, spec.output_dim),
)
def _get_batch(spec: TaskSpec, step: int, device: torch.device):
"""Sinusoidal regression: X in [0,1], y = amplitude * sin(2*pi*freq*x + phase) + noise."""
g = torch.Generator(device=device)
g.manual_seed(spec.data_seed + step)
X = torch.rand(BATCH_SIZE, spec.input_dim, device=device, generator=g)
# y = amplitude * sin(2*pi*freq*x + phase); x is first column
x = X[:, 0:1]
y = spec.amplitude * torch.sin(2 * math.pi * spec.freq * x + spec.phase)
y = y + 0.05 * torch.randn_like(y, device=device, generator=g)
return X, y
def run_adam_baseline(
task_id: Optional[int] = None,
task_spec: Optional[Dict[str, Any]] = None,
max_steps: int = MAX_STEPS,
loss_threshold: float = LOSS_THRESHOLD,
lr: float = 1e-2,
seed: Optional[int] = None,
return_metrics: bool = False,
):
"""
Run Adam on one task. Returns steps to threshold, or full metrics dict if return_metrics=True.
"""
if (task_id is None) == (task_spec is None):
raise ValueError("Provide exactly one of task_id or task_spec")
if seed is not None:
torch.manual_seed(seed)
device = torch.device("cpu")
spec = task_spec_from_dict(task_spec) if task_spec is not None else get_task(task_id)
model = _build_model(spec).to(device)
opt = torch.optim.Adam(model.parameters(), lr=lr)
loss_trajectory: List[float] = []
steps_to_threshold: Optional[int] = None
for step in range(max_steps):
X, y = _get_batch(spec, step, device)
model.train()
opt.zero_grad()
loss = nn.functional.mse_loss(model(X), y)
loss.backward()
opt.step()
with torch.no_grad():
L = nn.functional.mse_loss(model(X), y).item()
loss_trajectory.append(L)
if steps_to_threshold is None and L < loss_threshold:
steps_to_threshold = step + 1
final_loss = loss_trajectory[-1] if loss_trajectory else float("inf")
if not return_metrics:
return steps_to_threshold if steps_to_threshold is not None else max_steps
last_k = min(10, len(loss_trajectory))
mean_last_k = sum(loss_trajectory[-last_k:]) / last_k if loss_trajectory else final_loss
return {
"steps_to_threshold": steps_to_threshold if steps_to_threshold is not None else max_steps,
"success": steps_to_threshold is not None,
"final_loss": final_loss,
"mean_last_10_loss": mean_last_k,
"loss_auc": sum(loss_trajectory) / len(loss_trajectory) if loss_trajectory else final_loss,
"loss_trajectory": loss_trajectory,
}
def run_sgd_baseline(
task_id: Optional[int] = None,
task_spec: Optional[Dict[str, Any]] = None,
max_steps: int = MAX_STEPS,
loss_threshold: float = LOSS_THRESHOLD,
lr: float = 1e-2,
momentum: float = 0.9,
seed: Optional[int] = None,
return_metrics: bool = False,
):
"""
Run SGD (with optional momentum) on one task. Returns steps to threshold, or full metrics dict if return_metrics=True.
"""
if (task_id is None) == (task_spec is None):
raise ValueError("Provide exactly one of task_id or task_spec")
if seed is not None:
torch.manual_seed(seed)
device = torch.device("cpu")
spec = task_spec_from_dict(task_spec) if task_spec is not None else get_task(task_id)
model = _build_model(spec).to(device)
opt = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)
loss_trajectory = []
steps_to_threshold = None
for step in range(max_steps):
X, y = _get_batch(spec, step, device)
model.train()
opt.zero_grad()
loss = nn.functional.mse_loss(model(X), y)
loss.backward()
opt.step()
with torch.no_grad():
L = nn.functional.mse_loss(model(X), y).item()
loss_trajectory.append(L)
if steps_to_threshold is None and L < loss_threshold:
steps_to_threshold = step + 1
final_loss = loss_trajectory[-1] if loss_trajectory else float("inf")
if not return_metrics:
return steps_to_threshold if steps_to_threshold is not None else max_steps
last_k = min(10, len(loss_trajectory))
mean_last_k = sum(loss_trajectory[-last_k:]) / last_k if loss_trajectory else final_loss
return {
"steps_to_threshold": steps_to_threshold if steps_to_threshold is not None else max_steps,
"success": steps_to_threshold is not None,
"final_loss": final_loss,
"mean_last_10_loss": mean_last_k,
"loss_auc": sum(loss_trajectory) / len(loss_trajectory) if loss_trajectory else final_loss,
"loss_trajectory": loss_trajectory,
}
def run_meta_optimizer_trajectory(
task_id: Optional[int] = None,
task_spec: Optional[Dict[str, Any]] = None,
max_steps: int = MAX_STEPS,
loss_threshold: float = LOSS_THRESHOLD,
seed: Optional[int] = None,
policy_callable: Optional[Any] = None,
) -> Dict[str, Any]:
"""
Run the meta-optimizer env with a policy (obs -> MetaOptimizerAction) and return metrics dict.
If policy_callable is None, uses a fixed default policy.
"""
if (task_id is None) == (task_spec is None):
raise ValueError("Provide exactly one of task_id or task_spec")
if seed is not None:
random.seed(seed)
torch.manual_seed(seed)
env = MetaOptimizerEnvironment(max_steps=max_steps, loss_threshold=loss_threshold)
obs = env.reset(seed=seed, task_id=task_id, task_spec=task_spec)
loss_trajectory: List[float] = [obs.loss]
if policy_callable is None:
def _default_policy(o): # type: ignore
return MetaOptimizerAction(
lr_scale=0.02, momentum_coef=0.9,
grad_clip_threshold=1.0, weight_decay_this_step=0.0,
)
policy_callable = _default_policy
while not obs.done:
action = policy_callable(obs)
obs = env.step(action)
loss_trajectory.append(obs.loss)
final_loss = obs.loss
steps_to_threshold = obs.steps_to_threshold if obs.steps_to_threshold is not None else max_steps
last_k = min(10, len(loss_trajectory))
mean_last_k = sum(loss_trajectory[-last_k:]) / last_k
return {
"steps_to_threshold": steps_to_threshold,
"success": obs.steps_to_threshold is not None,
"final_loss": final_loss,
"mean_last_10_loss": mean_last_k,
"loss_auc": sum(loss_trajectory) / len(loss_trajectory),
"loss_trajectory": loss_trajectory,
}
class MetaOptimizerEnvironment(Environment[MetaOptimizerAction, MetaOptimizerObservation, State]):
"""
Meta-learning optimizer environment: agent chooses LR scale, momentum, grad clip, weight decay per step.
Reward: dense term = scale * (prev_loss - current_loss) each step (loss decrease); terminal = -steps_to_threshold
when episode ends. Episode ends at max_steps or as soon as loss < threshold (early termination). Supports 50 train
tasks and held-out eval.
"""
SUPPORTS_CONCURRENT_SESSIONS: bool = True
def __init__(
self,
loss_threshold: float = LOSS_THRESHOLD,
max_steps: int = MAX_STEPS,
**kwargs: Any,
):
super().__init__(**kwargs)
self.loss_threshold = loss_threshold
self.max_steps = max_steps
self._device = torch.device("cpu")
# Episode state (set in reset)
self._task_spec: Optional[TaskSpec] = None
self._model: Optional[nn.Module] = None
self._velocities: Optional[List[torch.Tensor]] = None
self._step_count: int = 0
self._current_loss: float = 0.0
self._prev_loss: float = 0.0 # for dense reward (loss decrease)
self._steps_to_threshold: Optional[int] = None
self._action_log: List[Dict[str, Any]] = []
self._episode_id: Optional[str] = None
def reset(
self,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
task_id: Optional[int] = None,
task_spec: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> MetaOptimizerObservation:
if seed is not None:
random.seed(seed)
torch.manual_seed(seed)
if task_spec is not None:
self._task_spec = task_spec_from_dict(task_spec)
else:
tid = task_id if task_id is not None else random.choice(TRAIN_TASK_IDS)
self._task_spec = get_task(tid)
self._model = _build_model(self._task_spec).to(self._device)
self._velocities = [torch.zeros_like(p) for p in self._model.parameters()]
self._step_count = 0
self._steps_to_threshold = None
self._action_log = []
self._episode_id = episode_id or str(uuid4())
# Initial loss (no update yet)
X, y = _get_batch(self._task_spec, 0, self._device)
with torch.no_grad():
out = self._model(X)
self._current_loss = nn.functional.mse_loss(out, y).item()
self._prev_loss = self._current_loss
return self._observation(reward=None, grad_norm=None)
def step(
self,
action: MetaOptimizerAction,
timeout_s: Optional[float] = None,
**kwargs: Any,
) -> MetaOptimizerObservation:
assert self._model is not None and self._task_spec is not None
prev_loss = self._prev_loss
lr = action.lr_scale
momentum = action.momentum_coef
clip = action.grad_clip_threshold
wd = action.weight_decay_this_step
self._action_log.append({
"step": self._step_count,
"lr_scale": lr,
"momentum_coef": momentum,
"grad_clip_threshold": clip,
"weight_decay_this_step": wd,
})
X, y = _get_batch(self._task_spec, self._step_count + 1, self._device)
self._model.train()
out = self._model(X)
loss = nn.functional.mse_loss(out, y)
self._model.zero_grad()
loss.backward()
grads = [p.grad.clone() for p in self._model.parameters()]
grad_norm = sum(g.pow(2).sum() for g in grads).sqrt().item()
if clip > 0:
total_norm = sum(g.pow(2).sum() for g in grads).sqrt()
if total_norm > clip:
scale = clip / (total_norm + 1e-8)
grads = [g * scale for g in grads]
with torch.no_grad():
for i, p in enumerate(self._model.parameters()):
g = grads[i]
v = self._velocities[i]
v.mul_(momentum).add_(g)
p.sub_(v, alpha=lr)
if wd > 0:
p.sub_(p, alpha=wd)
with torch.no_grad():
new_out = self._model(X)
self._current_loss = nn.functional.mse_loss(new_out, y).item()
self._step_count += 1
if self._steps_to_threshold is None and self._current_loss < self.loss_threshold:
self._steps_to_threshold = self._step_count
# Dense reward: reward loss decrease (potential-based shaping, does not change optimal policy)
dense_reward = DENSE_REWARD_SCALE * (prev_loss - self._current_loss)
self._prev_loss = self._current_loss
# End episode when we hit max_steps or when loss first crosses threshold (early termination)
done = self._step_count >= self.max_steps or self._steps_to_threshold is not None
if done:
terminal = -(self._steps_to_threshold if self._steps_to_threshold is not None else self.max_steps)
reward = dense_reward + terminal
else:
reward = dense_reward
return self._observation(reward=reward, grad_norm=grad_norm, done=done)
def _observation(
self,
reward: Optional[float] = None,
grad_norm: Optional[float] = None,
done: bool = False,
) -> MetaOptimizerObservation:
meta: Dict[str, Any] = {}
if self._steps_to_threshold is not None:
meta["steps_to_threshold"] = self._steps_to_threshold
if done and self._action_log:
meta["action_log"] = self._action_log
return MetaOptimizerObservation(
loss=self._current_loss,
step_count=self._step_count,
grad_norm=grad_norm,
steps_to_threshold=self._steps_to_threshold,
done=done,
reward=reward,
metadata=meta,
)
@property
def state(self) -> State:
return State(
episode_id=self._episode_id,
step_count=self._step_count,
)
def get_episode_action_log(self) -> List[Dict[str, Any]]:
"""Return the action log for the current episode (for in-process viz or eval)."""
return list(self._action_log)