File size: 5,812 Bytes
aa677e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
from __future__ import annotations

import random
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
from sklearn.ensemble import ExtraTreesRegressor

from edgeeda.agents.base import Action, Agent
from edgeeda.config import Config
from edgeeda.utils import sanitize_variant_prefix, stable_hash


@dataclass
class Obs:
    x: np.ndarray
    y: float
    fidelity: str
    variant: str


class SurrogateUCBAgent(Agent):
    """
    Agentic tuner:
      - Generates candidates (random)
      - Fits a lightweight surrogate (ExtraTrees) on observed rewards (for a given fidelity)
      - Chooses next action via UCB: mean + kappa * std (std estimated across trees)

    Multi-fidelity policy:
      - Always start at cheapest fidelity for new variants
      - Promote a subset to next fidelity when budget allows
    """

    def __init__(self, cfg: Config, kappa: float = 1.0, init_random: int = 6):
        self.cfg = cfg
        self.kappa = kappa
        self.init_random = init_random
        self.stage_names = cfg.flow.fidelities
        self.knob_names = list(cfg.tuning.knobs.keys())
        self.variant_prefix = sanitize_variant_prefix(cfg.experiment.name)

        self.obs: List[Obs] = []
        self.variant_stage: Dict[str, int] = {}
        self._variant_knobs: Dict[str, Dict[str, Any]] = {}  # Initialize knob storage
        self.counter = 0

    def _encode(self, knobs: Dict[str, Any]) -> np.ndarray:
        xs = []
        for name in self.knob_names:
            spec = self.cfg.tuning.knobs[name]
            v = float(knobs[name])
            # normalize to [0,1]
            xs.append((v - float(spec.min)) / max(1e-9, (float(spec.max) - float(spec.min))))
        return np.array(xs, dtype=np.float32)

    def _sample_knobs(self) -> Dict[str, Any]:
        out: Dict[str, Any] = {}
        for name, spec in self.cfg.tuning.knobs.items():
            if spec.type == "int":
                out[name] = random.randint(int(spec.min), int(spec.max))
            else:
                out[name] = float(spec.min) + random.random() * (float(spec.max) - float(spec.min))
                out[name] = round(out[name], 3)
        return out

    def _fit_surrogate(self, fidelity: str) -> Optional[ExtraTreesRegressor]:
        data = [o for o in self.obs if o.fidelity == fidelity]
        if len(data) < max(5, self.init_random):
            return None
        X = np.stack([o.x for o in data], axis=0)
        y = np.array([o.y for o in data], dtype=np.float32)
        model = ExtraTreesRegressor(
            n_estimators=128,
            random_state=0,
            min_samples_leaf=2,
            n_jobs=-1,
        )
        model.fit(X, y)
        return model

    def _predict_ucb(self, model: ExtraTreesRegressor, Xcand: np.ndarray) -> np.ndarray:
        # estimate mean/std across trees
        preds = np.stack([t.predict(Xcand) for t in model.estimators_], axis=0)
        mu = preds.mean(axis=0)
        sd = preds.std(axis=0)
        return mu + self.kappa * sd

    def propose(self) -> Action:
        self.counter += 1

        # With some probability, promote an existing promising variant to next fidelity
        promotable = [v for v, st in self.variant_stage.items() if st < len(self.stage_names) - 1]
        if promotable and random.random() < 0.35:
            # promote best observed (by latest reward) among promotable at current stage
            best_v = None
            best_y = float("-inf")
            for v in promotable:
                st = self.variant_stage[v]
                fid = self.stage_names[st]
                # best reward observed for this variant at its current fidelity
                ys = [o.y for o in self.obs if o.fidelity == fid and o.variant == v]
                if ys:
                    y = max(ys)
                    if y > best_y:
                        best_y = y
                        best_v = v
            if best_v is not None:
                st = self.variant_stage[best_v] + 1
                self.variant_stage[best_v] = st
                # knobs are encoded in variant hash, but store explicitly:
                # easiest: resample from history by matching stable_hash prefix is messy;
                # we instead keep a variant->knobs cache.
                # If missing, fallback random.
                knobs = self._variant_knobs.get(best_v, self._sample_knobs())
                return Action(variant=best_v, fidelity=self.stage_names[st], knobs=knobs)

        # Otherwise: propose a new variant at cheapest fidelity
        knobs = self._sample_knobs()
        x = self._encode(knobs)

        fid0 = self.stage_names[0]
        model = self._fit_surrogate(fid0)

        if model is not None:
            # do a small candidate search and pick best UCB
            cands = []
            Xc = []
            for _ in range(32):
                kk = self._sample_knobs()
                cands.append(kk)
                Xc.append(self._encode(kk))
            Xc = np.stack(Xc, axis=0)
            ucb = self._predict_ucb(model, Xc)
            best_i = int(np.argmax(ucb))
            knobs = cands[best_i]

        variant = f"{self.variant_prefix}_u{self.counter:05d}_{stable_hash(str(knobs))}"
        self.variant_stage[variant] = 0
        self._variant_knobs[variant] = knobs
        return Action(variant=variant, fidelity=fid0, knobs=knobs)

    def observe(self, action: Action, ok: bool, reward: Optional[float], metrics_flat: Optional[Dict[str, Any]]) -> None:
        if ok and reward is not None:
            x = self._encode(action.knobs)
            self.obs.append(Obs(x=x, y=float(reward), fidelity=action.fidelity, variant=action.variant))
        # keep knobs cache
        self._variant_knobs[action.variant] = action.knobs