edgeeda-agent / src /edgeeda /agents /successive_halving.py
SamChYe's picture
Publish EdgeEDA agent
aa677e3 verified
from __future__ import annotations
import random
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
from edgeeda.agents.base import Action, Agent
from edgeeda.config import Config
from edgeeda.utils import sanitize_variant_prefix, stable_hash
@dataclass
class Candidate:
variant: str
knobs: Dict[str, Any]
stage_idx: int
last_reward: Optional[float]
class SuccessiveHalvingAgent(Agent):
"""
Simple multi-fidelity baseline:
- sample a pool
- evaluate at fidelity0
- keep top fraction
- promote to next fidelity
"""
def __init__(self, cfg: Config, pool_size: int = 12, eta: float = 0.5):
self.cfg = cfg
self.pool_size = pool_size
self.eta = eta
self.stage_names = cfg.flow.fidelities
self.variant_prefix = sanitize_variant_prefix(cfg.experiment.name)
self.pool: List[Candidate] = []
self._init_pool()
self._queue: List[Action] = []
self._rebuild_queue()
def _sample_knobs(self) -> Dict[str, Any]:
out: Dict[str, Any] = {}
for name, spec in self.cfg.tuning.knobs.items():
if spec.type == "int":
out[name] = random.randint(int(spec.min), int(spec.max))
else:
out[name] = float(spec.min) + random.random() * (float(spec.max) - float(spec.min))
out[name] = round(out[name], 3)
return out
def _init_pool(self):
self.pool = []
for i in range(self.pool_size):
knobs = self._sample_knobs()
variant = f"{self.variant_prefix}_sh{i:03d}_{stable_hash(str(knobs))}"
self.pool.append(Candidate(variant=variant, knobs=knobs, stage_idx=0, last_reward=None))
def _rebuild_queue(self):
self._queue = []
for c in self.pool:
self._queue.append(Action(variant=c.variant, fidelity=self.stage_names[c.stage_idx], knobs=c.knobs))
def propose(self) -> Action:
if not self._queue:
# promote
self._promote()
self._rebuild_queue()
return self._queue.pop(0)
def _promote(self):
# group by stage idx
max_stage = max(c.stage_idx for c in self.pool)
if max_stage >= len(self.stage_names) - 1:
# already at final stage; resample fresh pool to continue
self._init_pool()
return
# keep top fraction among candidates at current max stage
current = [c for c in self.pool if c.stage_idx == max_stage]
# if rewards missing, treat as very bad
current.sort(key=lambda c: float("-inf") if c.last_reward is None else c.last_reward, reverse=True)
k = max(1, int(len(current) * self.eta))
survivors = current[:k]
# promote survivors to next stage; others replaced with new randoms at stage 0
promoted = []
for c in survivors:
promoted.append(Candidate(c.variant, c.knobs, c.stage_idx + 1, None))
fresh_needed = self.pool_size - len(promoted)
fresh = []
for i in range(fresh_needed):
knobs = self._sample_knobs()
variant = f"{self.variant_prefix}_shR{i:03d}_{stable_hash(str(knobs))}"
fresh.append(Candidate(variant=variant, knobs=knobs, stage_idx=0, last_reward=None))
self.pool = promoted + fresh
def observe(self, action: Action, ok: bool, reward: Optional[float], metrics_flat: Optional[Dict[str, Any]]) -> None:
for c in self.pool:
if c.variant == action.variant and self.stage_names[c.stage_idx] == action.fidelity:
c.last_reward = reward if ok else None
return