|
|
import re |
|
|
from typing import Any, Callable, Dict, List |
|
|
from ...prompts.template import PromptTemplate |
|
|
import copy |
|
|
import warnings |
|
|
_INDEX_RE = re.compile(r'^(.*?)\[(.*?)\]$') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_PATH_RE = re.compile(r""" |
|
|
([a-zA-Z_]\w*) | # attr name |
|
|
\[\s*(-?\d+)\s*\] | # list index (allow negative) |
|
|
\[\s*['"]([^'\"]+)['"]\s*\] # dict key (inside quotes) |
|
|
""", re.VERBOSE) |
|
|
|
|
|
class OptimizableField: |
|
|
""" |
|
|
Represents a parameter that can be optimized. |
|
|
|
|
|
This class encapsulates a runtime attribute using dynamic getter and setter |
|
|
functions. It allows the parameter to be exposed and manipulated by an external |
|
|
optimizer. An initial snapshot of the field can be stored and later used to reset |
|
|
the field to its original value. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
name: str, |
|
|
getter: Callable[[], Any], |
|
|
setter: Callable[[Any], None] |
|
|
): |
|
|
""" |
|
|
Initialize an OptimizableField instance. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
name : str |
|
|
The alias used to register the field in the registry. |
|
|
getter : Callable[[], Any] |
|
|
A function that returns the current value of the field. |
|
|
setter : Callable[[Any], None] |
|
|
A function that sets a new value to the field. |
|
|
""" |
|
|
self.name = name |
|
|
self._get = getter |
|
|
self._set = setter |
|
|
self._initial_value = None |
|
|
|
|
|
def get(self) -> Any: |
|
|
""" |
|
|
Retrieve the current value of the field. |
|
|
|
|
|
Returns |
|
|
------- |
|
|
Any |
|
|
The current value of the field. |
|
|
""" |
|
|
return self._get() |
|
|
|
|
|
def set(self, value: Any) -> None: |
|
|
""" |
|
|
Update the field with a new value. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
value : Any |
|
|
The new value to assign to the field. |
|
|
""" |
|
|
self._set(value) |
|
|
|
|
|
def init_snapshot(self) -> None: |
|
|
""" |
|
|
Capture a snapshot of the current field value. |
|
|
|
|
|
This method stores a deep copy of the current field value so that it |
|
|
can be restored later using `reset()`. |
|
|
""" |
|
|
current = self.get() |
|
|
self._initial_value = safe_deepcopy(current) |
|
|
|
|
|
def reset(self) -> None: |
|
|
""" |
|
|
Reset the field to its initial value. |
|
|
|
|
|
If the current value object defines a `__reset__()` method, it will be |
|
|
called to perform the reset. Otherwise, the field is reset to the deep-copied |
|
|
initial value stored by `init_snapshot()`. |
|
|
|
|
|
Raises |
|
|
------ |
|
|
ValueError |
|
|
If `init_snapshot()` has not been called before `reset()`. |
|
|
""" |
|
|
current = self.get() |
|
|
|
|
|
if self._initial_value is None: |
|
|
raise ValueError(f"Field '{self.name}' has no snapshot. Call init_snapshot() first.") |
|
|
|
|
|
if hasattr(current, "__reset__") and callable(current.__reset__): |
|
|
current.__reset__() |
|
|
else: |
|
|
self.set(safe_deepcopy(self._initial_value)) |
|
|
|
|
|
|
|
|
class ParamRegistry: |
|
|
""" |
|
|
Central registry for all parameters that can be exposed to optimization. |
|
|
|
|
|
Allows dynamic binding and tracking of runtime attributes via dot-paths, |
|
|
dictionary keys, or list indices. Provides getter/setter access to all |
|
|
registered parameters for optimizers. |
|
|
""" |
|
|
def __init__(self) -> None: |
|
|
"""Initialize an empty registry of optimizable fields.""" |
|
|
self.fields: Dict[str, OptimizableField] = {} |
|
|
|
|
|
def register_field(self, field: OptimizableField): |
|
|
"""Manually register an OptimizableField with its alias name.""" |
|
|
field.init_snapshot() |
|
|
self.fields[field.name] = field |
|
|
|
|
|
def get(self, name: str) -> Any: |
|
|
"""Retrieve the current value of a registered field by name.""" |
|
|
return self.fields[name].get() |
|
|
|
|
|
def get_field(self, name: str) -> OptimizableField: |
|
|
"""Retrieve the OptimizableField object by name.""" |
|
|
if name not in self.fields: |
|
|
raise ValueError(f"Field '{name}' is not registered.") |
|
|
else: |
|
|
return self.fields[name] |
|
|
|
|
|
def set(self, name: str, value: Any): |
|
|
"""Set the value of a registered field by name.""" |
|
|
self.fields[name].set(value) |
|
|
|
|
|
def names(self) -> List[str]: |
|
|
"""Return a list of all registered field names (aliases).""" |
|
|
return list(self.fields.keys()) |
|
|
|
|
|
def reset(self): |
|
|
"""Roll back all registered fields to their initial values.""" |
|
|
for field in self.fields.values(): |
|
|
field.reset() |
|
|
|
|
|
def reset_field(self, name: str): |
|
|
"""Roll back a registered field to its initial value.""" |
|
|
self.fields[name].reset() |
|
|
|
|
|
def track(self, root_or_obj: Any, path_or_attr: str, *, name: str | None = None): |
|
|
""" |
|
|
Register a parameter to be optimized. Supports both nested paths and direct attributes. |
|
|
|
|
|
Parameters: |
|
|
- root_or_obj (Any): the base object or container |
|
|
- path_or_attr (str): a path like 'prompt.template' or a direct attribute like 'template' |
|
|
- name (str | None): optional alias for this parameter |
|
|
|
|
|
Supported formats: |
|
|
- registry.track(program, "prompt.template") # nested attribute |
|
|
- registry.track(program, "metadata['style']") # dictionary key |
|
|
- registry.track(program, "components[2].prefix") # list index |
|
|
- registry.track(program.prompt, "template") # direct object + attribute |
|
|
- registry.track([ |
|
|
(program, "prompt.template"), |
|
|
(program, "metadata['style']", "style"), |
|
|
(program.prompt, "prefix", "prompt_prefix") |
|
|
]) # batch registration |
|
|
- registry.track(program, "prompt.template").track(program, "prompt.prefix") # chained calls |
|
|
|
|
|
- registry.track(program, "prompt_template_obj") # register a prompt_template instance |
|
|
|
|
|
Returns: |
|
|
- self (PromptRegistry): for chaining |
|
|
""" |
|
|
if isinstance(root_or_obj, list | tuple): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for item in root_or_obj: |
|
|
if len(item) == 2: |
|
|
self.track(item[0], item[1]) |
|
|
elif len(item) == 3: |
|
|
self.track(item[0], item[1], name=item[2]) |
|
|
return self |
|
|
|
|
|
if "." in path_or_attr or "[" in path_or_attr: |
|
|
return self._track_path(root_or_obj, path_or_attr, name) |
|
|
else: |
|
|
key = name or path_or_attr |
|
|
|
|
|
def getter(): |
|
|
return getattr(root_or_obj, path_or_attr) |
|
|
|
|
|
def setter(v): |
|
|
setattr(root_or_obj, path_or_attr, v) |
|
|
|
|
|
field = OptimizableField(key, getter, setter) |
|
|
if key in self.fields: |
|
|
import warnings |
|
|
warnings.warn(f"Field '{key}' is already registered. Overwriting.") |
|
|
self.register_field(field) |
|
|
return self |
|
|
|
|
|
def _track_path(self, root: Any, path: str, name: str | None = None): |
|
|
""" |
|
|
Internal helper that registers a nested field (via dot path, index, or key) |
|
|
as an OptimizableField by dynamically creating getter and setter functions. |
|
|
|
|
|
Parameters: |
|
|
- root (Any): the root object to start walking from |
|
|
- path (str): dot-separated path supporting list/dict access |
|
|
- name (Optional[str]): alias for the parameter (defaults to last path segment) |
|
|
|
|
|
Returns: |
|
|
- self |
|
|
""" |
|
|
key = name if name is not None else path |
|
|
parent, leaf = self._walk(root, path) |
|
|
|
|
|
def getter(): |
|
|
return parent[leaf] if isinstance(parent, (list, dict)) else getattr(parent, leaf) |
|
|
|
|
|
def setter(v): |
|
|
if isinstance(parent, (list, dict)): |
|
|
parent[leaf] = v |
|
|
else: |
|
|
setattr(parent, leaf, v) |
|
|
|
|
|
field = OptimizableField(key, getter, setter) |
|
|
self.register_field(field) |
|
|
return self |
|
|
|
|
|
|
|
|
def _walk(self, root, path: str): |
|
|
""" |
|
|
Internal helper to resolve a dot-separated path string into its parent container |
|
|
and the leaf attribute/key/index for assignment or retrieval. |
|
|
|
|
|
Supports: |
|
|
- Nested attributes: e.g. "a.b.c" |
|
|
- Dict key access: e.g. "config['key']" |
|
|
- List index access: e.g. "layers[0]" |
|
|
|
|
|
Parameters: |
|
|
- root (Any): root object to walk from |
|
|
- path (str): path string to resolve |
|
|
- create_missing (bool): unused placeholder for future extensions |
|
|
|
|
|
Returns: |
|
|
- (parent, leaf): where parent[leaf] or getattr(parent, leaf) is the target |
|
|
""" |
|
|
cur = root |
|
|
parts = [] |
|
|
for match in _PATH_RE.finditer(path): |
|
|
attr, idx, key = match.groups() |
|
|
if attr: |
|
|
parts.append(attr) |
|
|
elif idx: |
|
|
parts.append(int(idx)) |
|
|
elif key: |
|
|
parts.append(key) |
|
|
|
|
|
for part in parts[:-1]: |
|
|
if isinstance(part, int): |
|
|
cur = cur[part] |
|
|
else: |
|
|
cur = getattr(cur, part) if hasattr(cur, part) else cur[part] |
|
|
|
|
|
leaf = parts[-1] |
|
|
parent = cur |
|
|
return parent, leaf |
|
|
|
|
|
def _walk_old(self, root, path: str): |
|
|
""" |
|
|
Unused Function |
|
|
Internal helper to resolve a dot-separated path string into its parent container |
|
|
and the leaf attribute/key/index for assignment or retrieval. |
|
|
|
|
|
Supports: |
|
|
- Nested attributes: e.g. "a.b.c" |
|
|
- Dict key access: e.g. "config['key']" |
|
|
- List index access: e.g. "layers[0]" |
|
|
|
|
|
Parameters: |
|
|
- root (Any): root object to walk from |
|
|
- path (str): path string to resolve |
|
|
- create_missing (bool): unused placeholder for future extensions |
|
|
|
|
|
Returns: |
|
|
- (parent, leaf): where parent[leaf] or getattr(parent, leaf) is the target |
|
|
""" |
|
|
cur = root |
|
|
parts = path.split(".") |
|
|
for part in parts[:-1]: |
|
|
m = _INDEX_RE.match(part) |
|
|
if m: |
|
|
attr, idx = m.groups() |
|
|
cur = getattr(cur, attr) if attr else cur |
|
|
idx = idx.strip() |
|
|
if (idx.startswith("'") and idx.endswith("'")) or (idx.startswith('"') and idx.endswith('"')): |
|
|
idx = idx[1:-1] |
|
|
elif idx.isdigit(): |
|
|
idx = int(idx) |
|
|
cur = cur[idx] |
|
|
else: |
|
|
cur = getattr(cur, part) |
|
|
|
|
|
leaf = parts[-1] |
|
|
m = _INDEX_RE.match(leaf) |
|
|
if m: |
|
|
attr, idx = m.groups() |
|
|
parent = getattr(cur, attr) if attr else cur |
|
|
idx = idx.strip() |
|
|
if (idx.startswith("'") and idx.endswith("'")) or (idx.startswith('"') and idx.endswith('"')): |
|
|
idx = idx[1:-1] |
|
|
elif idx.isdigit(): |
|
|
idx = int(idx) |
|
|
return parent, idx |
|
|
return cur, leaf |
|
|
|
|
|
|
|
|
def safe_deepcopy(obj): |
|
|
""" |
|
|
Safely attempt to deep copy any Python object, with graceful fallback behavior. |
|
|
|
|
|
This function performs a standard `copy.deepcopy` when possible. If that fails |
|
|
(e.g., due to the presence of uncopyable components such as file handles, threads, |
|
|
or custom classes that don't support deep copying), it falls back to a more resilient strategy: |
|
|
|
|
|
1. Attempts to create a blank instance of the object's class using `__new__`. |
|
|
2. Recursively copies all attributes found in the object's `__dict__`, using: |
|
|
- `safe_deepcopy` for deep recursive copy, |
|
|
- `copy.copy` as a shallow fallback, |
|
|
- or the original reference as a last resort. |
|
|
3. If the object has no `__dict__` or cannot be instantiated, returns the original object. |
|
|
|
|
|
Parameters: |
|
|
obj (Any): The object to be deep copied. |
|
|
|
|
|
Returns: |
|
|
Any: A deep copy of the input object if possible, or a best-effort fallback copy. |
|
|
|
|
|
Warnings: |
|
|
Issues a `warnings.warn()` message whenever: |
|
|
- The deep copy fails and fallback mechanisms are used. |
|
|
- An attribute copy fails and falls back to a shallower or direct reference. |
|
|
- The class cannot be re-instantiated and the original reference is returned. |
|
|
|
|
|
Notes: |
|
|
- This function is intended for robust copying in systems where user-defined objects, |
|
|
templates, or agents may not support strict deep copying. |
|
|
- It is not guaranteed to preserve identity semantics or copy objects with `__slots__`. |
|
|
- For critical correctness or mutation isolation, ensure your objects are deepcopy-compatible. |
|
|
|
|
|
Example: |
|
|
>>> obj = CustomObject() |
|
|
>>> obj_copy = safe_deepcopy(obj) |
|
|
""" |
|
|
try: |
|
|
return copy.deepcopy(obj) |
|
|
except Exception: |
|
|
warnings.warn(f"Failed to deepcopy {obj.__class__.__name__}. Falling back to advanced handling.") |
|
|
pass |
|
|
|
|
|
try: |
|
|
|
|
|
new_instance = obj.__class__.__new__(obj.__class__) |
|
|
except Exception: |
|
|
warnings.warn(f"Failed to create a blank instance of {obj.__class__.__name__}. Falling back to reference.") |
|
|
return obj |
|
|
|
|
|
for attr, value in getattr(obj, "__dict__", {}).items(): |
|
|
try: |
|
|
setattr(new_instance, attr, safe_deepcopy(value)) |
|
|
except Exception: |
|
|
try: |
|
|
warnings.warn(f"Failed to copy {attr} of {obj.__class__.__name__}. Falling back to shallow copy.") |
|
|
setattr(new_instance, attr, copy.copy(value)) |
|
|
except Exception: |
|
|
warnings.warn(f"Failed to copy {attr} of {obj.__class__.__name__}. Falling back to reference.") |
|
|
setattr(new_instance, attr, value) |
|
|
|
|
|
return new_instance |
|
|
|
|
|
|
|
|
class PromptTemplateRegister(ParamRegistry): |
|
|
""" |
|
|
Unused Class |
|
|
Enhanced parameter registry that supports directly registering PromptTemplate instances |
|
|
or prompt strings as a single optimizable object. |
|
|
""" |
|
|
|
|
|
def track(self, root_or_obj: Any, path_or_attr: str, *, name: str | None = None): |
|
|
if isinstance(root_or_obj, (list, tuple)): |
|
|
for item in root_or_obj: |
|
|
if len(item) == 2: |
|
|
self.track(item[0], item[1]) |
|
|
elif len(item) == 3: |
|
|
self.track(item[0], item[1], name=item[2]) |
|
|
return self |
|
|
|
|
|
if '.' in path_or_attr or '[' in path_or_attr: |
|
|
return self._track_path(root_or_obj, path_or_attr, name) |
|
|
else: |
|
|
key = name or path_or_attr |
|
|
|
|
|
try: |
|
|
value = getattr(root_or_obj, path_or_attr) |
|
|
except AttributeError: |
|
|
return super().track(root_or_obj, path_or_attr, name=name) |
|
|
|
|
|
if isinstance(value, (str, PromptTemplate)): |
|
|
|
|
|
field = OptimizableField( |
|
|
key, |
|
|
getter=lambda: getattr(root_or_obj, path_or_attr), |
|
|
setter=lambda v: setattr(root_or_obj, path_or_attr, v) |
|
|
) |
|
|
self.register_field(field) |
|
|
return self |
|
|
|
|
|
|
|
|
return super().track(root_or_obj, path_or_attr, name=name) |