|
|
from __future__ import annotations |
|
|
|
|
|
import random |
|
|
from dataclasses import dataclass |
|
|
from typing import Any, Dict, List, Optional, Tuple |
|
|
|
|
|
from edgeeda.agents.base import Action, Agent |
|
|
from edgeeda.config import Config |
|
|
from edgeeda.utils import sanitize_variant_prefix, stable_hash |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Candidate: |
|
|
variant: str |
|
|
knobs: Dict[str, Any] |
|
|
stage_idx: int |
|
|
last_reward: Optional[float] |
|
|
|
|
|
|
|
|
class SuccessiveHalvingAgent(Agent): |
|
|
""" |
|
|
Simple multi-fidelity baseline: |
|
|
- sample a pool |
|
|
- evaluate at fidelity0 |
|
|
- keep top fraction |
|
|
- promote to next fidelity |
|
|
""" |
|
|
|
|
|
def __init__(self, cfg: Config, pool_size: int = 12, eta: float = 0.5): |
|
|
self.cfg = cfg |
|
|
self.pool_size = pool_size |
|
|
self.eta = eta |
|
|
self.stage_names = cfg.flow.fidelities |
|
|
self.variant_prefix = sanitize_variant_prefix(cfg.experiment.name) |
|
|
self.pool: List[Candidate] = [] |
|
|
self._init_pool() |
|
|
self._queue: List[Action] = [] |
|
|
self._rebuild_queue() |
|
|
|
|
|
def _sample_knobs(self) -> Dict[str, Any]: |
|
|
out: Dict[str, Any] = {} |
|
|
for name, spec in self.cfg.tuning.knobs.items(): |
|
|
if spec.type == "int": |
|
|
out[name] = random.randint(int(spec.min), int(spec.max)) |
|
|
else: |
|
|
out[name] = float(spec.min) + random.random() * (float(spec.max) - float(spec.min)) |
|
|
out[name] = round(out[name], 3) |
|
|
return out |
|
|
|
|
|
def _init_pool(self): |
|
|
self.pool = [] |
|
|
for i in range(self.pool_size): |
|
|
knobs = self._sample_knobs() |
|
|
variant = f"{self.variant_prefix}_sh{i:03d}_{stable_hash(str(knobs))}" |
|
|
self.pool.append(Candidate(variant=variant, knobs=knobs, stage_idx=0, last_reward=None)) |
|
|
|
|
|
def _rebuild_queue(self): |
|
|
self._queue = [] |
|
|
for c in self.pool: |
|
|
self._queue.append(Action(variant=c.variant, fidelity=self.stage_names[c.stage_idx], knobs=c.knobs)) |
|
|
|
|
|
def propose(self) -> Action: |
|
|
if not self._queue: |
|
|
|
|
|
self._promote() |
|
|
self._rebuild_queue() |
|
|
return self._queue.pop(0) |
|
|
|
|
|
def _promote(self): |
|
|
|
|
|
max_stage = max(c.stage_idx for c in self.pool) |
|
|
if max_stage >= len(self.stage_names) - 1: |
|
|
|
|
|
self._init_pool() |
|
|
return |
|
|
|
|
|
|
|
|
current = [c for c in self.pool if c.stage_idx == max_stage] |
|
|
|
|
|
current.sort(key=lambda c: float("-inf") if c.last_reward is None else c.last_reward, reverse=True) |
|
|
k = max(1, int(len(current) * self.eta)) |
|
|
survivors = current[:k] |
|
|
|
|
|
|
|
|
promoted = [] |
|
|
for c in survivors: |
|
|
promoted.append(Candidate(c.variant, c.knobs, c.stage_idx + 1, None)) |
|
|
|
|
|
fresh_needed = self.pool_size - len(promoted) |
|
|
fresh = [] |
|
|
for i in range(fresh_needed): |
|
|
knobs = self._sample_knobs() |
|
|
variant = f"{self.variant_prefix}_shR{i:03d}_{stable_hash(str(knobs))}" |
|
|
fresh.append(Candidate(variant=variant, knobs=knobs, stage_idx=0, last_reward=None)) |
|
|
|
|
|
self.pool = promoted + fresh |
|
|
|
|
|
def observe(self, action: Action, ok: bool, reward: Optional[float], metrics_flat: Optional[Dict[str, Any]]) -> None: |
|
|
for c in self.pool: |
|
|
if c.variant == action.variant and self.stage_names[c.stage_idx] == action.fidelity: |
|
|
c.last_reward = reward if ok else None |
|
|
return |
|
|
|