iLOVE2D's picture
Upload 2846 files
5374a2d verified
from ...optimizers.engine.registry import ParamRegistry
from typing import Any, Optional, List, Dict
# from ...optimizers.engine.registry import OptimizableField
class MiproRegistry(ParamRegistry):
"""
Extended ParamRegistry that supports storing input_names and output_names
for each optimizable field. Compatible with all original track() usages.
"""
def track(
self,
root_or_obj: Any,
path_or_attr: str = None,
*,
name: Optional[str] = None,
input_names: Optional[List[str]] = None,
output_names: Optional[List[str]] = None,
input_descs: Optional[Dict[str, str]] = None,
output_descs: Optional[Dict[str, str]] = None,
):
# Support batch registration with list/tuple
if isinstance(root_or_obj, (list, tuple)):
for item in root_or_obj:
if isinstance(item, dict):
self.track(**item)
elif isinstance(item, (list, tuple)):
if len(item) == 7:
self.track(
item[0], item[1],
name=item[2],
input_names=item[3],
output_names=item[4],
input_descs=item[5],
output_descs=item[6]
)
else:
raise ValueError("Each tuple must be (obj, attr, name, input_names, output_names, input_descs, output_descs)")
return self
# Call parent to do normal tracking
super().track(root_or_obj, path_or_attr, name=name)
# Inject input/output names into the field
key = name or path_or_attr
field = self.fields[key]
field.input_names = input_names or []
field.output_names = output_names or []
field.input_descs = input_descs or {}
field.output_descs = output_descs or {}
return self
def get_input_names(self, name: str) -> List[str]:
"""Return the input_names for a registered field, or an empty list if not set."""
return getattr(self.fields[name], "input_names", None) or []
def get_output_names(self, name: str) -> List[str]:
"""Return the output_names for a registered field, or an empty list if not set."""
return getattr(self.fields[name], "output_names", None) or []
def get_input_desc_dict(self, name: str) -> Dict[str, str]:
"""Return the input_descs for a registered field, or an empty dict if not set."""
return getattr(self.fields[name], "input_descs", {})
def get_output_desc_dict(self, name: str) -> Dict[str, str]:
"""Return the output_descs for a registered field, or an empty dict if not set."""
return getattr(self.fields[name], "output_descs", {})
def get_input_desc(self, name: str, input_name: str) -> str:
"""Return the input_desc for a registered field, or an empty string if not set."""
return self.get_input_desc_dict(name).get(input_name, "")
def get_output_desc(self, name: str, output_name: str) -> str:
"""Return the output_desc for a registered field, or an empty string if not set."""
return self.get_output_desc_dict(name).get(output_name, "")
def describe(self) -> Dict[str, Dict[str, Any]]:
"""
Returns a dict of all fields and their metadata, including input/output names if present.
"""
result = {}
for name, field in self.fields.items():
result[name] = {
"value": field.get(),
"input_names": getattr(field, "input_names", None),
"output_names": getattr(field, "output_names", None),
"input_descs": getattr(field, "input_descs", {}),
"output_descs": getattr(field, "output_descs", {}),
}
return result