arvla-bridge / src /hf_compat.py
you2who's picture
Duplicate from you2who/paligemma-arvla-bridge
2672775
import copy
import dataclasses
import enum
import sys
from typing import Any, Dict, Mapping, Type, TypeVar, get_args, get_origin, get_type_hints
import yaml
ConfigT = TypeVar("ConfigT", bound="Config")
class Config:
def __init__(self, **kwargs: Any):
annotations = _collect_annotations(type(self))
for key, annotation in annotations.items():
if key in kwargs:
value = kwargs[key]
elif hasattr(type(self), key):
value = copy.deepcopy(getattr(type(self), key))
else:
continue
setattr(self, key, _convert_value(annotation, value))
for key, value in kwargs.items():
if key not in annotations:
setattr(self, key, value)
if not hasattr(self, "pretrain_config"):
self.pretrain_config = EmptyConfig()
self.__post_init__()
def __post_init__(self):
return
@property
def empty(self) -> bool:
return len(self.__dict__) == 0
def as_json(self) -> Dict[str, Any]:
return {
key: _to_json(value)
for key, value in self.__dict__.items()
if key != "pretrain_config"
}
@classmethod
def from_dict(cls: Type[ConfigT], data: Mapping[str, Any]) -> ConfigT:
return cls(**dict(data))
@classmethod
def from_yaml(cls: Type[ConfigT], file_path: str) -> ConfigT:
with open(file_path, "r") as file:
data = yaml.safe_load(file)
if data is None:
data = {}
return cls.from_dict(data)
class EmptyConfig(Config):
def __init__(self, **kwargs: Any):
self.__dict__.update(kwargs)
class Configurable:
ConfigT = Any
def __init__(self, config: Any):
self.config = config
@classmethod
def __class_getitem__(cls, _):
return cls
class Template:
ConfigT = Any
@classmethod
def __class_getitem__(cls, _):
return cls
def _collect_annotations(cls: Type[Any]) -> Dict[str, Any]:
annotations: Dict[str, Any] = {}
for base in reversed(cls.__mro__):
if base is object:
continue
module = sys.modules.get(base.__module__)
module_globals = vars(module) if module is not None else {}
try:
hints = get_type_hints(base, globalns=module_globals, localns=module_globals)
except Exception:
hints = getattr(base, "__annotations__", {})
annotations.update(hints)
return annotations
def _is_config_type(annotation: Any) -> bool:
return isinstance(annotation, type) and issubclass(annotation, Config)
def _convert_value(annotation: Any, value: Any) -> Any:
if value is None:
return None
origin = get_origin(annotation)
args = get_args(annotation)
if _is_config_type(annotation) and isinstance(value, annotation):
return value
if _is_config_type(annotation) and isinstance(value, Mapping):
return annotation(**value)
if isinstance(annotation, type) and issubclass(annotation, enum.Enum) and not isinstance(value, annotation):
return annotation(value)
if origin is None:
return value
if origin in (list, tuple):
item_type = args[0] if len(args) > 0 else Any
converted = [_convert_value(item_type, item) for item in value]
return converted if origin is list else tuple(converted)
if origin is dict:
key_type = args[0] if len(args) > 0 else Any
val_type = args[1] if len(args) > 1 else Any
return {
_convert_value(key_type, key): _convert_value(val_type, item)
for key, item in value.items()
}
if str(origin) in {"typing.Union", "types.UnionType"}:
non_none_args = [arg for arg in args if arg is not type(None)]
for arg in non_none_args:
try:
return _convert_value(arg, value)
except Exception:
continue
return value
return value
def _to_json(value: Any) -> Any:
if isinstance(value, Config):
return value.as_json()
if isinstance(value, enum.Enum):
return value.value
if dataclasses.is_dataclass(value):
return dataclasses.asdict(value)
if isinstance(value, dict):
return {key: _to_json(item) for key, item in value.items()}
if isinstance(value, list):
return [_to_json(item) for item in value]
if isinstance(value, tuple):
return [_to_json(item) for item in value]
return value