from __future__ import annotations import random from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple import numpy as np from sklearn.ensemble import ExtraTreesRegressor from edgeeda.agents.base import Action, Agent from edgeeda.config import Config from edgeeda.utils import sanitize_variant_prefix, stable_hash @dataclass class Obs: x: np.ndarray y: float fidelity: str variant: str class SurrogateUCBAgent(Agent): """ Agentic tuner: - Generates candidates (random) - Fits a lightweight surrogate (ExtraTrees) on observed rewards (for a given fidelity) - Chooses next action via UCB: mean + kappa * std (std estimated across trees) Multi-fidelity policy: - Always start at cheapest fidelity for new variants - Promote a subset to next fidelity when budget allows """ def __init__(self, cfg: Config, kappa: float = 1.0, init_random: int = 6): self.cfg = cfg self.kappa = kappa self.init_random = init_random self.stage_names = cfg.flow.fidelities self.knob_names = list(cfg.tuning.knobs.keys()) self.variant_prefix = sanitize_variant_prefix(cfg.experiment.name) self.obs: List[Obs] = [] self.variant_stage: Dict[str, int] = {} self._variant_knobs: Dict[str, Dict[str, Any]] = {} # Initialize knob storage self.counter = 0 def _encode(self, knobs: Dict[str, Any]) -> np.ndarray: xs = [] for name in self.knob_names: spec = self.cfg.tuning.knobs[name] v = float(knobs[name]) # normalize to [0,1] xs.append((v - float(spec.min)) / max(1e-9, (float(spec.max) - float(spec.min)))) return np.array(xs, dtype=np.float32) def _sample_knobs(self) -> Dict[str, Any]: out: Dict[str, Any] = {} for name, spec in self.cfg.tuning.knobs.items(): if spec.type == "int": out[name] = random.randint(int(spec.min), int(spec.max)) else: out[name] = float(spec.min) + random.random() * (float(spec.max) - float(spec.min)) out[name] = round(out[name], 3) return out def _fit_surrogate(self, fidelity: str) -> Optional[ExtraTreesRegressor]: data = [o for o in self.obs if o.fidelity == fidelity] if len(data) < max(5, self.init_random): return None X = np.stack([o.x for o in data], axis=0) y = np.array([o.y for o in data], dtype=np.float32) model = ExtraTreesRegressor( n_estimators=128, random_state=0, min_samples_leaf=2, n_jobs=-1, ) model.fit(X, y) return model def _predict_ucb(self, model: ExtraTreesRegressor, Xcand: np.ndarray) -> np.ndarray: # estimate mean/std across trees preds = np.stack([t.predict(Xcand) for t in model.estimators_], axis=0) mu = preds.mean(axis=0) sd = preds.std(axis=0) return mu + self.kappa * sd def propose(self) -> Action: self.counter += 1 # With some probability, promote an existing promising variant to next fidelity promotable = [v for v, st in self.variant_stage.items() if st < len(self.stage_names) - 1] if promotable and random.random() < 0.35: # promote best observed (by latest reward) among promotable at current stage best_v = None best_y = float("-inf") for v in promotable: st = self.variant_stage[v] fid = self.stage_names[st] # best reward observed for this variant at its current fidelity ys = [o.y for o in self.obs if o.fidelity == fid and o.variant == v] if ys: y = max(ys) if y > best_y: best_y = y best_v = v if best_v is not None: st = self.variant_stage[best_v] + 1 self.variant_stage[best_v] = st # knobs are encoded in variant hash, but store explicitly: # easiest: resample from history by matching stable_hash prefix is messy; # we instead keep a variant->knobs cache. # If missing, fallback random. knobs = self._variant_knobs.get(best_v, self._sample_knobs()) return Action(variant=best_v, fidelity=self.stage_names[st], knobs=knobs) # Otherwise: propose a new variant at cheapest fidelity knobs = self._sample_knobs() x = self._encode(knobs) fid0 = self.stage_names[0] model = self._fit_surrogate(fid0) if model is not None: # do a small candidate search and pick best UCB cands = [] Xc = [] for _ in range(32): kk = self._sample_knobs() cands.append(kk) Xc.append(self._encode(kk)) Xc = np.stack(Xc, axis=0) ucb = self._predict_ucb(model, Xc) best_i = int(np.argmax(ucb)) knobs = cands[best_i] variant = f"{self.variant_prefix}_u{self.counter:05d}_{stable_hash(str(knobs))}" self.variant_stage[variant] = 0 self._variant_knobs[variant] = knobs return Action(variant=variant, fidelity=fid0, knobs=knobs) def observe(self, action: Action, ok: bool, reward: Optional[float], metrics_flat: Optional[Dict[str, Any]]) -> None: if ok and reward is not None: x = self._encode(action.knobs) self.obs.append(Obs(x=x, y=float(reward), fidelity=action.fidelity, variant=action.variant)) # keep knobs cache self._variant_knobs[action.variant] = action.knobs