edgeeda-agent / src /edgeeda /agents /surrogate_ucb.py
SamChYe's picture
Publish EdgeEDA agent
aa677e3 verified
from __future__ import annotations
import random
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
from sklearn.ensemble import ExtraTreesRegressor
from edgeeda.agents.base import Action, Agent
from edgeeda.config import Config
from edgeeda.utils import sanitize_variant_prefix, stable_hash
@dataclass
class Obs:
x: np.ndarray
y: float
fidelity: str
variant: str
class SurrogateUCBAgent(Agent):
"""
Agentic tuner:
- Generates candidates (random)
- Fits a lightweight surrogate (ExtraTrees) on observed rewards (for a given fidelity)
- Chooses next action via UCB: mean + kappa * std (std estimated across trees)
Multi-fidelity policy:
- Always start at cheapest fidelity for new variants
- Promote a subset to next fidelity when budget allows
"""
def __init__(self, cfg: Config, kappa: float = 1.0, init_random: int = 6):
self.cfg = cfg
self.kappa = kappa
self.init_random = init_random
self.stage_names = cfg.flow.fidelities
self.knob_names = list(cfg.tuning.knobs.keys())
self.variant_prefix = sanitize_variant_prefix(cfg.experiment.name)
self.obs: List[Obs] = []
self.variant_stage: Dict[str, int] = {}
self._variant_knobs: Dict[str, Dict[str, Any]] = {} # Initialize knob storage
self.counter = 0
def _encode(self, knobs: Dict[str, Any]) -> np.ndarray:
xs = []
for name in self.knob_names:
spec = self.cfg.tuning.knobs[name]
v = float(knobs[name])
# normalize to [0,1]
xs.append((v - float(spec.min)) / max(1e-9, (float(spec.max) - float(spec.min))))
return np.array(xs, dtype=np.float32)
def _sample_knobs(self) -> Dict[str, Any]:
out: Dict[str, Any] = {}
for name, spec in self.cfg.tuning.knobs.items():
if spec.type == "int":
out[name] = random.randint(int(spec.min), int(spec.max))
else:
out[name] = float(spec.min) + random.random() * (float(spec.max) - float(spec.min))
out[name] = round(out[name], 3)
return out
def _fit_surrogate(self, fidelity: str) -> Optional[ExtraTreesRegressor]:
data = [o for o in self.obs if o.fidelity == fidelity]
if len(data) < max(5, self.init_random):
return None
X = np.stack([o.x for o in data], axis=0)
y = np.array([o.y for o in data], dtype=np.float32)
model = ExtraTreesRegressor(
n_estimators=128,
random_state=0,
min_samples_leaf=2,
n_jobs=-1,
)
model.fit(X, y)
return model
def _predict_ucb(self, model: ExtraTreesRegressor, Xcand: np.ndarray) -> np.ndarray:
# estimate mean/std across trees
preds = np.stack([t.predict(Xcand) for t in model.estimators_], axis=0)
mu = preds.mean(axis=0)
sd = preds.std(axis=0)
return mu + self.kappa * sd
def propose(self) -> Action:
self.counter += 1
# With some probability, promote an existing promising variant to next fidelity
promotable = [v for v, st in self.variant_stage.items() if st < len(self.stage_names) - 1]
if promotable and random.random() < 0.35:
# promote best observed (by latest reward) among promotable at current stage
best_v = None
best_y = float("-inf")
for v in promotable:
st = self.variant_stage[v]
fid = self.stage_names[st]
# best reward observed for this variant at its current fidelity
ys = [o.y for o in self.obs if o.fidelity == fid and o.variant == v]
if ys:
y = max(ys)
if y > best_y:
best_y = y
best_v = v
if best_v is not None:
st = self.variant_stage[best_v] + 1
self.variant_stage[best_v] = st
# knobs are encoded in variant hash, but store explicitly:
# easiest: resample from history by matching stable_hash prefix is messy;
# we instead keep a variant->knobs cache.
# If missing, fallback random.
knobs = self._variant_knobs.get(best_v, self._sample_knobs())
return Action(variant=best_v, fidelity=self.stage_names[st], knobs=knobs)
# Otherwise: propose a new variant at cheapest fidelity
knobs = self._sample_knobs()
x = self._encode(knobs)
fid0 = self.stage_names[0]
model = self._fit_surrogate(fid0)
if model is not None:
# do a small candidate search and pick best UCB
cands = []
Xc = []
for _ in range(32):
kk = self._sample_knobs()
cands.append(kk)
Xc.append(self._encode(kk))
Xc = np.stack(Xc, axis=0)
ucb = self._predict_ucb(model, Xc)
best_i = int(np.argmax(ucb))
knobs = cands[best_i]
variant = f"{self.variant_prefix}_u{self.counter:05d}_{stable_hash(str(knobs))}"
self.variant_stage[variant] = 0
self._variant_knobs[variant] = knobs
return Action(variant=variant, fidelity=fid0, knobs=knobs)
def observe(self, action: Action, ok: bool, reward: Optional[float], metrics_flat: Optional[Dict[str, Any]]) -> None:
if ok and reward is not None:
x = self._encode(action.knobs)
self.obs.append(Obs(x=x, y=float(reward), fidelity=action.fidelity, variant=action.variant))
# keep knobs cache
self._variant_knobs[action.variant] = action.knobs