|
|
|
|
|
from ...optimizers.engine.registry import ParamRegistry |
|
|
from typing import Any, Optional, List, Dict |
|
|
|
|
|
|
|
|
class MiproRegistry(ParamRegistry): |
|
|
""" |
|
|
Extended ParamRegistry that supports storing input_names and output_names |
|
|
for each optimizable field. Compatible with all original track() usages. |
|
|
""" |
|
|
|
|
|
def track( |
|
|
self, |
|
|
root_or_obj: Any, |
|
|
path_or_attr: str = None, |
|
|
*, |
|
|
name: Optional[str] = None, |
|
|
input_names: Optional[List[str]] = None, |
|
|
output_names: Optional[List[str]] = None, |
|
|
input_descs: Optional[Dict[str, str]] = None, |
|
|
output_descs: Optional[Dict[str, str]] = None, |
|
|
): |
|
|
|
|
|
if isinstance(root_or_obj, (list, tuple)): |
|
|
for item in root_or_obj: |
|
|
if isinstance(item, dict): |
|
|
self.track(**item) |
|
|
elif isinstance(item, (list, tuple)): |
|
|
if len(item) == 7: |
|
|
self.track( |
|
|
item[0], item[1], |
|
|
name=item[2], |
|
|
input_names=item[3], |
|
|
output_names=item[4], |
|
|
input_descs=item[5], |
|
|
output_descs=item[6] |
|
|
) |
|
|
else: |
|
|
raise ValueError("Each tuple must be (obj, attr, name, input_names, output_names, input_descs, output_descs)") |
|
|
return self |
|
|
|
|
|
|
|
|
super().track(root_or_obj, path_or_attr, name=name) |
|
|
|
|
|
|
|
|
key = name or path_or_attr |
|
|
field = self.fields[key] |
|
|
field.input_names = input_names or [] |
|
|
field.output_names = output_names or [] |
|
|
field.input_descs = input_descs or {} |
|
|
field.output_descs = output_descs or {} |
|
|
|
|
|
return self |
|
|
|
|
|
def get_input_names(self, name: str) -> List[str]: |
|
|
"""Return the input_names for a registered field, or an empty list if not set.""" |
|
|
return getattr(self.fields[name], "input_names", None) or [] |
|
|
|
|
|
def get_output_names(self, name: str) -> List[str]: |
|
|
"""Return the output_names for a registered field, or an empty list if not set.""" |
|
|
return getattr(self.fields[name], "output_names", None) or [] |
|
|
|
|
|
def get_input_desc_dict(self, name: str) -> Dict[str, str]: |
|
|
"""Return the input_descs for a registered field, or an empty dict if not set.""" |
|
|
return getattr(self.fields[name], "input_descs", {}) |
|
|
|
|
|
def get_output_desc_dict(self, name: str) -> Dict[str, str]: |
|
|
"""Return the output_descs for a registered field, or an empty dict if not set.""" |
|
|
return getattr(self.fields[name], "output_descs", {}) |
|
|
|
|
|
def get_input_desc(self, name: str, input_name: str) -> str: |
|
|
"""Return the input_desc for a registered field, or an empty string if not set.""" |
|
|
return self.get_input_desc_dict(name).get(input_name, "") |
|
|
|
|
|
def get_output_desc(self, name: str, output_name: str) -> str: |
|
|
"""Return the output_desc for a registered field, or an empty string if not set.""" |
|
|
return self.get_output_desc_dict(name).get(output_name, "") |
|
|
|
|
|
def describe(self) -> Dict[str, Dict[str, Any]]: |
|
|
""" |
|
|
Returns a dict of all fields and their metadata, including input/output names if present. |
|
|
""" |
|
|
result = {} |
|
|
for name, field in self.fields.items(): |
|
|
result[name] = { |
|
|
"value": field.get(), |
|
|
"input_names": getattr(field, "input_names", None), |
|
|
"output_names": getattr(field, "output_names", None), |
|
|
"input_descs": getattr(field, "input_descs", {}), |
|
|
"output_descs": getattr(field, "output_descs", {}), |
|
|
} |
|
|
return result |