from __future__ import annotations import random from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple from edgeeda.agents.base import Action, Agent from edgeeda.config import Config from edgeeda.utils import sanitize_variant_prefix, stable_hash @dataclass class Candidate: variant: str knobs: Dict[str, Any] stage_idx: int last_reward: Optional[float] class SuccessiveHalvingAgent(Agent): """ Simple multi-fidelity baseline: - sample a pool - evaluate at fidelity0 - keep top fraction - promote to next fidelity """ def __init__(self, cfg: Config, pool_size: int = 12, eta: float = 0.5): self.cfg = cfg self.pool_size = pool_size self.eta = eta self.stage_names = cfg.flow.fidelities self.variant_prefix = sanitize_variant_prefix(cfg.experiment.name) self.pool: List[Candidate] = [] self._init_pool() self._queue: List[Action] = [] self._rebuild_queue() def _sample_knobs(self) -> Dict[str, Any]: out: Dict[str, Any] = {} for name, spec in self.cfg.tuning.knobs.items(): if spec.type == "int": out[name] = random.randint(int(spec.min), int(spec.max)) else: out[name] = float(spec.min) + random.random() * (float(spec.max) - float(spec.min)) out[name] = round(out[name], 3) return out def _init_pool(self): self.pool = [] for i in range(self.pool_size): knobs = self._sample_knobs() variant = f"{self.variant_prefix}_sh{i:03d}_{stable_hash(str(knobs))}" self.pool.append(Candidate(variant=variant, knobs=knobs, stage_idx=0, last_reward=None)) def _rebuild_queue(self): self._queue = [] for c in self.pool: self._queue.append(Action(variant=c.variant, fidelity=self.stage_names[c.stage_idx], knobs=c.knobs)) def propose(self) -> Action: if not self._queue: # promote self._promote() self._rebuild_queue() return self._queue.pop(0) def _promote(self): # group by stage idx max_stage = max(c.stage_idx for c in self.pool) if max_stage >= len(self.stage_names) - 1: # already at final stage; resample fresh pool to continue self._init_pool() return # keep top fraction among candidates at current max stage current = [c for c in self.pool if c.stage_idx == max_stage] # if rewards missing, treat as very bad current.sort(key=lambda c: float("-inf") if c.last_reward is None else c.last_reward, reverse=True) k = max(1, int(len(current) * self.eta)) survivors = current[:k] # promote survivors to next stage; others replaced with new randoms at stage 0 promoted = [] for c in survivors: promoted.append(Candidate(c.variant, c.knobs, c.stage_idx + 1, None)) fresh_needed = self.pool_size - len(promoted) fresh = [] for i in range(fresh_needed): knobs = self._sample_knobs() variant = f"{self.variant_prefix}_shR{i:03d}_{stable_hash(str(knobs))}" fresh.append(Candidate(variant=variant, knobs=knobs, stage_idx=0, last_reward=None)) self.pool = promoted + fresh def observe(self, action: Action, ok: bool, reward: Optional[float], metrics_flat: Optional[Dict[str, Any]]) -> None: for c in self.pool: if c.variant == action.variant and self.stage_names[c.stage_idx] == action.fidelity: c.last_reward = reward if ok else None return