File size: 3,716 Bytes
aa677e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from __future__ import annotations

import random
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple

from edgeeda.agents.base import Action, Agent
from edgeeda.config import Config
from edgeeda.utils import sanitize_variant_prefix, stable_hash


@dataclass
class Candidate:
    variant: str
    knobs: Dict[str, Any]
    stage_idx: int
    last_reward: Optional[float]


class SuccessiveHalvingAgent(Agent):
    """
    Simple multi-fidelity baseline:
      - sample a pool
      - evaluate at fidelity0
      - keep top fraction
      - promote to next fidelity
    """

    def __init__(self, cfg: Config, pool_size: int = 12, eta: float = 0.5):
        self.cfg = cfg
        self.pool_size = pool_size
        self.eta = eta
        self.stage_names = cfg.flow.fidelities
        self.variant_prefix = sanitize_variant_prefix(cfg.experiment.name)
        self.pool: List[Candidate] = []
        self._init_pool()
        self._queue: List[Action] = []
        self._rebuild_queue()

    def _sample_knobs(self) -> Dict[str, Any]:
        out: Dict[str, Any] = {}
        for name, spec in self.cfg.tuning.knobs.items():
            if spec.type == "int":
                out[name] = random.randint(int(spec.min), int(spec.max))
            else:
                out[name] = float(spec.min) + random.random() * (float(spec.max) - float(spec.min))
                out[name] = round(out[name], 3)
        return out

    def _init_pool(self):
        self.pool = []
        for i in range(self.pool_size):
            knobs = self._sample_knobs()
            variant = f"{self.variant_prefix}_sh{i:03d}_{stable_hash(str(knobs))}"
            self.pool.append(Candidate(variant=variant, knobs=knobs, stage_idx=0, last_reward=None))

    def _rebuild_queue(self):
        self._queue = []
        for c in self.pool:
            self._queue.append(Action(variant=c.variant, fidelity=self.stage_names[c.stage_idx], knobs=c.knobs))

    def propose(self) -> Action:
        if not self._queue:
            # promote
            self._promote()
            self._rebuild_queue()
        return self._queue.pop(0)

    def _promote(self):
        # group by stage idx
        max_stage = max(c.stage_idx for c in self.pool)
        if max_stage >= len(self.stage_names) - 1:
            # already at final stage; resample fresh pool to continue
            self._init_pool()
            return

        # keep top fraction among candidates at current max stage
        current = [c for c in self.pool if c.stage_idx == max_stage]
        # if rewards missing, treat as very bad
        current.sort(key=lambda c: float("-inf") if c.last_reward is None else c.last_reward, reverse=True)
        k = max(1, int(len(current) * self.eta))
        survivors = current[:k]

        # promote survivors to next stage; others replaced with new randoms at stage 0
        promoted = []
        for c in survivors:
            promoted.append(Candidate(c.variant, c.knobs, c.stage_idx + 1, None))

        fresh_needed = self.pool_size - len(promoted)
        fresh = []
        for i in range(fresh_needed):
            knobs = self._sample_knobs()
            variant = f"{self.variant_prefix}_shR{i:03d}_{stable_hash(str(knobs))}"
            fresh.append(Candidate(variant=variant, knobs=knobs, stage_idx=0, last_reward=None))

        self.pool = promoted + fresh

    def observe(self, action: Action, ok: bool, reward: Optional[float], metrics_flat: Optional[Dict[str, Any]]) -> None:
        for c in self.pool:
            if c.variant == action.variant and self.stage_names[c.stage_idx] == action.fidelity:
                c.last_reward = reward if ok else None
                return