File size: 3,958 Bytes
5374a2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92

from ...optimizers.engine.registry import ParamRegistry
from typing import Any, Optional, List, Dict
# from ...optimizers.engine.registry import OptimizableField

class MiproRegistry(ParamRegistry):
    """
    Extended ParamRegistry that supports storing input_names and output_names
    for each optimizable field. Compatible with all original track() usages.
    """

    def track(
        self,
        root_or_obj: Any,
        path_or_attr: str = None,
        *,
        name: Optional[str] = None,
        input_names: Optional[List[str]] = None,
        output_names: Optional[List[str]] = None,
        input_descs: Optional[Dict[str, str]] = None,
        output_descs: Optional[Dict[str, str]] = None,
    ):
        # Support batch registration with list/tuple
        if isinstance(root_or_obj, (list, tuple)):
            for item in root_or_obj:
                if isinstance(item, dict):
                    self.track(**item)
                elif isinstance(item, (list, tuple)):
                    if len(item) == 7:
                        self.track(
                            item[0], item[1],
                            name=item[2],
                            input_names=item[3],
                            output_names=item[4],
                            input_descs=item[5],
                            output_descs=item[6]
                        )
                    else:
                        raise ValueError("Each tuple must be (obj, attr, name, input_names, output_names, input_descs, output_descs)")
            return self

        # Call parent to do normal tracking
        super().track(root_or_obj, path_or_attr, name=name)

        # Inject input/output names into the field
        key = name or path_or_attr
        field = self.fields[key]
        field.input_names = input_names or []
        field.output_names = output_names or []
        field.input_descs = input_descs or {}
        field.output_descs = output_descs or {}

        return self
    
    def get_input_names(self, name: str) -> List[str]:
        """Return the input_names for a registered field, or an empty list if not set."""
        return getattr(self.fields[name], "input_names", None) or []

    def get_output_names(self, name: str) -> List[str]:
        """Return the output_names for a registered field, or an empty list if not set."""
        return getattr(self.fields[name], "output_names", None) or []
    
    def get_input_desc_dict(self, name: str) -> Dict[str, str]:
        """Return the input_descs for a registered field, or an empty dict if not set."""
        return getattr(self.fields[name], "input_descs", {})

    def get_output_desc_dict(self, name: str) -> Dict[str, str]:
        """Return the output_descs for a registered field, or an empty dict if not set."""
        return getattr(self.fields[name], "output_descs", {})
    
    def get_input_desc(self, name: str, input_name: str) -> str:
        """Return the input_desc for a registered field, or an empty string if not set."""
        return self.get_input_desc_dict(name).get(input_name, "")

    def get_output_desc(self, name: str, output_name: str) -> str:
        """Return the output_desc for a registered field, or an empty string if not set."""
        return self.get_output_desc_dict(name).get(output_name, "")
    
    def describe(self) -> Dict[str, Dict[str, Any]]:
        """
        Returns a dict of all fields and their metadata, including input/output names if present.
        """
        result = {}
        for name, field in self.fields.items():
            result[name] = {
                "value": field.get(),
                "input_names": getattr(field, "input_names", None),
                "output_names": getattr(field, "output_names", None),
                "input_descs": getattr(field, "input_descs", {}),
                "output_descs": getattr(field, "output_descs", {}),
            }
        return result