File size: 2,408 Bytes
5374a2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
from typing import Any, Callable, Dict, List, Optional
import abc
from .decorators import EntryPoint
from .registry import ParamRegistry
class BaseOptimizer(abc.ABC):
# def __init__(
# self,
# registry: ParamRegistry,
# program: Callable,
# evaluator: Callable[[Dict[str, Any]], float],
# **kwargs
def __init__(
self,
registry: ParamRegistry,
program: Callable[..., Dict[str, Any]] = None,
evaluator: Optional[Callable[..., Any]] = None,
):
"""
Abstract base class for optimization routines.
Parameters:
- registry (ParamRegistry): parameter access layer
- evaluator (Callable): function that evaluates the result dict and returns a float
"""
self.program = program
self.registry = registry
self.program = program
self.evaluator = evaluator
def get_param(self, name: str) -> Any:
"""Retrieve the current value of a parameter by name."""
return self.registry.get(name)
def set_param(self, name: str, value: Any):
"""Set the value of a parameter by name."""
self.registry.set(name, value)
def param_names(self) -> List[str]:
"""Return the list of all registered parameter names."""
return self.registry.names()
def get_current_cfg(self) -> Dict[str, Any]:
"""Return current config as a dictionary."""
return {name: self.get_param(name) for name in self.param_names()}
def apply_cfg(self, cfg: Dict[str, Any]):
"""Apply a configuration dictionary to the registered parameters."""
for k, v in cfg.items():
if k in self.registry.fields:
self.registry.set(k, v)
@abc.abstractmethod
def optimize(self):
"""
Abstract optimization loop. Should be implemented by subclasses.
Parameters:
- program_entry: callable that runs the program and returns output dict
Returns:
- (best_cfg, history): best config found and full search history
"""
if self.program is None:
self.program = EntryPoint.get_entry()
if self.program is None:
raise RuntimeError("No entry function provided or registered.")
print(f"Starting optimization from entry: {self.program.__name__}")
raise NotImplementedError |