File size: 12,009 Bytes
5374a2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
from __future__ import annotations
import abc
import inspect
import random
import re
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Tuple


# Regular expression to match indexing expressions like foo[0] or bar["key"]
_INDEX_RE = re.compile(r'^(.*?)\[(.*?)\]$')


# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# 1.  Runtime field helpers
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
class OptimizableField:
    """Expose a concrete runtime attribute via get/set."""
    def __init__(self,
                 name: str,
                 getter: Callable[[], Any],
                 setter: Callable[[Any], None]):
        self.name, self._get, self._set = name, getter, setter
    def get(self) -> Any:            return self._get()
    def set(self, value: Any) -> None: self._set(value)


class PromptRegistry:
    """Central registry for all runtime-patchable fields."""
    def __init__(self) -> None:
        self.fields: Dict[str, OptimizableField] = {}
    def register_field(self, field: OptimizableField):
        self.fields[field.name] = field
    # convenience
    def get(self, name: str) -> Any:
        return self.fields[name].get()
    def set(self, name: str, value: Any):
        self.fields[name].set(value)
    def names(self) -> List[str]:
        return list(self.fields.keys())

    # -- ๆ–ฐๅขž API ----------------------------------------------
    def register_path(self, root: Any, path: str, *, name: str|None=None):
        """็”จ็ฑปไผผ 'encoder.layers[3].dropout_p' ็š„ๅญ—็ฌฆไธฒไธ€ๆฌกๆ€งๆณจๅ†Œใ€‚"""
        key = name or path.split(".")[-1]          # ๅปบ่ฎฎ่ฎฉ็”จๆˆท่‡ช่ตทๆ›ด็Ÿญ alias
        parent, leaf = self._walk(root, path)

        def getter():                       # ่ฏป
            return parent[leaf] if isinstance(parent, (list, dict)) else getattr(parent, leaf)

        def setter(v):                      # ๅ†™
            if isinstance(parent, (list, dict)):
                parent[leaf] = v
            else:
                setattr(parent, leaf, v)

        field = OptimizableField(key, getter, setter)
        self.register_field(field)
        return field

    def _walk(self, root, path: str, create_missing=False):
        cur = root
        parts = path.split(".")
        for part in parts[:-1]:
            m = _INDEX_RE.match(part)
            if m:
                attr, idx = m.groups()
                cur = getattr(cur, attr) if attr else cur
                idx = idx.strip()
                if (idx.startswith("'") and idx.endswith("'")) or (idx.startswith('"') and idx.endswith('"')):
                    idx = idx[1:-1]  # strip quotes if it's a string key
                elif idx.isdigit():
                    idx = int(idx)
                cur = cur[idx]
            else:
                cur = getattr(cur, part)

        # ๆœ€ๅŽไธ€ไธชๅถๅญๅฑžๆ€ง
        leaf = parts[-1]
        m = _INDEX_RE.match(leaf)
        if m:
            attr, idx = m.groups()
            parent = getattr(cur, attr) if attr else cur
            idx = idx.strip()
            if (idx.startswith("'") and idx.endswith("'")) or (idx.startswith('"') and idx.endswith('"')):
                idx = idx[1:-1]
            elif idx.isdigit():
                idx = int(idx)
            return parent, idx
        return cur, leaf


# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# 2.  CodeBlock  (sync / async dualโ€‘compatible)
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# result = await block.run(cfg)    
class CodeBlock:
    """
    Parameters
    ----------
    name : str
        ้€ป่พ‘ๅ๏ผˆๆ—ฅๅฟ—ใ€่ฐƒ่ฏ•ๅ‹ๅฅฝ๏ผ‰
    func : Callable[[dict], Any]
        ๆ™ฎ้€šๅŒๆญฅๅ‡ฝๆ•ฐ๏ผŒ่พ“ๅ…ฅ cfg ๅญ—ๅ…ธ
    """

    def __init__(self, name: str, func: Callable[[Dict[str, Any]], Any]):
        self.name = name
        self._func = func

    def run(self, cfg: Dict[str, Any]) -> Any:
        """ๅŒๆญฅๆ‰ง่กŒๅฐ่ฃ…็š„ๅ‡ฝๆ•ฐใ€‚"""
        return self._func(cfg)

    def __call__(self, cfg: Dict[str, Any]) -> Any:
        return self.run(cfg)

    def __repr__(self):
        return f"<CodeBlock {self.name} (sync)>"




# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# 3.  BaseCodeBlockOptimizer
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
class BaseCodeBlockOptimizer(abc.ABC):
    """
    Abstract optimiser that:
      โ€ข performs sequential trials
      โ€ข writes sampled cfg back to runtime via PromptRegistry
      โ€ข validates that registered names appear in CodeBlock signature
    """

    def __init__(self,
                 registry: PromptRegistry,
                 metric: str,
                 maximize: bool = True,
                 max_trials: int = 30):
        self.registry   = registry
        self.metric     = metric
        self.maximize   = maximize
        self.max_trials = max_trials

    @abc.abstractmethod
    def sample_cfg(self) -> Dict[str, Any]:
        """Return a cfg dict (may include subset of registry names)."""

    @abc.abstractmethod
    def update(self, cfg: Dict[str, Any], score: float):
        """Update internal optimiser state."""

    def _apply_cfg(self, cfg: Dict[str, Any]):
        for k, v in cfg.items():
            if k in self.registry.fields:
                self.registry.set(k, v)

    def _check_codeblock_compat(self, code_block: CodeBlock):
        sig = inspect.signature(code_block._func)
        params = sig.parameters.values()

        has_kwargs = any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params)
        accepts_cfg_dict = "cfg" in sig.parameters

        if has_kwargs or accepts_cfg_dict:
            return

        allowed_keys = set(sig.parameters)
        unknown = set(self.registry.names()) - allowed_keys
        if unknown:
            import warnings
            warnings.warn(f"PromptRegistry fields {unknown} are not present in "
                          f"{code_block.name}() signature; they will be ignored.")

    def run(self,
            code_block: CodeBlock,
            evaluator: Callable[[Dict[str, Any], Any], float]
            ) -> Tuple[Dict[str, Any], List[Tuple[Dict[str, Any], float]]]:

        self._check_codeblock_compat(code_block)

        best_cfg, best_score = None, -float("inf") if self.maximize else float("inf")
        history: List[Tuple[Dict[str, Any], float]] = []

        for _ in range(self.max_trials):
            cfg = self.sample_cfg()
            self._apply_cfg(cfg)
            result = code_block.run(cfg)
            score = evaluator(cfg, result)
            self.update(cfg, score)

            history.append((cfg, score))
            better = score > best_score if self.maximize else score < best_score
            if better:
                best_cfg, best_score = cfg, score

        return best_cfg, history



# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# Other  Helper: bind_cfg โ€“ write cfg into nested attributes
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def bind_cfg(obj: Any, cfg: Dict[str, Any]) -> None:
    """Recursively write *cfg* values into (potentially nested) attributes
    of *obj*.  Key like "a.b.c" becomes obj.a.b.c = value.
    """
    for key, val in cfg.items():
        parts = key.split(".")
        cur = obj
        for part in parts[:-1]:
            cur = getattr(cur, part)
        setattr(cur, parts[-1], val)



# Demo
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ #
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ #
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ Demo: ไธšๅŠกๅฏน่ฑก & ๅทฅไฝœๆต โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ #
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ Demo: Workflow & Sampler โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ #
@dataclass
class Sampler:
    temperature: float = 0.7
    top_p: float = 0.9

class Workflow:

    def __init__(self):
        self.system_prompt = "You are a helpful assistant."
        self.few_shot = "Q: 1+1=?\nA: 2"
        self.sampler = Sampler()

    # @parameter_registry("name", ["a", "self.system_prompt"])
    def execute(self):
        # a = 000 
        pass 

    def run(self):
        prompt = f"{self.system_prompt}\n{self.few_shot}\nUser: Hi"
        return {"prompt": prompt, "score": random.uniform(0, 1)}


# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ Optimizer ๅฎž็Žฐ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ #
class RandomSearchOptimizer(BaseCodeBlockOptimizer):
    def sample_cfg(self) -> Dict[str, Any]:
        return {
            "sampler_temperature": random.uniform(0.3, 1.3),
            "sampler_top_p":       random.uniform(0.5, 1.0),
            "sys_prompt": random.choice([
                "You are a helpful assistant.",
                "You are a super-concise assistant."
            ]),
        }

    def update(self, cfg, score):
        pass


class GreedyLoggerOptimizer(BaseCodeBlockOptimizer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.best = None
        self.best_score = -float("inf") if self.maximize else float("inf")

    def sample_cfg(self):
        return {
            "sampler_temperature": random.uniform(0.3, 1.3),
            "sampler_top_p":       random.uniform(0.5, 1.0),
            "sys_prompt": random.choice([
                "You are a helpful assistant.",
                "You are a super-concise assistant."
            ]),
        }

    def update(self, cfg, score):
        if (self.maximize and score > self.best_score) or (not self.maximize and score < self.best_score):
            self.best = cfg
            self.best_score = score
            print(f"[New Best] score={score:.3f} cfg={cfg}")



# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ ๅฎž้ชŒๅ…ฅๅฃ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ #
def main():
    flow = Workflow()

    registry = PromptRegistry()
    registry.register_path(flow, "system_prompt", name="sys_prompt")
    registry.register_path(flow, "sampler.temperature")
    registry.register_path(flow, "sampler.top_p")

    code_block = CodeBlock("run_workflow", lambda cfg: flow.run())

    def evaluator(cfg, result) -> float:
        return result["score"]

    opt = RandomSearchOptimizer(registry, metric="score", max_trials=10)
    best_cfg, history = opt.run(code_block, evaluator)

    print("\n=== Trial history ===")
    for i, (cfg, score) in enumerate(history, 1):
        print(f"{i:02d}: score={score:.3f}, cfg={cfg}")

    print("\n=== Best ===")
    print(best_cfg)


if __name__ == "__main__":
    main()