|
|
from __future__ import annotations |
|
|
|
|
|
import itertools |
|
|
import random |
|
|
from typing import Any, Dict, Optional |
|
|
|
|
|
from edgeeda.agents.base import Action, Agent |
|
|
from edgeeda.config import Config |
|
|
from edgeeda.utils import sanitize_variant_prefix, stable_hash |
|
|
|
|
|
|
|
|
class RandomSearchAgent(Agent): |
|
|
def __init__(self, cfg: Config): |
|
|
self.cfg = cfg |
|
|
self.counter = 0 |
|
|
self.variant_prefix = sanitize_variant_prefix(cfg.experiment.name) |
|
|
|
|
|
def _sample_knobs(self) -> Dict[str, Any]: |
|
|
out: Dict[str, Any] = {} |
|
|
for name, spec in self.cfg.tuning.knobs.items(): |
|
|
if spec.type == "int": |
|
|
out[name] = random.randint(int(spec.min), int(spec.max)) |
|
|
else: |
|
|
out[name] = float(spec.min) + random.random() * (float(spec.max) - float(spec.min)) |
|
|
out[name] = round(out[name], 3) |
|
|
return out |
|
|
|
|
|
def propose(self) -> Action: |
|
|
self.counter += 1 |
|
|
knobs = self._sample_knobs() |
|
|
variant = f"{self.variant_prefix}_t{self.counter:05d}_{stable_hash(str(knobs))}" |
|
|
fidelity = self.cfg.flow.fidelities[0] |
|
|
return Action(variant=variant, fidelity=fidelity, knobs=knobs) |
|
|
|
|
|
def observe(self, action: Action, ok: bool, reward: Optional[float], metrics_flat: Optional[Dict[str, Any]]) -> None: |
|
|
|
|
|
return |
|
|
|