| |
| |
| |
| |
| |
|
|
| """ |
| Meta-optimizer environment: train an RL agent to act as an optimizer on random regression tasks. |
| |
| Supports 50 training tasks, held-out eval, rich action space (LR, momentum, grad clip, weight decay), |
| and convergence-speed reward. Action log is exposed for emergent-behavior visualization. |
| """ |
|
|
| import math |
| import random |
| from typing import Any, Dict, List, Optional |
| from uuid import uuid4 |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from openenv.core.env_server.interfaces import Environment |
| from openenv.core.env_server.types import State |
|
|
| from my_env.models import MetaOptimizerAction, MetaOptimizerObservation |
| from .tasks import TRAIN_TASK_IDS, get_task, task_spec_from_dict, TaskSpec |
|
|
| |
| LOSS_THRESHOLD = 0.1 |
| MAX_STEPS = 100 |
| BATCH_SIZE = 32 |
| |
| DENSE_REWARD_SCALE = 0.2 |
|
|
|
|
| def _build_model(spec: TaskSpec) -> nn.Module: |
| """Build a 2-layer MLP for the given task spec.""" |
| torch.manual_seed(spec.arch_seed) |
| return nn.Sequential( |
| nn.Linear(spec.input_dim, spec.hidden_dim), |
| nn.ReLU(), |
| nn.Linear(spec.hidden_dim, spec.output_dim), |
| ) |
|
|
|
|
| def _get_batch(spec: TaskSpec, step: int, device: torch.device): |
| """Sinusoidal regression: X in [0,1], y = amplitude * sin(2*pi*freq*x + phase) + noise.""" |
| g = torch.Generator(device=device) |
| g.manual_seed(spec.data_seed + step) |
| X = torch.rand(BATCH_SIZE, spec.input_dim, device=device, generator=g) |
| |
| x = X[:, 0:1] |
| y = spec.amplitude * torch.sin(2 * math.pi * spec.freq * x + spec.phase) |
| y = y + 0.05 * torch.randn_like(y, device=device, generator=g) |
| return X, y |
|
|
|
|
| def run_adam_baseline( |
| task_id: Optional[int] = None, |
| task_spec: Optional[Dict[str, Any]] = None, |
| max_steps: int = MAX_STEPS, |
| loss_threshold: float = LOSS_THRESHOLD, |
| lr: float = 1e-2, |
| seed: Optional[int] = None, |
| return_metrics: bool = False, |
| ): |
| """ |
| Run Adam on one task. Returns steps to threshold, or full metrics dict if return_metrics=True. |
| """ |
| if (task_id is None) == (task_spec is None): |
| raise ValueError("Provide exactly one of task_id or task_spec") |
| if seed is not None: |
| torch.manual_seed(seed) |
| device = torch.device("cpu") |
| spec = task_spec_from_dict(task_spec) if task_spec is not None else get_task(task_id) |
| model = _build_model(spec).to(device) |
| opt = torch.optim.Adam(model.parameters(), lr=lr) |
| loss_trajectory: List[float] = [] |
| steps_to_threshold: Optional[int] = None |
| for step in range(max_steps): |
| X, y = _get_batch(spec, step, device) |
| model.train() |
| opt.zero_grad() |
| loss = nn.functional.mse_loss(model(X), y) |
| loss.backward() |
| opt.step() |
| with torch.no_grad(): |
| L = nn.functional.mse_loss(model(X), y).item() |
| loss_trajectory.append(L) |
| if steps_to_threshold is None and L < loss_threshold: |
| steps_to_threshold = step + 1 |
| final_loss = loss_trajectory[-1] if loss_trajectory else float("inf") |
| if not return_metrics: |
| return steps_to_threshold if steps_to_threshold is not None else max_steps |
| last_k = min(10, len(loss_trajectory)) |
| mean_last_k = sum(loss_trajectory[-last_k:]) / last_k if loss_trajectory else final_loss |
| return { |
| "steps_to_threshold": steps_to_threshold if steps_to_threshold is not None else max_steps, |
| "success": steps_to_threshold is not None, |
| "final_loss": final_loss, |
| "mean_last_10_loss": mean_last_k, |
| "loss_auc": sum(loss_trajectory) / len(loss_trajectory) if loss_trajectory else final_loss, |
| "loss_trajectory": loss_trajectory, |
| } |
|
|
|
|
| def run_sgd_baseline( |
| task_id: Optional[int] = None, |
| task_spec: Optional[Dict[str, Any]] = None, |
| max_steps: int = MAX_STEPS, |
| loss_threshold: float = LOSS_THRESHOLD, |
| lr: float = 1e-2, |
| momentum: float = 0.9, |
| seed: Optional[int] = None, |
| return_metrics: bool = False, |
| ): |
| """ |
| Run SGD (with optional momentum) on one task. Returns steps to threshold, or full metrics dict if return_metrics=True. |
| """ |
| if (task_id is None) == (task_spec is None): |
| raise ValueError("Provide exactly one of task_id or task_spec") |
| if seed is not None: |
| torch.manual_seed(seed) |
| device = torch.device("cpu") |
| spec = task_spec_from_dict(task_spec) if task_spec is not None else get_task(task_id) |
| model = _build_model(spec).to(device) |
| opt = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum) |
| loss_trajectory = [] |
| steps_to_threshold = None |
| for step in range(max_steps): |
| X, y = _get_batch(spec, step, device) |
| model.train() |
| opt.zero_grad() |
| loss = nn.functional.mse_loss(model(X), y) |
| loss.backward() |
| opt.step() |
| with torch.no_grad(): |
| L = nn.functional.mse_loss(model(X), y).item() |
| loss_trajectory.append(L) |
| if steps_to_threshold is None and L < loss_threshold: |
| steps_to_threshold = step + 1 |
| final_loss = loss_trajectory[-1] if loss_trajectory else float("inf") |
| if not return_metrics: |
| return steps_to_threshold if steps_to_threshold is not None else max_steps |
| last_k = min(10, len(loss_trajectory)) |
| mean_last_k = sum(loss_trajectory[-last_k:]) / last_k if loss_trajectory else final_loss |
| return { |
| "steps_to_threshold": steps_to_threshold if steps_to_threshold is not None else max_steps, |
| "success": steps_to_threshold is not None, |
| "final_loss": final_loss, |
| "mean_last_10_loss": mean_last_k, |
| "loss_auc": sum(loss_trajectory) / len(loss_trajectory) if loss_trajectory else final_loss, |
| "loss_trajectory": loss_trajectory, |
| } |
|
|
|
|
| def run_meta_optimizer_trajectory( |
| task_id: Optional[int] = None, |
| task_spec: Optional[Dict[str, Any]] = None, |
| max_steps: int = MAX_STEPS, |
| loss_threshold: float = LOSS_THRESHOLD, |
| seed: Optional[int] = None, |
| policy_callable: Optional[Any] = None, |
| ) -> Dict[str, Any]: |
| """ |
| Run the meta-optimizer env with a policy (obs -> MetaOptimizerAction) and return metrics dict. |
| If policy_callable is None, uses a fixed default policy. |
| """ |
| if (task_id is None) == (task_spec is None): |
| raise ValueError("Provide exactly one of task_id or task_spec") |
| if seed is not None: |
| random.seed(seed) |
| torch.manual_seed(seed) |
| env = MetaOptimizerEnvironment(max_steps=max_steps, loss_threshold=loss_threshold) |
| obs = env.reset(seed=seed, task_id=task_id, task_spec=task_spec) |
| loss_trajectory: List[float] = [obs.loss] |
| if policy_callable is None: |
| def _default_policy(o): |
| return MetaOptimizerAction( |
| lr_scale=0.02, momentum_coef=0.9, |
| grad_clip_threshold=1.0, weight_decay_this_step=0.0, |
| ) |
| policy_callable = _default_policy |
| while not obs.done: |
| action = policy_callable(obs) |
| obs = env.step(action) |
| loss_trajectory.append(obs.loss) |
| final_loss = obs.loss |
| steps_to_threshold = obs.steps_to_threshold if obs.steps_to_threshold is not None else max_steps |
| last_k = min(10, len(loss_trajectory)) |
| mean_last_k = sum(loss_trajectory[-last_k:]) / last_k |
| return { |
| "steps_to_threshold": steps_to_threshold, |
| "success": obs.steps_to_threshold is not None, |
| "final_loss": final_loss, |
| "mean_last_10_loss": mean_last_k, |
| "loss_auc": sum(loss_trajectory) / len(loss_trajectory), |
| "loss_trajectory": loss_trajectory, |
| } |
|
|
|
|
| class MetaOptimizerEnvironment(Environment[MetaOptimizerAction, MetaOptimizerObservation, State]): |
| """ |
| Meta-learning optimizer environment: agent chooses LR scale, momentum, grad clip, weight decay per step. |
| Reward: dense term = scale * (prev_loss - current_loss) each step (loss decrease); terminal = -steps_to_threshold |
| when episode ends. Episode ends at max_steps or as soon as loss < threshold (early termination). Supports 50 train |
| tasks and held-out eval. |
| """ |
|
|
| SUPPORTS_CONCURRENT_SESSIONS: bool = True |
|
|
| def __init__( |
| self, |
| loss_threshold: float = LOSS_THRESHOLD, |
| max_steps: int = MAX_STEPS, |
| **kwargs: Any, |
| ): |
| super().__init__(**kwargs) |
| self.loss_threshold = loss_threshold |
| self.max_steps = max_steps |
| self._device = torch.device("cpu") |
|
|
| |
| self._task_spec: Optional[TaskSpec] = None |
| self._model: Optional[nn.Module] = None |
| self._velocities: Optional[List[torch.Tensor]] = None |
| self._step_count: int = 0 |
| self._current_loss: float = 0.0 |
| self._prev_loss: float = 0.0 |
| self._steps_to_threshold: Optional[int] = None |
| self._action_log: List[Dict[str, Any]] = [] |
| self._episode_id: Optional[str] = None |
|
|
| def reset( |
| self, |
| seed: Optional[int] = None, |
| episode_id: Optional[str] = None, |
| task_id: Optional[int] = None, |
| task_spec: Optional[Dict[str, Any]] = None, |
| **kwargs: Any, |
| ) -> MetaOptimizerObservation: |
| if seed is not None: |
| random.seed(seed) |
| torch.manual_seed(seed) |
| if task_spec is not None: |
| self._task_spec = task_spec_from_dict(task_spec) |
| else: |
| tid = task_id if task_id is not None else random.choice(TRAIN_TASK_IDS) |
| self._task_spec = get_task(tid) |
| self._model = _build_model(self._task_spec).to(self._device) |
| self._velocities = [torch.zeros_like(p) for p in self._model.parameters()] |
| self._step_count = 0 |
| self._steps_to_threshold = None |
| self._action_log = [] |
| self._episode_id = episode_id or str(uuid4()) |
|
|
| |
| X, y = _get_batch(self._task_spec, 0, self._device) |
| with torch.no_grad(): |
| out = self._model(X) |
| self._current_loss = nn.functional.mse_loss(out, y).item() |
| self._prev_loss = self._current_loss |
|
|
| return self._observation(reward=None, grad_norm=None) |
|
|
| def step( |
| self, |
| action: MetaOptimizerAction, |
| timeout_s: Optional[float] = None, |
| **kwargs: Any, |
| ) -> MetaOptimizerObservation: |
| assert self._model is not None and self._task_spec is not None |
| prev_loss = self._prev_loss |
| lr = action.lr_scale |
| momentum = action.momentum_coef |
| clip = action.grad_clip_threshold |
| wd = action.weight_decay_this_step |
|
|
| self._action_log.append({ |
| "step": self._step_count, |
| "lr_scale": lr, |
| "momentum_coef": momentum, |
| "grad_clip_threshold": clip, |
| "weight_decay_this_step": wd, |
| }) |
|
|
| X, y = _get_batch(self._task_spec, self._step_count + 1, self._device) |
| self._model.train() |
| out = self._model(X) |
| loss = nn.functional.mse_loss(out, y) |
| self._model.zero_grad() |
| loss.backward() |
|
|
| grads = [p.grad.clone() for p in self._model.parameters()] |
| grad_norm = sum(g.pow(2).sum() for g in grads).sqrt().item() |
|
|
| if clip > 0: |
| total_norm = sum(g.pow(2).sum() for g in grads).sqrt() |
| if total_norm > clip: |
| scale = clip / (total_norm + 1e-8) |
| grads = [g * scale for g in grads] |
|
|
| with torch.no_grad(): |
| for i, p in enumerate(self._model.parameters()): |
| g = grads[i] |
| v = self._velocities[i] |
| v.mul_(momentum).add_(g) |
| p.sub_(v, alpha=lr) |
| if wd > 0: |
| p.sub_(p, alpha=wd) |
|
|
| with torch.no_grad(): |
| new_out = self._model(X) |
| self._current_loss = nn.functional.mse_loss(new_out, y).item() |
|
|
| self._step_count += 1 |
| if self._steps_to_threshold is None and self._current_loss < self.loss_threshold: |
| self._steps_to_threshold = self._step_count |
|
|
| |
| dense_reward = DENSE_REWARD_SCALE * (prev_loss - self._current_loss) |
| self._prev_loss = self._current_loss |
|
|
| |
| done = self._step_count >= self.max_steps or self._steps_to_threshold is not None |
| if done: |
| terminal = -(self._steps_to_threshold if self._steps_to_threshold is not None else self.max_steps) |
| reward = dense_reward + terminal |
| else: |
| reward = dense_reward |
|
|
| return self._observation(reward=reward, grad_norm=grad_norm, done=done) |
|
|
| def _observation( |
| self, |
| reward: Optional[float] = None, |
| grad_norm: Optional[float] = None, |
| done: bool = False, |
| ) -> MetaOptimizerObservation: |
| meta: Dict[str, Any] = {} |
| if self._steps_to_threshold is not None: |
| meta["steps_to_threshold"] = self._steps_to_threshold |
| if done and self._action_log: |
| meta["action_log"] = self._action_log |
| return MetaOptimizerObservation( |
| loss=self._current_loss, |
| step_count=self._step_count, |
| grad_norm=grad_norm, |
| steps_to_threshold=self._steps_to_threshold, |
| done=done, |
| reward=reward, |
| metadata=meta, |
| ) |
|
|
| @property |
| def state(self) -> State: |
| return State( |
| episode_id=self._episode_id, |
| step_count=self._step_count, |
| ) |
|
|
| def get_episode_action_log(self) -> List[Dict[str, Any]]: |
| """Return the action log for the current episode (for in-process viz or eval).""" |
| return list(self._action_log) |
|
|