File size: 8,484 Bytes
5ccf219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
"""
Unified registry utilities and simple JSON-based save/load helpers.

This module provides:
- create_registry: factory to create (registry dict, register decorator, get_class)
- capture_init_args: decorator to record __init__ kwargs on instances as _init_args
- save_object / load_object: serialize/deserialize object configs via registry
"""

from __future__ import annotations

import inspect
import json
from typing import Dict, Type, Callable, Optional, Tuple, TypeVar, Any
import torch

T = TypeVar("T")


def create_registry(
    registry_name: str,
    case_insensitive: bool = False,
) -> Tuple[Dict[str, Type[T]], Callable[..., Type[T]], Callable[[str], Type[T]]]:
    """
    Create a registry system with register and get functions.

    Args:
        registry_name: Name used in error messages (e.g., "projector")
        case_insensitive: Whether to store lowercase versions of names

    Returns:
        (registry_dict, register_function, get_function)
    """

    registry: Dict[str, Type[T]] = {}

    def register(cls_or_name=None, name: Optional[str] = None):
        """Register a class in the registry. Supports multiple usage patterns.

        Usage:
            @register
            class Foo: ...

            @register("foo")
            class Foo: ...

            @register(name="foo")
            class Foo: ...
        """

        def _register(c: Type[T]) -> Type[T]:
            # Determine the name to use
            if isinstance(cls_or_name, str):
                class_name = cls_or_name
            elif name is not None:
                class_name = name
            else:
                class_name = c.__name__

            registry[class_name] = c
            if case_insensitive:
                registry[class_name.lower()] = c
            return c

        if cls_or_name is not None and not isinstance(cls_or_name, str):
            # Called as @register or register(cls)
            return _register(cls_or_name)
        else:
            # Called as @register("name") or @register(name="name")
            return _register

    def get_class(name: str) -> Type[T]:
        """Get class by name from registry."""
        if name not in registry:
            # Build readable available list without duplicates when case_insensitive
            seen = set()
            available = []
            for k in registry.keys():
                if k.lower() in seen:
                    continue
                seen.add(k.lower())
                available.append(k)
            raise ValueError(
                f"Unknown {registry_name} class: {name}. Available: {available}"
            )
        return registry[name]

    return registry, register, get_class


def capture_init_args(cls):
    """
    Decorator to capture initialization arguments of a class.

    Stores the mapping of the constructor's parameters to the values supplied
    at instantiation time into `self._init_args` for later serialization.
    """
    original_init = cls.__init__

    def new_init(self, *args, **kwargs):
        # Store all initialization arguments
        init_args: Dict[str, Any] = {}

        # Get parameter names from the original __init__ method
        sig = inspect.signature(original_init)
        param_names = list(sig.parameters.keys())[1:]  # Skip 'self'

        # Map positional args to parameter names
        for i, arg in enumerate(args):
            if i < len(param_names):
                init_args[param_names[i]] = arg

        # Add keyword args
        init_args.update(kwargs)

        self._init_args = init_args

        # Call the original __init__
        original_init(self, *args, **kwargs)

    cls.__init__ = new_init
    return cls


# -------------------------
# Serialization utilities
# -------------------------

def _encode_value(value: Any) -> Any:
    """Best-effort JSON encoding for common ML types."""
    # Primitives and None
    if value is None or isinstance(value, (bool, int, float, str)):
        return value

    # Tuples -> lists
    if isinstance(value, tuple):
        return [
            _encode_value(v) for v in value
        ]

    # Lists
    if isinstance(value, list):
        return [
            _encode_value(v) for v in value
        ]

    # Dicts
    if isinstance(value, dict):
        return {k: _encode_value(v) for k, v in value.items()}

    # torch-specific types
    if torch is not None:
        # torch.dtype
        if isinstance(value, type(getattr(torch, "float32", object))):
            # Guard: torch.dtype is not a class; rely on str(value) format
            s = str(value)
            if s.startswith("torch."):
                return {"__type__": "torch.dtype", "value": s.split(".")[-1]}

        # torch.device
        if isinstance(value, getattr(torch, "device", ())):
            return {"__type__": "torch.device", "value": str(value)}

    # Fallback to string representation
    return {"__type__": "str", "value": str(value)}


def _decode_value(value: Any) -> Any:
    """Decode values produced by _encode_value, recursively for containers."""
    # Lists: decode each element
    if isinstance(value, list):
        return [_decode_value(v) for v in value]

    # Dicts: either a typed-marker dict or a regular mapping that needs recursive decoding
    if isinstance(value, dict):
        if "__type__" in value:
            t = value.get("__type__")
            v = value.get("value")

            if t == "torch.dtype" and torch is not None:
                dtype = getattr(torch, str(v), None)
                if dtype is None:
                    raise ValueError(f"Unknown torch.dtype: {v}")
                return dtype

            if t == "torch.device" and torch is not None:
                return torch.device(v)

            if t == "str":
                return str(v)

            # Unknown type marker; return raw as-is
            return value

        # Regular dict: decode values recursively
        return {k: _decode_value(v) for k, v in value.items()}

    # Primitives and anything else: return as-is
    return value


def save_object(obj: Any, file_path: str) -> None:
    """
    Save an object's construction config to a JSON file.

    The object is expected to have been decorated with capture_init_args,
    so that `obj._init_args` exists.
    """
    class_name = obj.__class__.__name__
    init_args = getattr(obj, "_init_args", {})

    serializable_args = _encode_value(init_args)
    payload = {
        "class": class_name,
        "init_args": serializable_args,
    }

    with open(file_path, "w", encoding="utf-8") as f:
        json.dump(payload, f, indent=2)


def load_object(
    file_path: str,
    get_class_fn: Callable[[str], Type[T]],
    override_args: Optional[Dict[str, Any]] = None,
) -> T:
    """
    Load an object from a JSON config file previously saved by save_object.

    Args:
        file_path: Path to JSON file
        get_class_fn: Function to resolve class names from registry
        override_args: Optional dict to override stored init args

    Returns:
        Instantiated object of type T
    """
    with open(file_path, "r", encoding="utf-8") as f:
        payload = json.load(f)

    class_name = payload["class"]
    encoded_args = payload.get("init_args", {})
    init_args = _decode_value(encoded_args)

    if override_args:
        init_args.update(override_args)

    cls = get_class_fn(class_name)
    return cls(**init_args)


def dumps_object_config(obj: Any) -> str:
    """Return a JSON string with the object's class and init args."""
    class_name = obj.__class__.__name__
    init_args = getattr(obj, "_init_args", {})
    serializable_args = _encode_value(init_args)
    return json.dumps({"class": class_name, "init_args": serializable_args}, indent=2)


def loads_object_config(
    s: str,
    get_class_fn: Callable[[str], Type[T]],
    override_args: Optional[Dict[str, Any]] = None,
) -> T:
    """Instantiate an object from a JSON string produced by dumps_object_config."""
    payload = json.loads(s)
    class_name = payload["class"]
    encoded_args = payload.get("init_args", {})
    init_args = _decode_value(encoded_args)
    if override_args:
        init_args.update(override_args)
    cls = get_class_fn(class_name)
    return cls(**init_args)


# Model Registry System (case-insensitive for backward compatibility)
PROJECTOR_REGISTRY, register_model, get_projector_class = create_registry(
    "projector", case_insensitive=True
)