|
|
import re |
|
|
import sys |
|
|
import copy |
|
|
import types |
|
|
import inspect |
|
|
import keyword |
|
|
import builtins |
|
|
import functools |
|
|
import _thread |
|
|
from types import GenericAlias |
|
|
|
|
|
|
|
|
__all__ = ['dataclass', |
|
|
'field', |
|
|
'Field', |
|
|
'FrozenInstanceError', |
|
|
'InitVar', |
|
|
'MISSING', |
|
|
|
|
|
|
|
|
'fields', |
|
|
'asdict', |
|
|
'astuple', |
|
|
'make_dataclass', |
|
|
'replace', |
|
|
'is_dataclass', |
|
|
] |
|
|
|
|
|
class FrozenInstanceError(AttributeError): pass |
|
|
|
|
|
|
|
|
class _HAS_DEFAULT_FACTORY_CLASS: |
|
|
def __repr__(self): |
|
|
return '<factory>' |
|
|
_HAS_DEFAULT_FACTORY = _HAS_DEFAULT_FACTORY_CLASS() |
|
|
|
|
|
|
|
|
class _MISSING_TYPE: |
|
|
pass |
|
|
MISSING = _MISSING_TYPE() |
|
|
|
|
|
|
|
|
_EMPTY_METADATA = types.MappingProxyType({}) |
|
|
|
|
|
class _FIELD_BASE: |
|
|
def __init__(self, name): |
|
|
self.name = name |
|
|
def __repr__(self): |
|
|
return self.name |
|
|
_FIELD = _FIELD_BASE('_FIELD') |
|
|
_FIELD_CLASSVAR = _FIELD_BASE('_FIELD_CLASSVAR') |
|
|
_FIELD_INITVAR = _FIELD_BASE('_FIELD_INITVAR') |
|
|
|
|
|
|
|
|
_FIELDS = '__dataclass_fields__' |
|
|
|
|
|
|
|
|
_PARAMS = '__dataclass_params__' |
|
|
|
|
|
|
|
|
_POST_INIT_NAME = '__post_init__' |
|
|
|
|
|
|
|
|
_MODULE_IDENTIFIER_RE = re.compile(r'^(?:\s*(\w+)\s*\.)?\s*(\w+)') |
|
|
|
|
|
class InitVar: |
|
|
__slots__ = ('type', ) |
|
|
|
|
|
def __init__(self, type): |
|
|
self.type = type |
|
|
|
|
|
def __repr__(self): |
|
|
if isinstance(self.type, type): |
|
|
type_name = self.type.__name__ |
|
|
else: |
|
|
|
|
|
type_name = repr(self.type) |
|
|
return f'dataclasses.InitVar[{type_name}]' |
|
|
|
|
|
def __class_getitem__(cls, type): |
|
|
return InitVar(type) |
|
|
|
|
|
|
|
|
class Field: |
|
|
__slots__ = ('name', |
|
|
'type', |
|
|
'default', |
|
|
'default_factory', |
|
|
'repr', |
|
|
'hash', |
|
|
'init', |
|
|
'compare', |
|
|
'metadata', |
|
|
'_field_type', |
|
|
) |
|
|
|
|
|
def __init__(self, default, default_factory, init, repr, hash, compare, |
|
|
metadata): |
|
|
self.name = None |
|
|
self.type = None |
|
|
self.default = default |
|
|
self.default_factory = default_factory |
|
|
self.init = init |
|
|
self.repr = repr |
|
|
self.hash = hash |
|
|
self.compare = compare |
|
|
self.metadata = (_EMPTY_METADATA |
|
|
if metadata is None else |
|
|
types.MappingProxyType(metadata)) |
|
|
self._field_type = None |
|
|
|
|
|
def __repr__(self): |
|
|
return ('Field(' |
|
|
f'name={self.name!r},' |
|
|
f'type={self.type!r},' |
|
|
f'default={self.default!r},' |
|
|
f'default_factory={self.default_factory!r},' |
|
|
f'init={self.init!r},' |
|
|
f'repr={self.repr!r},' |
|
|
f'hash={self.hash!r},' |
|
|
f'compare={self.compare!r},' |
|
|
f'metadata={self.metadata!r},' |
|
|
f'_field_type={self._field_type}' |
|
|
')') |
|
|
|
|
|
def __set_name__(self, owner, name): |
|
|
func = getattr(type(self.default), '__set_name__', None) |
|
|
if func: |
|
|
|
|
|
func(self.default, owner, name) |
|
|
|
|
|
__class_getitem__ = classmethod(GenericAlias) |
|
|
|
|
|
|
|
|
class _DataclassParams: |
|
|
__slots__ = ('init', |
|
|
'repr', |
|
|
'eq', |
|
|
'order', |
|
|
'unsafe_hash', |
|
|
'frozen', |
|
|
) |
|
|
|
|
|
def __init__(self, init, repr, eq, order, unsafe_hash, frozen): |
|
|
self.init = init |
|
|
self.repr = repr |
|
|
self.eq = eq |
|
|
self.order = order |
|
|
self.unsafe_hash = unsafe_hash |
|
|
self.frozen = frozen |
|
|
|
|
|
def __repr__(self): |
|
|
return ('_DataclassParams(' |
|
|
f'init={self.init!r},' |
|
|
f'repr={self.repr!r},' |
|
|
f'eq={self.eq!r},' |
|
|
f'order={self.order!r},' |
|
|
f'unsafe_hash={self.unsafe_hash!r},' |
|
|
f'frozen={self.frozen!r}' |
|
|
')') |
|
|
|
|
|
|
|
|
def field(*, default=MISSING, default_factory=MISSING, init=True, repr=True, |
|
|
hash=None, compare=True, metadata=None): |
|
|
|
|
|
|
|
|
if default is not MISSING and default_factory is not MISSING: |
|
|
raise ValueError('cannot specify both default and default_factory') |
|
|
return Field(default, default_factory, init, repr, hash, compare, |
|
|
metadata) |
|
|
|
|
|
|
|
|
def _tuple_str(obj_name, fields): |
|
|
|
|
|
if not fields: |
|
|
return '()' |
|
|
return f'({",".join([f"{obj_name}.{f.name}" for f in fields])},)' |
|
|
|
|
|
|
|
|
def _recursive_repr(user_function): |
|
|
|
|
|
repr_running = set() |
|
|
|
|
|
@functools.wraps(user_function) |
|
|
def wrapper(self): |
|
|
key = id(self), _thread.get_ident() |
|
|
if key in repr_running: |
|
|
return '...' |
|
|
repr_running.add(key) |
|
|
try: |
|
|
result = user_function(self) |
|
|
finally: |
|
|
repr_running.discard(key) |
|
|
return result |
|
|
return wrapper |
|
|
|
|
|
|
|
|
def _create_fn(name, args, body, *, globals=None, locals=None, |
|
|
return_type=MISSING): |
|
|
|
|
|
if locals is None: |
|
|
locals = {} |
|
|
if 'BUILTINS' not in locals: |
|
|
locals['BUILTINS'] = builtins |
|
|
return_annotation = '' |
|
|
if return_type is not MISSING: |
|
|
locals['_return_type'] = return_type |
|
|
return_annotation = '->_return_type' |
|
|
args = ','.join(args) |
|
|
body = '\n'.join(f' {b}' for b in body) |
|
|
|
|
|
txt = f' def {name}({args}){return_annotation}:\n{body}' |
|
|
|
|
|
local_vars = ', '.join(locals.keys()) |
|
|
txt = f"def __create_fn__({local_vars}):\n{txt}\n return {name}" |
|
|
|
|
|
ns = {} |
|
|
exec(txt, globals, ns) |
|
|
return ns['__create_fn__'](**locals) |
|
|
|
|
|
|
|
|
def _field_assign(frozen, name, value, self_name): |
|
|
|
|
|
if frozen: |
|
|
return f'BUILTINS.object.__setattr__({self_name},{name!r},{value})' |
|
|
return f'{self_name}.{name}={value}' |
|
|
|
|
|
|
|
|
def _field_init(f, frozen, globals, self_name): |
|
|
|
|
|
|
|
|
default_name = f'_dflt_{f.name}' |
|
|
if f.default_factory is not MISSING: |
|
|
if f.init: |
|
|
|
|
|
globals[default_name] = f.default_factory |
|
|
value = (f'{default_name}() ' |
|
|
f'if {f.name} is _HAS_DEFAULT_FACTORY ' |
|
|
f'else {f.name}') |
|
|
else: |
|
|
|
|
|
|
|
|
globals[default_name] = f.default_factory |
|
|
value = f'{default_name}()' |
|
|
else: |
|
|
if f.init: |
|
|
if f.default is MISSING: |
|
|
value = f.name |
|
|
elif f.default is not MISSING: |
|
|
globals[default_name] = f.default |
|
|
value = f.name |
|
|
else: |
|
|
|
|
|
return None |
|
|
|
|
|
if f._field_type is _FIELD_INITVAR: |
|
|
return None |
|
|
|
|
|
return _field_assign(frozen, f.name, value, self_name) |
|
|
|
|
|
|
|
|
def _init_param(f): |
|
|
|
|
|
if f.default is MISSING and f.default_factory is MISSING: |
|
|
|
|
|
default = '' |
|
|
elif f.default is not MISSING: |
|
|
|
|
|
default = f'=_dflt_{f.name}' |
|
|
elif f.default_factory is not MISSING: |
|
|
default = '=_HAS_DEFAULT_FACTORY' |
|
|
return f'{f.name}:_type_{f.name}{default}' |
|
|
|
|
|
|
|
|
def _init_fn(fields, frozen, has_post_init, self_name, globals): |
|
|
|
|
|
seen_default = False |
|
|
for f in fields: |
|
|
if f.init: |
|
|
if not (f.default is MISSING and f.default_factory is MISSING): |
|
|
seen_default = True |
|
|
elif seen_default: |
|
|
raise TypeError(f'non-default argument {f.name!r} ' |
|
|
'follows default argument') |
|
|
|
|
|
locals = {f'_type_{f.name}': f.type for f in fields} |
|
|
locals.update({ |
|
|
'MISSING': MISSING, |
|
|
'_HAS_DEFAULT_FACTORY': _HAS_DEFAULT_FACTORY, |
|
|
}) |
|
|
|
|
|
body_lines = [] |
|
|
for f in fields: |
|
|
line = _field_init(f, frozen, locals, self_name) |
|
|
|
|
|
if line: |
|
|
body_lines.append(line) |
|
|
|
|
|
if has_post_init: |
|
|
params_str = ','.join(f.name for f in fields |
|
|
if f._field_type is _FIELD_INITVAR) |
|
|
body_lines.append(f'{self_name}.{_POST_INIT_NAME}({params_str})') |
|
|
|
|
|
|
|
|
if not body_lines: |
|
|
body_lines = ['pass'] |
|
|
|
|
|
return _create_fn('__init__', |
|
|
[self_name] + [_init_param(f) for f in fields if f.init], |
|
|
body_lines, |
|
|
locals=locals, |
|
|
globals=globals, |
|
|
return_type=None) |
|
|
|
|
|
|
|
|
def _repr_fn(fields, globals): |
|
|
fn = _create_fn('__repr__', |
|
|
('self',), |
|
|
['return self.__class__.__qualname__ + f"(' + |
|
|
', '.join([f"{f.name}={{self.{f.name}!r}}" |
|
|
for f in fields]) + |
|
|
')"'], |
|
|
globals=globals) |
|
|
return _recursive_repr(fn) |
|
|
|
|
|
|
|
|
def _frozen_get_del_attr(cls, fields, globals): |
|
|
locals = {'cls': cls, |
|
|
'FrozenInstanceError': FrozenInstanceError} |
|
|
if fields: |
|
|
fields_str = '(' + ','.join(repr(f.name) for f in fields) + ',)' |
|
|
else: |
|
|
fields_str = '()' |
|
|
return (_create_fn('__setattr__', |
|
|
('self', 'name', 'value'), |
|
|
(f'if type(self) is cls or name in {fields_str}:', |
|
|
' raise FrozenInstanceError(f"cannot assign to field {name!r}")', |
|
|
f'super(cls, self).__setattr__(name, value)'), |
|
|
locals=locals, |
|
|
globals=globals), |
|
|
_create_fn('__delattr__', |
|
|
('self', 'name'), |
|
|
(f'if type(self) is cls or name in {fields_str}:', |
|
|
' raise FrozenInstanceError(f"cannot delete field {name!r}")', |
|
|
f'super(cls, self).__delattr__(name)'), |
|
|
locals=locals, |
|
|
globals=globals), |
|
|
) |
|
|
|
|
|
|
|
|
def _cmp_fn(name, op, self_tuple, other_tuple, globals): |
|
|
|
|
|
|
|
|
return _create_fn(name, |
|
|
('self', 'other'), |
|
|
[ 'if other.__class__ is self.__class__:', |
|
|
f' return {self_tuple}{op}{other_tuple}', |
|
|
'return NotImplemented'], |
|
|
globals=globals) |
|
|
|
|
|
|
|
|
def _hash_fn(fields, globals): |
|
|
self_tuple = _tuple_str('self', fields) |
|
|
return _create_fn('__hash__', |
|
|
('self',), |
|
|
[f'return hash({self_tuple})'], |
|
|
globals=globals) |
|
|
|
|
|
|
|
|
def _is_classvar(a_type, typing): |
|
|
|
|
|
return (a_type is typing.ClassVar |
|
|
or (type(a_type) is typing._GenericAlias |
|
|
and a_type.__origin__ is typing.ClassVar)) |
|
|
|
|
|
|
|
|
def _is_initvar(a_type, dataclasses): |
|
|
|
|
|
return (a_type is dataclasses.InitVar |
|
|
or type(a_type) is dataclasses.InitVar) |
|
|
|
|
|
|
|
|
def _is_type(annotation, cls, a_module, a_type, is_type_predicate): |
|
|
|
|
|
|
|
|
match = _MODULE_IDENTIFIER_RE.match(annotation) |
|
|
if match: |
|
|
ns = None |
|
|
module_name = match.group(1) |
|
|
if not module_name: |
|
|
|
|
|
ns = sys.modules.get(cls.__module__).__dict__ |
|
|
else: |
|
|
module = sys.modules.get(cls.__module__) |
|
|
if module and module.__dict__.get(module_name) is a_module: |
|
|
ns = sys.modules.get(a_type.__module__).__dict__ |
|
|
if ns and is_type_predicate(ns.get(match.group(2)), a_module): |
|
|
return True |
|
|
return False |
|
|
|
|
|
|
|
|
def _get_field(cls, a_name, a_type): |
|
|
|
|
|
default = getattr(cls, a_name, MISSING) |
|
|
if isinstance(default, Field): |
|
|
f = default |
|
|
else: |
|
|
if isinstance(default, types.MemberDescriptorType): |
|
|
default = MISSING |
|
|
f = field(default=default) |
|
|
|
|
|
f.name = a_name |
|
|
f.type = a_type |
|
|
|
|
|
|
|
|
f._field_type = _FIELD |
|
|
|
|
|
|
|
|
typing = sys.modules.get('typing') |
|
|
if typing: |
|
|
if (_is_classvar(a_type, typing) |
|
|
or (isinstance(f.type, str) |
|
|
and _is_type(f.type, cls, typing, typing.ClassVar, |
|
|
_is_classvar))): |
|
|
f._field_type = _FIELD_CLASSVAR |
|
|
|
|
|
|
|
|
if f._field_type is _FIELD: |
|
|
|
|
|
dataclasses = sys.modules[__name__] |
|
|
if (_is_initvar(a_type, dataclasses) |
|
|
or (isinstance(f.type, str) |
|
|
and _is_type(f.type, cls, dataclasses, dataclasses.InitVar, |
|
|
_is_initvar))): |
|
|
f._field_type = _FIELD_INITVAR |
|
|
|
|
|
|
|
|
if f._field_type in (_FIELD_CLASSVAR, _FIELD_INITVAR): |
|
|
if f.default_factory is not MISSING: |
|
|
raise TypeError(f'field {f.name} cannot have a ' |
|
|
'default factory') |
|
|
|
|
|
if f._field_type is _FIELD and isinstance(f.default, (list, dict, set)): |
|
|
raise ValueError(f'mutable default {type(f.default)} for field ' |
|
|
f'{f.name} is not allowed: use default_factory') |
|
|
|
|
|
return f |
|
|
|
|
|
|
|
|
def _set_new_attribute(cls, name, value): |
|
|
|
|
|
if name in cls.__dict__: |
|
|
return True |
|
|
setattr(cls, name, value) |
|
|
return False |
|
|
|
|
|
|
|
|
def _hash_set_none(cls, fields, globals): |
|
|
return None |
|
|
|
|
|
def _hash_add(cls, fields, globals): |
|
|
flds = [f for f in fields if (f.compare if f.hash is None else f.hash)] |
|
|
return _hash_fn(flds, globals) |
|
|
|
|
|
def _hash_exception(cls, fields, globals): |
|
|
|
|
|
raise TypeError(f'Cannot overwrite attribute __hash__ ' |
|
|
f'in class {cls.__name__}') |
|
|
|
|
|
|
|
|
_hash_action = {(False, False, False, False): None, |
|
|
(False, False, False, True ): None, |
|
|
(False, False, True, False): None, |
|
|
(False, False, True, True ): None, |
|
|
(False, True, False, False): _hash_set_none, |
|
|
(False, True, False, True ): None, |
|
|
(False, True, True, False): _hash_add, |
|
|
(False, True, True, True ): None, |
|
|
(True, False, False, False): _hash_add, |
|
|
(True, False, False, True ): _hash_exception, |
|
|
(True, False, True, False): _hash_add, |
|
|
(True, False, True, True ): _hash_exception, |
|
|
(True, True, False, False): _hash_add, |
|
|
(True, True, False, True ): _hash_exception, |
|
|
(True, True, True, False): _hash_add, |
|
|
(True, True, True, True ): _hash_exception, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen): |
|
|
|
|
|
fields = {} |
|
|
|
|
|
if cls.__module__ in sys.modules: |
|
|
globals = sys.modules[cls.__module__].__dict__ |
|
|
else: |
|
|
|
|
|
globals = {} |
|
|
|
|
|
setattr(cls, _PARAMS, _DataclassParams(init, repr, eq, order, |
|
|
unsafe_hash, frozen)) |
|
|
|
|
|
|
|
|
any_frozen_base = False |
|
|
has_dataclass_bases = False |
|
|
for b in cls.__mro__[-1:0:-1]: |
|
|
|
|
|
base_fields = getattr(b, _FIELDS, None) |
|
|
if base_fields is not None: |
|
|
has_dataclass_bases = True |
|
|
for f in base_fields.values(): |
|
|
fields[f.name] = f |
|
|
if getattr(b, _PARAMS).frozen: |
|
|
any_frozen_base = True |
|
|
|
|
|
|
|
|
cls_annotations = cls.__dict__.get('__annotations__', {}) |
|
|
|
|
|
|
|
|
cls_fields = [_get_field(cls, name, type) |
|
|
for name, type in cls_annotations.items()] |
|
|
for f in cls_fields: |
|
|
fields[f.name] = f |
|
|
|
|
|
|
|
|
if isinstance(getattr(cls, f.name, None), Field): |
|
|
if f.default is MISSING: |
|
|
|
|
|
delattr(cls, f.name) |
|
|
else: |
|
|
setattr(cls, f.name, f.default) |
|
|
|
|
|
for name, value in cls.__dict__.items(): |
|
|
if isinstance(value, Field) and not name in cls_annotations: |
|
|
raise TypeError(f'{name!r} is a field but has no type annotation') |
|
|
|
|
|
if has_dataclass_bases: |
|
|
if any_frozen_base and not frozen: |
|
|
raise TypeError('cannot inherit non-frozen dataclass from a ' |
|
|
'frozen one') |
|
|
|
|
|
if not any_frozen_base and frozen: |
|
|
raise TypeError('cannot inherit frozen dataclass from a ' |
|
|
'non-frozen one') |
|
|
|
|
|
|
|
|
setattr(cls, _FIELDS, fields) |
|
|
|
|
|
|
|
|
class_hash = cls.__dict__.get('__hash__', MISSING) |
|
|
has_explicit_hash = not (class_hash is MISSING or |
|
|
(class_hash is None and '__eq__' in cls.__dict__)) |
|
|
|
|
|
if order and not eq: |
|
|
raise ValueError('eq must be true if order is true') |
|
|
|
|
|
if init: |
|
|
has_post_init = hasattr(cls, _POST_INIT_NAME) |
|
|
|
|
|
flds = [f for f in fields.values() |
|
|
if f._field_type in (_FIELD, _FIELD_INITVAR)] |
|
|
_set_new_attribute(cls, '__init__', |
|
|
_init_fn(flds, |
|
|
frozen, |
|
|
has_post_init, |
|
|
|
|
|
'__dataclass_self__' if 'self' in fields |
|
|
else 'self', |
|
|
globals, |
|
|
)) |
|
|
|
|
|
field_list = [f for f in fields.values() if f._field_type is _FIELD] |
|
|
|
|
|
if repr: |
|
|
flds = [f for f in field_list if f.repr] |
|
|
_set_new_attribute(cls, '__repr__', _repr_fn(flds, globals)) |
|
|
|
|
|
if eq: |
|
|
|
|
|
flds = [f for f in field_list if f.compare] |
|
|
self_tuple = _tuple_str('self', flds) |
|
|
other_tuple = _tuple_str('other', flds) |
|
|
_set_new_attribute(cls, '__eq__', |
|
|
_cmp_fn('__eq__', '==', |
|
|
self_tuple, other_tuple, |
|
|
globals=globals)) |
|
|
|
|
|
if order: |
|
|
flds = [f for f in field_list if f.compare] |
|
|
self_tuple = _tuple_str('self', flds) |
|
|
other_tuple = _tuple_str('other', flds) |
|
|
for name, op in [('__lt__', '<'), |
|
|
('__le__', '<='), |
|
|
('__gt__', '>'), |
|
|
('__ge__', '>='), |
|
|
]: |
|
|
if _set_new_attribute(cls, name, |
|
|
_cmp_fn(name, op, self_tuple, other_tuple, |
|
|
globals=globals)): |
|
|
raise TypeError(f'Cannot overwrite attribute {name} ' |
|
|
f'in class {cls.__name__}. Consider using ' |
|
|
'functools.total_ordering') |
|
|
|
|
|
if frozen: |
|
|
for fn in _frozen_get_del_attr(cls, field_list, globals): |
|
|
if _set_new_attribute(cls, fn.__name__, fn): |
|
|
raise TypeError(f'Cannot overwrite attribute {fn.__name__} ' |
|
|
f'in class {cls.__name__}') |
|
|
|
|
|
hash_action = _hash_action[bool(unsafe_hash), |
|
|
bool(eq), |
|
|
bool(frozen), |
|
|
has_explicit_hash] |
|
|
if hash_action: |
|
|
|
|
|
cls.__hash__ = hash_action(cls, field_list, globals) |
|
|
|
|
|
if not getattr(cls, '__doc__'): |
|
|
cls.__doc__ = (cls.__name__ + |
|
|
str(inspect.signature(cls)).replace(' -> None', '')) |
|
|
|
|
|
return cls |
|
|
|
|
|
|
|
|
def dataclass(cls=None, /, *, init=True, repr=True, eq=True, order=False, |
|
|
unsafe_hash=False, frozen=False): |
|
|
|
|
|
|
|
|
def wrap(cls): |
|
|
return _process_class(cls, init, repr, eq, order, unsafe_hash, frozen) |
|
|
|
|
|
if cls is None: |
|
|
return wrap |
|
|
|
|
|
return wrap(cls) |
|
|
|
|
|
|
|
|
def fields(class_or_instance): |
|
|
|
|
|
try: |
|
|
fields = getattr(class_or_instance, _FIELDS) |
|
|
except AttributeError: |
|
|
raise TypeError('must be called with a dataclass type or instance') |
|
|
|
|
|
|
|
|
return tuple(f for f in fields.values() if f._field_type is _FIELD) |
|
|
|
|
|
|
|
|
def _is_dataclass_instance(obj): |
|
|
"""Returns True if obj is an instance of a dataclass.""" |
|
|
return hasattr(type(obj), _FIELDS) |
|
|
|
|
|
|
|
|
def is_dataclass(obj): |
|
|
"""Returns True if obj is a dataclass or an instance of a |
|
|
dataclass.""" |
|
|
cls = obj if isinstance(obj, type) else type(obj) |
|
|
return hasattr(cls, _FIELDS) |
|
|
|
|
|
|
|
|
def asdict(obj, *, dict_factory=dict): |
|
|
|
|
|
if not _is_dataclass_instance(obj): |
|
|
raise TypeError("asdict() should be called on dataclass instances") |
|
|
return _asdict_inner(obj, dict_factory) |
|
|
|
|
|
|
|
|
def _asdict_inner(obj, dict_factory): |
|
|
if _is_dataclass_instance(obj): |
|
|
result = [] |
|
|
for f in fields(obj): |
|
|
value = _asdict_inner(getattr(obj, f.name), dict_factory) |
|
|
result.append((f.name, value)) |
|
|
return dict_factory(result) |
|
|
elif isinstance(obj, tuple) and hasattr(obj, '_fields'): |
|
|
|
|
|
|
|
|
return type(obj)(*[_asdict_inner(v, dict_factory) for v in obj]) |
|
|
elif isinstance(obj, (list, tuple)): |
|
|
|
|
|
return type(obj)(_asdict_inner(v, dict_factory) for v in obj) |
|
|
elif isinstance(obj, dict): |
|
|
return type(obj)((_asdict_inner(k, dict_factory), |
|
|
_asdict_inner(v, dict_factory)) |
|
|
for k, v in obj.items()) |
|
|
else: |
|
|
return copy.deepcopy(obj) |
|
|
|
|
|
|
|
|
def astuple(obj, *, tuple_factory=tuple): |
|
|
|
|
|
|
|
|
if not _is_dataclass_instance(obj): |
|
|
raise TypeError("astuple() should be called on dataclass instances") |
|
|
return _astuple_inner(obj, tuple_factory) |
|
|
|
|
|
|
|
|
def _astuple_inner(obj, tuple_factory): |
|
|
if _is_dataclass_instance(obj): |
|
|
result = [] |
|
|
for f in fields(obj): |
|
|
value = _astuple_inner(getattr(obj, f.name), tuple_factory) |
|
|
result.append(value) |
|
|
return tuple_factory(result) |
|
|
elif isinstance(obj, tuple) and hasattr(obj, '_fields'): |
|
|
|
|
|
return type(obj)(*[_astuple_inner(v, tuple_factory) for v in obj]) |
|
|
elif isinstance(obj, (list, tuple)): |
|
|
|
|
|
return type(obj)(_astuple_inner(v, tuple_factory) for v in obj) |
|
|
elif isinstance(obj, dict): |
|
|
return type(obj)((_astuple_inner(k, tuple_factory), _astuple_inner(v, tuple_factory)) |
|
|
for k, v in obj.items()) |
|
|
else: |
|
|
return copy.deepcopy(obj) |
|
|
|
|
|
|
|
|
def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True, |
|
|
repr=True, eq=True, order=False, unsafe_hash=False, |
|
|
frozen=False): |
|
|
|
|
|
|
|
|
if namespace is None: |
|
|
namespace = {} |
|
|
else: |
|
|
namespace = namespace.copy() |
|
|
|
|
|
seen = set() |
|
|
anns = {} |
|
|
for item in fields: |
|
|
if isinstance(item, str): |
|
|
name = item |
|
|
tp = 'typing.Any' |
|
|
elif len(item) == 2: |
|
|
name, tp, = item |
|
|
elif len(item) == 3: |
|
|
name, tp, spec = item |
|
|
namespace[name] = spec |
|
|
else: |
|
|
raise TypeError(f'Invalid field: {item!r}') |
|
|
|
|
|
if not isinstance(name, str) or not name.isidentifier(): |
|
|
raise TypeError(f'Field names must be valid identifiers: {name!r}') |
|
|
if keyword.iskeyword(name): |
|
|
raise TypeError(f'Field names must not be keywords: {name!r}') |
|
|
if name in seen: |
|
|
raise TypeError(f'Field name duplicated: {name!r}') |
|
|
|
|
|
seen.add(name) |
|
|
anns[name] = tp |
|
|
|
|
|
namespace['__annotations__'] = anns |
|
|
|
|
|
cls = types.new_class(cls_name, bases, {}, lambda ns: ns.update(namespace)) |
|
|
return dataclass(cls, init=init, repr=repr, eq=eq, order=order, |
|
|
unsafe_hash=unsafe_hash, frozen=frozen) |
|
|
|
|
|
|
|
|
def replace(obj, /, **changes): |
|
|
|
|
|
if not _is_dataclass_instance(obj): |
|
|
raise TypeError("replace() should be called on dataclass instances") |
|
|
|
|
|
|
|
|
for f in getattr(obj, _FIELDS).values(): |
|
|
if f._field_type is _FIELD_CLASSVAR: |
|
|
continue |
|
|
|
|
|
if not f.init: |
|
|
|
|
|
if f.name in changes: |
|
|
raise ValueError(f'field {f.name} is declared with ' |
|
|
'init=False, it cannot be specified with ' |
|
|
'replace()') |
|
|
continue |
|
|
|
|
|
if f.name not in changes: |
|
|
if f._field_type is _FIELD_INITVAR and f.default is MISSING: |
|
|
raise ValueError(f"InitVar {f.name!r} " |
|
|
'must be specified with replace()') |
|
|
changes[f.name] = getattr(obj, f.name) |
|
|
|
|
|
|
|
|
return obj.__class__(**changes) |
|
|
|
|
|
from functools import partial |
|
|
from typing import List |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
import math |
|
|
from typing import List, Optional, Tuple |
|
|
|
|
|
import torch |
|
|
import torch.fft |
|
|
import torch.onnx |
|
|
from torch import Tensor |
|
|
from torch.autograd import Function |
|
|
|
|
|
|
|
|
|
|
|
def rfft( |
|
|
input: Tensor, |
|
|
n: Optional[int] = None, |
|
|
dim: int = -1, |
|
|
norm: Optional[str] = None, |
|
|
) -> Tensor: |
|
|
|
|
|
if not torch.onnx.is_in_onnx_export(): |
|
|
return torch.fft.rfft(input, n=n, dim=dim, norm=norm) |
|
|
|
|
|
if not isinstance(dim, int): |
|
|
raise TypeError() |
|
|
return _rfft_onnx(input, (n,), (dim,), norm) |
|
|
|
|
|
|
|
|
def rfft2( |
|
|
input: Tensor, |
|
|
s: Optional[Tuple[int]] = None, |
|
|
dim: Tuple[int] = (-2, -1), |
|
|
norm: Optional[str] = None, |
|
|
) -> Tensor: |
|
|
|
|
|
if not torch.onnx.is_in_onnx_export(): |
|
|
return torch.fft.rfft2(input, s=s, dim=dim, norm=norm) |
|
|
|
|
|
if not (isinstance(dim, tuple) and len(dim) == 2): |
|
|
raise ValueError() |
|
|
return _rfft_onnx(input, s, dim, norm) |
|
|
|
|
|
|
|
|
def irfft( |
|
|
input: Tensor, |
|
|
n: Optional[int] = None, |
|
|
dim: int = -1, |
|
|
norm: Optional[str] = None, |
|
|
) -> Tensor: |
|
|
|
|
|
if not torch.onnx.is_in_onnx_export(): |
|
|
return torch.fft.irfft(input, n=n, dim=dim, norm=norm) |
|
|
|
|
|
if not isinstance(dim, int): |
|
|
raise TypeError() |
|
|
return _irfft_onnx(input, (n,), (dim,), norm) |
|
|
|
|
|
|
|
|
def irfft2( |
|
|
input: Tensor, |
|
|
s: Optional[Tuple[int]] = None, |
|
|
dim: Tuple[int] = (-2, -1), |
|
|
norm: Optional[str] = None, |
|
|
) -> Tensor: |
|
|
|
|
|
if not torch.onnx.is_in_onnx_export(): |
|
|
return torch.fft.irfft2(input, s=s, dim=dim, norm=norm) |
|
|
|
|
|
if not (isinstance(dim, tuple) and len(dim) == 2): |
|
|
raise ValueError() |
|
|
return _irfft_onnx(input, s, dim, norm) |
|
|
|
|
|
|
|
|
def view_as_complex(input: Tensor) -> Tensor: |
|
|
|
|
|
if not torch.onnx.is_in_onnx_export(): |
|
|
return torch.view_as_complex(input) |
|
|
|
|
|
|
|
|
|
|
|
if input.size(-1) != 2: |
|
|
raise ValueError |
|
|
return input |
|
|
|
|
|
|
|
|
def real(input: Tensor) -> Tensor: |
|
|
|
|
|
if not torch.onnx.is_in_onnx_export(): |
|
|
return input.real |
|
|
|
|
|
|
|
|
if input.size(-1) != 2: |
|
|
raise ValueError() |
|
|
return input[..., 0] |
|
|
|
|
|
|
|
|
def imag(input: Tensor) -> Tensor: |
|
|
|
|
|
if not torch.onnx.is_in_onnx_export(): |
|
|
return input.imag |
|
|
|
|
|
|
|
|
if input.size(-1) != 2: |
|
|
raise ValueError(input.size(-1)) |
|
|
return input[..., 1] |
|
|
|
|
|
|
|
|
def _rfft_onnx( |
|
|
input: Tensor, s: Optional[Tuple[Optional[int]]], dim: Tuple[int], norm: str |
|
|
) -> Tensor: |
|
|
if s is not None: |
|
|
_check_padding_rfft(s, dim, input.size()) |
|
|
|
|
|
ndim = len(dim) |
|
|
if ndim not in [1, 2]: |
|
|
raise ValueError(ndim) |
|
|
|
|
|
perm = not _is_last_dims(dim, input.ndim) |
|
|
|
|
|
if perm: |
|
|
perm_in, perm_out = _create_axes_perm(input.ndim, dim) |
|
|
|
|
|
perm_out.append(len(perm_out)) |
|
|
|
|
|
input = input.permute(perm_in) |
|
|
|
|
|
rfft_func = OnnxRfft if ndim == 1 else OnnxRfft2 |
|
|
output = rfft_func.apply(input) |
|
|
|
|
|
output = _scale_output_forward(output, norm, input.size(), ndim) |
|
|
|
|
|
if perm: |
|
|
output = output.permute(perm_out) |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
def _irfft_onnx( |
|
|
input: Tensor, s: Optional[Tuple[Optional[int]]], dim: Tuple[int], norm: str |
|
|
) -> Tensor: |
|
|
if s is not None: |
|
|
_check_padding_irfft(s, dim, input.size()) |
|
|
|
|
|
ndim = len(dim) |
|
|
if ndim not in [1, 2]: |
|
|
raise ValueError(ndim) |
|
|
|
|
|
|
|
|
perm = not _is_last_dims(dim, input.ndim) |
|
|
|
|
|
if perm: |
|
|
|
|
|
perm_in, perm_out = _create_axes_perm(input.ndim - 1, dim) |
|
|
|
|
|
perm_in.append(len(perm_in)) |
|
|
|
|
|
input = input.permute(perm_in) |
|
|
|
|
|
irfft_func = OnnxIrfft if ndim == 1 else OnnxIrfft2 |
|
|
output = irfft_func.apply(input) |
|
|
|
|
|
output = _scale_output_backward(output, norm, input.size(), ndim) |
|
|
|
|
|
if perm: |
|
|
output = output.permute(perm_out) |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
def _contrib_rfft(g: torch.Graph, input: torch.Value, ndim: int) -> torch.Value: |
|
|
if ndim not in [1, 2]: |
|
|
raise ValueError(ndim) |
|
|
|
|
|
output = g.op( |
|
|
"com.microsoft::Rfft", |
|
|
input, |
|
|
normalized_i=0, |
|
|
onesided_i=1, |
|
|
signal_ndim_i=ndim, |
|
|
) |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
def _contrib_irfft(g: torch.Graph, input: torch.Value, ndim: int) -> torch.Value: |
|
|
if ndim not in [1, 2]: |
|
|
raise ValueError(ndim) |
|
|
|
|
|
output = g.op( |
|
|
"com.microsoft::Irfft", |
|
|
input, |
|
|
normalized_i=0, |
|
|
onesided_i=1, |
|
|
signal_ndim_i=ndim, |
|
|
) |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
def _is_last_dims(dim: Tuple[int], inp_ndim: int) -> bool: |
|
|
ndim = len(dim) |
|
|
for i, idim in enumerate(dim): |
|
|
|
|
|
if idim % inp_ndim != inp_ndim - ndim + i: |
|
|
return False |
|
|
return True |
|
|
|
|
|
|
|
|
def _check_padding_rfft( |
|
|
sizes: Tuple[Optional[int]], dim: Tuple[int], inp_sizes: Tuple[int] |
|
|
) -> None: |
|
|
if len(sizes) != len(dim): |
|
|
raise ValueError(f"{sizes}, {dim}") |
|
|
for i, s in enumerate(sizes): |
|
|
if s is None or s < 0: |
|
|
continue |
|
|
|
|
|
if s != inp_sizes[dim[i]]: |
|
|
raise RuntimeError( |
|
|
f"Padding/trimming is not yet supported, " |
|
|
f"got sizes {sizes}, DFT dims {dim}, " |
|
|
f"input dims {inp_sizes}." |
|
|
) |
|
|
|
|
|
|
|
|
def _check_padding_irfft( |
|
|
sizes: Tuple[Optional[int]], dim: Tuple[int], inp_sizes: Tuple[int] |
|
|
) -> None: |
|
|
if len(sizes) != len(dim): |
|
|
raise ValueError(f"{sizes}, {dim}") |
|
|
|
|
|
for i, s in enumerate(sizes[:-1]): |
|
|
if s is None or s < 0: |
|
|
continue |
|
|
|
|
|
if s != inp_sizes[dim[i]]: |
|
|
raise RuntimeError( |
|
|
f"Padding/trimming is not yet supported, " |
|
|
f"got sizes {sizes}, DFT dims {dim}, " |
|
|
f"input dims {inp_sizes}." |
|
|
) |
|
|
|
|
|
s = sizes[-1] |
|
|
if s is not None and s > 0: |
|
|
expected_size = 2 * (inp_sizes[dim[-1]] - 1) |
|
|
if s != expected_size: |
|
|
raise RuntimeError( |
|
|
f"Padding/trimming is not yet supported, got sizes {sizes}" |
|
|
f", DFT dims {dim}, input dims {inp_sizes}" |
|
|
f", expected last size {expected_size}." |
|
|
) |
|
|
|
|
|
|
|
|
def _create_axes_perm(ndim: int, dims: Tuple[int]) -> Tuple[List[int], List[int]]: |
|
|
"""Creates permuted axes indices for RFFT/IRFFT operators.""" |
|
|
perm_in = list(range(ndim)) |
|
|
perm_out = list(perm_in) |
|
|
|
|
|
for i in range(-1, -(len(dims) + 1), -1): |
|
|
perm_in[dims[i]], perm_in[i] = perm_in[i], perm_in[dims[i]] |
|
|
|
|
|
for i in range(-len(dims), 0): |
|
|
perm_out[dims[i]], perm_out[i] = perm_out[i], perm_out[dims[i]] |
|
|
|
|
|
return perm_in, perm_out |
|
|
|
|
|
|
|
|
def _scale_output_forward( |
|
|
output: Tensor, norm: str, sizes: torch.Size, ndim: int |
|
|
) -> Tensor: |
|
|
"""Scales the RFFT output according to norm parameter.""" |
|
|
|
|
|
norm = "backward" if norm is None else norm |
|
|
if norm not in ["forward", "backward", "ortho"]: |
|
|
raise ValueError(norm) |
|
|
|
|
|
if norm in ["forward", "ortho"]: |
|
|
|
|
|
dft_size = math.prod(sizes[-ndim:]).float() |
|
|
denom = torch.sqrt(dft_size) if norm == "ortho" else dft_size |
|
|
output = output / denom |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
def _scale_output_backward( |
|
|
output: Tensor, norm: str, sizes: torch.Size, ndim: int |
|
|
) -> Tensor: |
|
|
"""Scales the IRFFT output according to norm parameter.""" |
|
|
|
|
|
norm = "backward" if norm is None else norm |
|
|
if norm not in ["forward", "backward", "ortho"]: |
|
|
raise ValueError(norm) |
|
|
|
|
|
|
|
|
if norm in ["forward", "ortho"]: |
|
|
|
|
|
if not len(sizes) >= ndim + 1: |
|
|
raise ValueError |
|
|
dft_size = math.prod(sizes[-(ndim + 1) : -2]) |
|
|
dft_size *= 2 * (sizes[-2] - 1) |
|
|
dft_size = dft_size.float() |
|
|
|
|
|
scale = dft_size if norm == "forward" else torch.sqrt(dft_size) |
|
|
output = scale * output |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
class OnnxRfft(Function): |
|
|
|
|
|
|
|
|
@staticmethod |
|
|
def forward(ctx, input: Tensor) -> Tensor: |
|
|
if not torch.onnx.is_in_onnx_export(): |
|
|
raise ValueError("Must be called only during ONNX export.") |
|
|
|
|
|
|
|
|
y = torch.fft.rfft(input, dim=-1, norm="backward") |
|
|
return torch.view_as_real(y) |
|
|
|
|
|
@staticmethod |
|
|
def symbolic(g: torch.Graph, input: torch.Value) -> torch.Value: |
|
|
"""Symbolic representation for onnx graph""" |
|
|
return _contrib_rfft(g, input, ndim=1) |
|
|
|
|
|
|
|
|
class OnnxRfft2(Function): |
|
|
|
|
|
|
|
|
@staticmethod |
|
|
def forward(ctx, input: Tensor) -> Tensor: |
|
|
if not torch.onnx.is_in_onnx_export(): |
|
|
raise AssertionError("Must be called only during ONNX export.") |
|
|
|
|
|
y = torch.fft.rfft2(input, dim=(-2, -1), norm="backward") |
|
|
return torch.view_as_real(y) |
|
|
|
|
|
@staticmethod |
|
|
def symbolic(g: torch.Graph, input: torch.Value) -> torch.Value: |
|
|
"""Symbolic representation for onnx graph""" |
|
|
return _contrib_rfft(g, input, ndim=2) |
|
|
|
|
|
|
|
|
class OnnxIrfft(Function): |
|
|
|
|
|
|
|
|
@staticmethod |
|
|
def forward(ctx, input: Tensor) -> Tensor: |
|
|
if not torch.onnx.is_in_onnx_export(): |
|
|
raise ValueError("Must be called only during ONNX export.") |
|
|
|
|
|
|
|
|
return torch.fft.irfft(torch.view_as_complex(input), dim=-1, norm="backward") |
|
|
|
|
|
@staticmethod |
|
|
def symbolic(g: torch.Graph, input: torch.Value) -> torch.Value: |
|
|
"""Symbolic representation for onnx graph""" |
|
|
return _contrib_irfft(g, input, ndim=1) |
|
|
|
|
|
|
|
|
class OnnxIrfft2(Function): |
|
|
|
|
|
@staticmethod |
|
|
def forward(ctx, input: Tensor) -> Tensor: |
|
|
if not torch.onnx.is_in_onnx_export(): |
|
|
raise AssertionError("Must be called only during ONNX export.") |
|
|
|
|
|
|
|
|
return torch.fft.irfft2( |
|
|
torch.view_as_complex(input), dim=(-2, -1), norm="backward" |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def symbolic(g: torch.Graph, input: torch.Value) -> torch.Value: |
|
|
"""Symbolic representation for onnx graph""" |
|
|
return _contrib_irfft(g, input, ndim=2) |
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ModelMetaData: |
|
|
|
|
|
|
|
|
|
|
|
name: str = "ModulusModule" |
|
|
|
|
|
jit: bool = False |
|
|
cuda_graphs: bool = False |
|
|
amp: bool = False |
|
|
amp_cpu: bool = None |
|
|
amp_gpu: bool = None |
|
|
torch_fx: bool = False |
|
|
|
|
|
bf16: bool = False |
|
|
|
|
|
onnx: bool = False |
|
|
onnx_gpu: bool = None |
|
|
onnx_cpu: bool = None |
|
|
onnx_runtime: bool = False |
|
|
trt: bool = False |
|
|
|
|
|
var_dim: int = -1 |
|
|
func_torch: bool = False |
|
|
auto_grad: bool = False |
|
|
|
|
|
def __post_init__(self): |
|
|
self.amp_cpu = self.amp if self.amp_cpu is None else self.amp_cpu |
|
|
self.amp_gpu = self.amp if self.amp_gpu is None else self.amp_gpu |
|
|
self.onnx_cpu = self.onnx if self.onnx_cpu is None else self.onnx_cpu |
|
|
self.onnx_gpu = self.onnx if self.onnx_gpu is None else self.onnx_gpu |
|
|
|
|
|
|
|
|
import importlib |
|
|
import inspect |
|
|
import json |
|
|
import logging |
|
|
import os |
|
|
import tarfile |
|
|
import tempfile |
|
|
from pathlib import Path |
|
|
from typing import Any, Dict, Union |
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ModelMetaData: |
|
|
"""Data class for storing essential meta data needed for all Modulus Models""" |
|
|
|
|
|
|
|
|
name: str = "ModulusModule" |
|
|
|
|
|
jit: bool = False |
|
|
cuda_graphs: bool = False |
|
|
amp: bool = False |
|
|
amp_cpu: bool = None |
|
|
amp_gpu: bool = None |
|
|
torch_fx: bool = False |
|
|
|
|
|
bf16: bool = False |
|
|
|
|
|
onnx: bool = False |
|
|
onnx_gpu: bool = None |
|
|
onnx_cpu: bool = None |
|
|
onnx_runtime: bool = False |
|
|
trt: bool = False |
|
|
|
|
|
var_dim: int = -1 |
|
|
func_torch: bool = False |
|
|
auto_grad: bool = False |
|
|
|
|
|
def __post_init__(self): |
|
|
self.amp_cpu = self.amp if self.amp_cpu is None else self.amp_cpu |
|
|
self.amp_gpu = self.amp if self.amp_gpu is None else self.amp_gpu |
|
|
self.onnx_cpu = self.onnx if self.onnx_cpu is None else self.onnx_cpu |
|
|
self.onnx_gpu = self.onnx if self.onnx_gpu is None else self.onnx_gpu |
|
|
|
|
|
|
|
|
|
|
|
from importlib.metadata import EntryPoint, entry_points |
|
|
from typing import List, Union |
|
|
|
|
|
|
|
|
import importlib_metadata |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ModelRegistry: |
|
|
_shared_state = {"_model_registry": None} |
|
|
|
|
|
def __new__(cls, *args, **kwargs): |
|
|
obj = super(ModelRegistry, cls).__new__(cls) |
|
|
obj.__dict__ = cls._shared_state |
|
|
if cls._shared_state["_model_registry"] is None: |
|
|
cls._shared_state["_model_registry"] = cls._construct_registry() |
|
|
return obj |
|
|
|
|
|
@staticmethod |
|
|
def _construct_registry() -> dict: |
|
|
registry = {} |
|
|
entrypoints = entry_points(group="modulus.models") |
|
|
for entry_point in entrypoints: |
|
|
registry[entry_point.name] = entry_point |
|
|
return registry |
|
|
|
|
|
def register(self, model: "modulus.Module", name: Union[str, None] = None) -> None: |
|
|
|
|
|
|
|
|
|
|
|
if not issubclass(model, modulus.Module): |
|
|
raise ValueError( |
|
|
f"Only subclasses of modulus.Module can be registered. " |
|
|
f"Provided model is of type {type(model)}" |
|
|
) |
|
|
|
|
|
|
|
|
if name is None: |
|
|
name = model.__name__ |
|
|
|
|
|
|
|
|
if name in self._model_registry: |
|
|
raise ValueError(f"Name {name} already in use") |
|
|
|
|
|
|
|
|
self._model_registry[name] = model |
|
|
|
|
|
def factory(self, name: str) -> "modulus.Module": |
|
|
|
|
|
|
|
|
model = self._model_registry.get(name) |
|
|
if model is not None: |
|
|
if isinstance(model, (EntryPoint, importlib_metadata.EntryPoint)): |
|
|
model = model.load() |
|
|
return model |
|
|
|
|
|
raise KeyError(f"No model is registered under the name {name}") |
|
|
|
|
|
def list_models(self) -> List[str]: |
|
|
|
|
|
return list(self._model_registry.keys()) |
|
|
|
|
|
def __clear_registry__(self): |
|
|
|
|
|
self._model_registry = {} |
|
|
|
|
|
def __restore_registry__(self): |
|
|
|
|
|
self._model_registry = self._construct_registry() |
|
|
|
|
|
|
|
|
|
|
|
import hashlib |
|
|
import json |
|
|
import logging |
|
|
import os |
|
|
import re |
|
|
import urllib.request |
|
|
import zipfile |
|
|
|
|
|
import fsspec |
|
|
import fsspec.implementations.cached |
|
|
import requests |
|
|
import s3fs |
|
|
from tqdm import tqdm |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
try: |
|
|
LOCAL_CACHE = os.environ["LOCAL_CACHE"] |
|
|
except KeyError: |
|
|
LOCAL_CACHE = os.environ["HOME"] + "/.cache/modulus" |
|
|
|
|
|
|
|
|
def _cache_fs(fs): |
|
|
return fsspec.implementations.cached.CachingFileSystem( |
|
|
fs=fs, cache_storage=LOCAL_CACHE |
|
|
) |
|
|
|
|
|
|
|
|
def _get_fs(path): |
|
|
if path.startswith("s3://"): |
|
|
return s3fs.S3FileSystem(client_kwargs=dict(endpoint_url="https://pbss.s8k.io")) |
|
|
else: |
|
|
return fsspec.filesystem("file") |
|
|
|
|
|
|
|
|
def _download_ngc_model_file(path: str, out_path: str, timeout: int = 300) -> str: |
|
|
|
|
|
|
|
|
suffix = "ngc://models/" |
|
|
|
|
|
pattern = re.compile(f"{suffix}[\w-]+(/[\w-]+)?/[\w-]+@[A-Za-z0-9.]+/[\w/](.*)") |
|
|
if not pattern.match(path): |
|
|
raise ValueError( |
|
|
"Invalid URL, should be of form ngc://models/<org_id/team_id/model_id>@<version>/<path/in/repo>" |
|
|
) |
|
|
|
|
|
path = path.replace(suffix, "") |
|
|
if len(path.split("@")[0].split("/")) == 3: |
|
|
(org, team, model_version, filename) = path.split("/", 3) |
|
|
(model, version) = model_version.split("@", 1) |
|
|
else: |
|
|
(org, model_version, filename) = path.split("/", 2) |
|
|
(model, version) = model_version.split("@", 1) |
|
|
team = None |
|
|
|
|
|
token = "" |
|
|
|
|
|
if "NGC_API_KEY" in os.environ: |
|
|
try: |
|
|
|
|
|
if os.environ["NGC_API_KEY"].startswith("nvapi-"): |
|
|
raise NotImplementedError("New personal keys not supported yet") |
|
|
|
|
|
|
|
|
else: |
|
|
session = requests.Session() |
|
|
session.auth = ("$oauthtoken", os.environ["NGC_API_KEY"]) |
|
|
headers = {"Accept": "application/json"} |
|
|
authn_url = f"https://authn.nvidia.com/token?service=ngc&scope=group/ngc:{org}&group/ngc:{org}/{team}" |
|
|
r = session.get(authn_url, headers=headers, timeout=5) |
|
|
r.raise_for_status() |
|
|
token = json.loads(r.content)["token"] |
|
|
except requests.exceptions.RequestException: |
|
|
logger.warning( |
|
|
"Failed to get JWT using the API set in NGC_API_KEY environment variable" |
|
|
) |
|
|
raise |
|
|
|
|
|
|
|
|
if len(token) > 0: |
|
|
|
|
|
if team: |
|
|
file_url = f"https://api.ngc.nvidia.com/v2/org/{org}/team/{team}/models/{model}/versions/{version}/files/{filename}" |
|
|
else: |
|
|
file_url = f"https://api.ngc.nvidia.com/v2/org/{org}/models/{model}/versions/{version}/files/{filename}" |
|
|
else: |
|
|
if team: |
|
|
file_url = f"https://api.ngc.nvidia.com/v2/models/{org}/{team}/{model}/versions/{version}/files/{filename}" |
|
|
else: |
|
|
file_url = f"https://api.ngc.nvidia.com/v2/models/{org}/{model}/versions/{version}/files/{filename}" |
|
|
|
|
|
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} |
|
|
|
|
|
with requests.get(file_url, headers=headers, stream=True, timeout=timeout) as r: |
|
|
r.raise_for_status() |
|
|
total_size_in_bytes = int(r.headers.get("content-length", 0)) |
|
|
chunk_size = 1024 |
|
|
progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) |
|
|
progress_bar.set_description(f"Fetching {filename}") |
|
|
with open(out_path, "wb") as f: |
|
|
for chunk in r.iter_content(chunk_size=chunk_size): |
|
|
progress_bar.update(len(chunk)) |
|
|
f.write(chunk) |
|
|
progress_bar.close() |
|
|
|
|
|
|
|
|
if zipfile.is_zipfile(out_path) and path.endswith(".zip"): |
|
|
temp_path = out_path + ".zip" |
|
|
os.rename(out_path, temp_path) |
|
|
with zipfile.ZipFile(temp_path, "r") as zip_ref: |
|
|
zip_ref.extractall(out_path) |
|
|
|
|
|
os.remove(temp_path) |
|
|
|
|
|
return out_path |
|
|
|
|
|
|
|
|
def _download_cached( |
|
|
path: str, recursive: bool = False, local_cache_path: str = LOCAL_CACHE |
|
|
) -> str: |
|
|
sha = hashlib.sha256(path.encode()) |
|
|
filename = sha.hexdigest() |
|
|
try: |
|
|
os.makedirs(local_cache_path, exist_ok=True) |
|
|
except PermissionError as error: |
|
|
logger.error( |
|
|
"Failed to create cache folder, check permissions or set a cache" |
|
|
+ " location using the LOCAL_CACHE environment variable" |
|
|
) |
|
|
raise error |
|
|
except OSError as error: |
|
|
logger.error( |
|
|
"Failed to create cache folder, set a cache" |
|
|
+ " location using the LOCAL_CACHE environment variable" |
|
|
) |
|
|
raise error |
|
|
|
|
|
cache_path = os.path.join(local_cache_path, filename) |
|
|
|
|
|
url = urllib.parse.urlparse(path) |
|
|
|
|
|
|
|
|
if not os.path.exists(cache_path): |
|
|
logger.debug("Downloading %s to cache: %s", path, cache_path) |
|
|
if path.startswith("s3://"): |
|
|
fs = _get_fs(path) |
|
|
fs.get(path, cache_path, recursive=recursive) |
|
|
elif path.startswith("ngc://models/"): |
|
|
path = _download_ngc_model_file(path, cache_path) |
|
|
return path |
|
|
elif url.scheme == "http": |
|
|
|
|
|
|
|
|
response = requests.get(path, stream=True, timeout=5) |
|
|
with open(cache_path, "wb") as output: |
|
|
for chunk in response.iter_content(chunk_size=8192): |
|
|
if chunk: |
|
|
output.write(chunk) |
|
|
elif url.scheme == "file": |
|
|
path = os.path.join(url.netloc, url.path) |
|
|
return path |
|
|
else: |
|
|
return path |
|
|
|
|
|
else: |
|
|
logger.debug("Opening from cache: %s", cache_path) |
|
|
|
|
|
return cache_path |
|
|
|
|
|
|
|
|
class Package: |
|
|
|
|
|
|
|
|
def __init__(self, root: str, seperator: str = "/"): |
|
|
self.root = root |
|
|
self.seperator = seperator |
|
|
|
|
|
def get(self, path: str, recursive: bool = False) -> str: |
|
|
"""Get a local path to the item at ``path`` |
|
|
|
|
|
``path`` might be a remote file, in which case it is downloaded to a |
|
|
local cache at $LOCAL_CACHE or $HOME/.cache/modulus first. |
|
|
""" |
|
|
return _download_cached(self._fullpath(path), recursive=recursive) |
|
|
|
|
|
def _fullpath(self, path): |
|
|
return self.root + self.seperator + path |
|
|
|
|
|
|
|
|
class Module(torch.nn.Module): |
|
|
|
|
|
|
|
|
_file_extension = ".mdlus" |
|
|
__model_checkpoint_version__ = ( |
|
|
"0.1.0" |
|
|
) |
|
|
|
|
|
def __new__(cls, *args, **kwargs): |
|
|
out = super().__new__(cls) |
|
|
|
|
|
|
|
|
sig = inspect.signature(cls.__init__) |
|
|
|
|
|
|
|
|
bound_args = sig.bind_partial( |
|
|
*([None] + list(args)), **kwargs |
|
|
) |
|
|
bound_args.apply_defaults() |
|
|
|
|
|
|
|
|
instantiate_args = {} |
|
|
for param, (k, v) in zip(sig.parameters.values(), bound_args.arguments.items()): |
|
|
|
|
|
if k == "self": |
|
|
continue |
|
|
|
|
|
|
|
|
if param.kind == param.VAR_KEYWORD: |
|
|
instantiate_args.update(v) |
|
|
else: |
|
|
instantiate_args[k] = v |
|
|
|
|
|
|
|
|
out._args = { |
|
|
"__name__": cls.__name__, |
|
|
"__module__": cls.__module__, |
|
|
"__args__": instantiate_args, |
|
|
} |
|
|
return out |
|
|
|
|
|
def __init__(self, meta: Union[ModelMetaData, None] = None): |
|
|
super().__init__() |
|
|
self.meta = meta |
|
|
self.register_buffer("device_buffer", torch.empty(0)) |
|
|
self._setup_logger() |
|
|
|
|
|
def _setup_logger(self): |
|
|
self.logger = logging.getLogger("core.module") |
|
|
handler = logging.StreamHandler() |
|
|
formatter = logging.Formatter( |
|
|
"[%(asctime)s - %(levelname)s] %(message)s", datefmt="%H:%M:%S" |
|
|
) |
|
|
handler.setFormatter(formatter) |
|
|
self.logger.addHandler(handler) |
|
|
self.logger.setLevel(logging.WARNING) |
|
|
|
|
|
@staticmethod |
|
|
def _safe_members(tar, local_path): |
|
|
for member in tar.getmembers(): |
|
|
if ( |
|
|
".." in member.name |
|
|
or os.path.isabs(member.name) |
|
|
or os.path.realpath(os.path.join(local_path, member.name)).startswith( |
|
|
os.path.realpath(local_path) |
|
|
) |
|
|
): |
|
|
yield member |
|
|
else: |
|
|
print(f"Skipping potentially malicious file: {member.name}") |
|
|
|
|
|
@classmethod |
|
|
def instantiate(cls, arg_dict: Dict[str, Any]) -> "Module": |
|
|
|
|
|
_cls_name = arg_dict["__name__"] |
|
|
registry = ModelRegistry() |
|
|
if cls.__name__ == arg_dict["__name__"]: |
|
|
_cls = cls |
|
|
elif _cls_name in registry.list_models(): |
|
|
_cls = registry.factory(_cls_name) |
|
|
else: |
|
|
try: |
|
|
|
|
|
_mod = importlib.import_module(arg_dict["__module__"]) |
|
|
_cls = getattr(_mod, arg_dict["__name__"]) |
|
|
except AttributeError: |
|
|
|
|
|
_cls = cls |
|
|
return _cls(**arg_dict["__args__"]) |
|
|
|
|
|
def debug(self): |
|
|
"""Turn on debug logging""" |
|
|
self.logger.handlers.clear() |
|
|
handler = logging.StreamHandler() |
|
|
formatter = logging.Formatter( |
|
|
f"[%(asctime)s - %(levelname)s - {self.meta.name}] %(message)s", |
|
|
datefmt="%Y-%m-%d %H:%M:%S", |
|
|
) |
|
|
handler.setFormatter(formatter) |
|
|
self.logger.addHandler(handler) |
|
|
self.logger.setLevel(logging.DEBUG) |
|
|
|
|
|
|
|
|
|
|
|
def save(self, file_name: Union[str, None] = None, verbose: bool = False) -> None: |
|
|
|
|
|
|
|
|
if file_name is not None and not file_name.endswith(self._file_extension): |
|
|
raise ValueError( |
|
|
f"File name must end with {self._file_extension} extension" |
|
|
) |
|
|
|
|
|
with tempfile.TemporaryDirectory() as temp_dir: |
|
|
local_path = Path(temp_dir) |
|
|
|
|
|
torch.save(self.state_dict(), local_path / "model.pt") |
|
|
|
|
|
with open(local_path / "args.json", "w") as f: |
|
|
json.dump(self._args, f) |
|
|
|
|
|
|
|
|
metadata_info = { |
|
|
"modulus_version": modulus.__version__, |
|
|
"mdlus_file_version": self.__model_checkpoint_version__, |
|
|
} |
|
|
|
|
|
if verbose: |
|
|
import git |
|
|
|
|
|
try: |
|
|
repo = git.Repo(search_parent_directories=True) |
|
|
metadata_info["git_hash"] = repo.head.object.hexsha |
|
|
except git.InvalidGitRepositoryError: |
|
|
metadata_info["git_hash"] = None |
|
|
|
|
|
with open(local_path / "metadata.json", "w") as f: |
|
|
json.dump(metadata_info, f) |
|
|
|
|
|
|
|
|
with tarfile.open(local_path / "model.tar", "w") as tar: |
|
|
for file in local_path.iterdir(): |
|
|
tar.add(str(file), arcname=file.name) |
|
|
|
|
|
if file_name is None: |
|
|
file_name = self.meta.name + ".mdlus" |
|
|
|
|
|
|
|
|
fs = _get_fs(file_name) |
|
|
fs.put(str(local_path / "model.tar"), file_name) |
|
|
|
|
|
@staticmethod |
|
|
def _check_checkpoint(local_path: str) -> bool: |
|
|
if not local_path.joinpath("args.json").exists(): |
|
|
raise IOError("File 'args.json' not found in checkpoint") |
|
|
|
|
|
if not local_path.joinpath("metadata.json").exists(): |
|
|
raise IOError("File 'metadata.json' not found in checkpoint") |
|
|
|
|
|
if not local_path.joinpath("model.pt").exists(): |
|
|
raise IOError("Model weights 'model.pt' not found in checkpoint") |
|
|
|
|
|
|
|
|
with open(local_path.joinpath("metadata.json"), "r") as f: |
|
|
metadata_info = json.load(f) |
|
|
if ( |
|
|
metadata_info["mdlus_file_version"] |
|
|
!= Module.__model_checkpoint_version__ |
|
|
): |
|
|
raise IOError( |
|
|
f"Model checkpoint version {metadata_info['mdlus_file_version']} is not compatible with current version {Module.__version__}" |
|
|
) |
|
|
|
|
|
def load( |
|
|
self, |
|
|
file_name: str, |
|
|
map_location: Union[None, str, torch.device] = None, |
|
|
strict: bool = True, |
|
|
) -> None: |
|
|
|
|
|
|
|
|
|
|
|
cached_file_name = _download_cached(file_name) |
|
|
|
|
|
|
|
|
with tempfile.TemporaryDirectory() as temp_dir: |
|
|
local_path = Path(temp_dir) |
|
|
|
|
|
|
|
|
with tarfile.open(cached_file_name, "r") as tar: |
|
|
tar.extractall( |
|
|
path=local_path, members=list(Module._safe_members(tar, local_path)) |
|
|
) |
|
|
|
|
|
|
|
|
Module._check_checkpoint(local_path) |
|
|
|
|
|
|
|
|
device = map_location if map_location is not None else self.device |
|
|
model_dict = torch.load( |
|
|
local_path.joinpath("model.pt"), map_location=device |
|
|
) |
|
|
self.load_state_dict(model_dict, strict=strict) |
|
|
|
|
|
@classmethod |
|
|
def from_checkpoint(cls, file_name: str) -> "Module": |
|
|
|
|
|
|
|
|
cached_file_name = _download_cached(file_name) |
|
|
|
|
|
|
|
|
with tempfile.TemporaryDirectory() as temp_dir: |
|
|
local_path = Path(temp_dir) |
|
|
|
|
|
|
|
|
with tarfile.open(cached_file_name, "r") as tar: |
|
|
tar.extractall( |
|
|
path=local_path, members=list(cls._safe_members(tar, local_path)) |
|
|
) |
|
|
|
|
|
|
|
|
Module._check_checkpoint(local_path) |
|
|
|
|
|
|
|
|
with open(local_path.joinpath("args.json"), "r") as f: |
|
|
args = json.load(f) |
|
|
model = cls.instantiate(args) |
|
|
|
|
|
|
|
|
model_dict = torch.load( |
|
|
local_path.joinpath("model.pt"), map_location=model.device |
|
|
) |
|
|
model.load_state_dict(model_dict) |
|
|
|
|
|
return model |
|
|
|
|
|
@staticmethod |
|
|
def from_torch( |
|
|
torch_model_class: torch.nn.Module, meta: ModelMetaData = None |
|
|
) -> "Module": |
|
|
|
|
|
|
|
|
|
|
|
class ModulusModel(Module): |
|
|
def __init__(self, *args, **kwargs): |
|
|
super().__init__(meta=meta) |
|
|
self.inner_model = torch_model_class(*args, **kwargs) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.inner_model(x) |
|
|
|
|
|
|
|
|
init_argspec = inspect.getfullargspec(torch_model_class.__init__) |
|
|
model_argnames = init_argspec.args[1:] |
|
|
model_defaults = init_argspec.defaults or [] |
|
|
defaults_dict = dict( |
|
|
zip(model_argnames[-len(model_defaults) :], model_defaults) |
|
|
) |
|
|
|
|
|
|
|
|
params = [inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)] |
|
|
params += [ |
|
|
inspect.Parameter( |
|
|
argname, |
|
|
inspect.Parameter.POSITIONAL_OR_KEYWORD, |
|
|
default=defaults_dict.get(argname, inspect.Parameter.empty), |
|
|
) |
|
|
for argname in model_argnames |
|
|
] |
|
|
init_signature = inspect.Signature(params) |
|
|
|
|
|
|
|
|
ModulusModel.__init__.__signature__ = init_signature |
|
|
|
|
|
|
|
|
new_class_name = f"{torch_model_class.__name__}ModulusModel" |
|
|
ModulusModel.__name__ = new_class_name |
|
|
|
|
|
|
|
|
registry = ModelRegistry() |
|
|
registry.register(ModulusModel, new_class_name) |
|
|
|
|
|
return ModulusModel |
|
|
|
|
|
@property |
|
|
def device(self) -> torch.device: |
|
|
|
|
|
return self.device_buffer.device |
|
|
|
|
|
def num_parameters(self) -> int: |
|
|
"""Gets the number of learnable parameters""" |
|
|
count = 0 |
|
|
for name, param in self.named_parameters(): |
|
|
count += param.numel() |
|
|
return count |
|
|
|
|
|
Tensor = torch.Tensor |
|
|
|
|
|
import torch.fft |
|
|
|
|
|
class AFNOMlp(nn.Module): |
|
|
|
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
in_features: int, |
|
|
latent_features: int, |
|
|
out_features: int, |
|
|
activation_fn: nn.Module = nn.GELU(), |
|
|
drop: float = 0.0, |
|
|
): |
|
|
super().__init__() |
|
|
self.fc1 = nn.Linear(in_features, latent_features) |
|
|
self.act = activation_fn |
|
|
self.fc2 = nn.Linear(latent_features, out_features) |
|
|
self.drop = nn.Dropout(drop) |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
x = self.fc1(x) |
|
|
x = self.act(x) |
|
|
x = self.drop(x) |
|
|
x = self.fc2(x) |
|
|
x = self.drop(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class AFNO2DLayer(nn.Module): |
|
|
|
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
hidden_size: int, |
|
|
num_blocks: int = 8, |
|
|
sparsity_threshold: float = 0.01, |
|
|
hard_thresholding_fraction: float = 1, |
|
|
hidden_size_factor: int = 1, |
|
|
): |
|
|
super().__init__() |
|
|
if not (hidden_size % num_blocks == 0): |
|
|
raise ValueError( |
|
|
f"hidden_size {hidden_size} should be divisible by num_blocks {num_blocks}" |
|
|
) |
|
|
|
|
|
self.hidden_size = hidden_size |
|
|
self.sparsity_threshold = sparsity_threshold |
|
|
self.num_blocks = num_blocks |
|
|
self.block_size = self.hidden_size // self.num_blocks |
|
|
self.hard_thresholding_fraction = hard_thresholding_fraction |
|
|
self.hidden_size_factor = hidden_size_factor |
|
|
self.scale = 0.02 |
|
|
|
|
|
self.w1 = nn.Parameter( |
|
|
self.scale |
|
|
* torch.randn( |
|
|
2, |
|
|
self.num_blocks, |
|
|
self.block_size, |
|
|
self.block_size * self.hidden_size_factor, |
|
|
) |
|
|
) |
|
|
self.b1 = nn.Parameter( |
|
|
self.scale |
|
|
* torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor) |
|
|
) |
|
|
self.w2 = nn.Parameter( |
|
|
self.scale |
|
|
* torch.randn( |
|
|
2, |
|
|
self.num_blocks, |
|
|
self.block_size * self.hidden_size_factor, |
|
|
self.block_size, |
|
|
) |
|
|
) |
|
|
self.b2 = nn.Parameter( |
|
|
self.scale * torch.randn(2, self.num_blocks, self.block_size) |
|
|
) |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
bias = x |
|
|
|
|
|
dtype = x.dtype |
|
|
x = x.float() |
|
|
B, H, W, C = x.shape |
|
|
|
|
|
x = rfft2(x, dim=(1, 2), norm="ortho") |
|
|
x_real, x_imag = real(x), imag(x) |
|
|
x_real = x_real.reshape(B, H, W // 2 + 1, self.num_blocks, self.block_size) |
|
|
x_imag = x_imag.reshape(B, H, W // 2 + 1, self.num_blocks, self.block_size) |
|
|
|
|
|
o1_real = torch.zeros( |
|
|
[ |
|
|
B, |
|
|
H, |
|
|
W // 2 + 1, |
|
|
self.num_blocks, |
|
|
self.block_size * self.hidden_size_factor, |
|
|
], |
|
|
device=x.device, |
|
|
) |
|
|
o1_imag = torch.zeros( |
|
|
[ |
|
|
B, |
|
|
H, |
|
|
W // 2 + 1, |
|
|
self.num_blocks, |
|
|
self.block_size * self.hidden_size_factor, |
|
|
], |
|
|
device=x.device, |
|
|
) |
|
|
o2 = torch.zeros(x_real.shape + (2,), device=x.device) |
|
|
|
|
|
total_modes = H // 2 + 1 |
|
|
kept_modes = int(total_modes * self.hard_thresholding_fraction) |
|
|
|
|
|
o1_real[ |
|
|
:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes |
|
|
] = F.relu( |
|
|
torch.einsum( |
|
|
"nyxbi,bio->nyxbo", |
|
|
x_real[ |
|
|
:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes |
|
|
], |
|
|
self.w1[0], |
|
|
) |
|
|
- torch.einsum( |
|
|
"nyxbi,bio->nyxbo", |
|
|
x_imag[ |
|
|
:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes |
|
|
], |
|
|
self.w1[1], |
|
|
) |
|
|
+ self.b1[0] |
|
|
) |
|
|
|
|
|
o1_imag[ |
|
|
:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes |
|
|
] = F.relu( |
|
|
torch.einsum( |
|
|
"nyxbi,bio->nyxbo", |
|
|
x_imag[ |
|
|
:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes |
|
|
], |
|
|
self.w1[0], |
|
|
) |
|
|
+ torch.einsum( |
|
|
"nyxbi,bio->nyxbo", |
|
|
x_real[ |
|
|
:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes |
|
|
], |
|
|
self.w1[1], |
|
|
) |
|
|
+ self.b1[1] |
|
|
) |
|
|
|
|
|
o2[ |
|
|
:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes, ..., 0 |
|
|
] = ( |
|
|
torch.einsum( |
|
|
"nyxbi,bio->nyxbo", |
|
|
o1_real[ |
|
|
:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes |
|
|
], |
|
|
self.w2[0], |
|
|
) |
|
|
- torch.einsum( |
|
|
"nyxbi,bio->nyxbo", |
|
|
o1_imag[ |
|
|
:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes |
|
|
], |
|
|
self.w2[1], |
|
|
) |
|
|
+ self.b2[0] |
|
|
) |
|
|
|
|
|
o2[ |
|
|
:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes, ..., 1 |
|
|
] = ( |
|
|
torch.einsum( |
|
|
"nyxbi,bio->nyxbo", |
|
|
o1_imag[ |
|
|
:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes |
|
|
], |
|
|
self.w2[0], |
|
|
) |
|
|
+ torch.einsum( |
|
|
"nyxbi,bio->nyxbo", |
|
|
o1_real[ |
|
|
:, total_modes - kept_modes : total_modes + kept_modes, :kept_modes |
|
|
], |
|
|
self.w2[1], |
|
|
) |
|
|
+ self.b2[1] |
|
|
) |
|
|
|
|
|
x = F.softshrink(o2, lambd=self.sparsity_threshold) |
|
|
x = view_as_complex(x) |
|
|
|
|
|
if torch.onnx.is_in_onnx_export(): |
|
|
x = x.reshape(B, H, W // 2 + 1, C, 2) |
|
|
else: |
|
|
x = x.reshape(B, H, W // 2 + 1, C) |
|
|
|
|
|
x = irfft2(x, s=(H, W), dim=(1, 2), norm="ortho") |
|
|
x = x.type(dtype) |
|
|
|
|
|
return x + bias |
|
|
|
|
|
|
|
|
class Block(nn.Module): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
embed_dim: int, |
|
|
num_blocks: int = 8, |
|
|
mlp_ratio: float = 4.0, |
|
|
drop: float = 0.0, |
|
|
activation_fn: nn.Module = nn.GELU(), |
|
|
norm_layer: nn.Module = nn.LayerNorm, |
|
|
double_skip: bool = True, |
|
|
sparsity_threshold: float = 0.01, |
|
|
hard_thresholding_fraction: float = 1.0, |
|
|
): |
|
|
super().__init__() |
|
|
self.norm1 = norm_layer(embed_dim) |
|
|
self.filter = AFNO2DLayer( |
|
|
embed_dim, num_blocks, sparsity_threshold, hard_thresholding_fraction |
|
|
) |
|
|
|
|
|
self.norm2 = norm_layer(embed_dim) |
|
|
mlp_latent_dim = int(embed_dim * mlp_ratio) |
|
|
self.mlp = AFNOMlp( |
|
|
in_features=embed_dim, |
|
|
latent_features=mlp_latent_dim, |
|
|
out_features=embed_dim, |
|
|
activation_fn=activation_fn, |
|
|
drop=drop, |
|
|
) |
|
|
self.double_skip = double_skip |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
residual = x |
|
|
x = self.norm1(x) |
|
|
x = self.filter(x) |
|
|
|
|
|
if self.double_skip: |
|
|
x = x + residual |
|
|
residual = x |
|
|
|
|
|
x = self.norm2(x) |
|
|
x = self.mlp(x) |
|
|
x = x + residual |
|
|
return x |
|
|
|
|
|
|
|
|
class PatchEmbed(nn.Module): |
|
|
|
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
inp_shape: List[int], |
|
|
in_channels: int, |
|
|
patch_size: List[int] = [16, 16], |
|
|
embed_dim: int = 256, |
|
|
): |
|
|
super().__init__() |
|
|
if len(inp_shape) != 2: |
|
|
raise ValueError("inp_shape should be a list of length 2") |
|
|
if len(patch_size) != 2: |
|
|
raise ValueError("patch_size should be a list of length 2") |
|
|
|
|
|
num_patches = (inp_shape[1] // patch_size[1]) * (inp_shape[0] // patch_size[0]) |
|
|
self.inp_shape = inp_shape |
|
|
self.patch_size = patch_size |
|
|
self.num_patches = num_patches |
|
|
self.proj = nn.Conv2d( |
|
|
in_channels, embed_dim, kernel_size=patch_size, stride=patch_size |
|
|
) |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
B, C, H, W = x.shape |
|
|
if not (H == self.inp_shape[0] and W == self.inp_shape[1]): |
|
|
raise ValueError( |
|
|
f"Input image size ({H}*{W}) doesn't match model ({self.inp_shape[0]}*{self.inp_shape[1]})." |
|
|
) |
|
|
x = self.proj(x).flatten(2).transpose(1, 2) |
|
|
return x |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class MetaData(ModelMetaData): |
|
|
name: str = "AFNO" |
|
|
|
|
|
jit: bool = False |
|
|
cuda_graphs: bool = True |
|
|
amp: bool = True |
|
|
|
|
|
onnx_cpu: bool = False |
|
|
onnx_gpu: bool = True |
|
|
onnx_runtime: bool = True |
|
|
|
|
|
var_dim: int = 1 |
|
|
func_torch: bool = False |
|
|
auto_grad: bool = False |
|
|
|
|
|
|
|
|
class Fourcastnet(Module): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
params, |
|
|
inp_shape: tuple = [120, 240], |
|
|
in_channels: int = 97, |
|
|
out_channels: int = 93, |
|
|
patch_size: List[int] = [2, 2], |
|
|
embed_dim: int = 256, |
|
|
depth: int = 4, |
|
|
mlp_ratio: float = 4.0, |
|
|
drop_rate: float = 0.0, |
|
|
num_blocks: int = 16, |
|
|
sparsity_threshold: float = 0.01, |
|
|
hard_thresholding_fraction: float = 1.0, |
|
|
) -> None: |
|
|
super().__init__(meta=MetaData()) |
|
|
if len(inp_shape) != 2: |
|
|
raise ValueError("inp_shape should be a list of length 2") |
|
|
if len(patch_size) != 2: |
|
|
raise ValueError("patch_size should be a list of length 2") |
|
|
|
|
|
if not ( |
|
|
inp_shape[0] % patch_size[0] == 0 and inp_shape[1] % patch_size[1] == 0 |
|
|
): |
|
|
raise ValueError( |
|
|
f"input shape {inp_shape} should be divisible by patch_size {patch_size}" |
|
|
) |
|
|
|
|
|
self.in_chans = in_channels |
|
|
self.out_chans = out_channels |
|
|
self.inp_shape = inp_shape |
|
|
self.patch_size = patch_size |
|
|
self.num_features = self.embed_dim = embed_dim |
|
|
self.num_blocks = num_blocks |
|
|
norm_layer = partial(nn.LayerNorm, eps=1e-6) |
|
|
|
|
|
self.patch_embed = PatchEmbed( |
|
|
inp_shape=inp_shape, |
|
|
in_channels=self.in_chans, |
|
|
patch_size=self.patch_size, |
|
|
embed_dim=embed_dim, |
|
|
) |
|
|
num_patches = self.patch_embed.num_patches |
|
|
|
|
|
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) |
|
|
self.pos_drop = nn.Dropout(p=drop_rate) |
|
|
|
|
|
self.h = inp_shape[0] // self.patch_size[0] |
|
|
self.w = inp_shape[1] // self.patch_size[1] |
|
|
|
|
|
self.blocks = nn.ModuleList( |
|
|
[ |
|
|
Block( |
|
|
embed_dim=embed_dim, |
|
|
num_blocks=self.num_blocks, |
|
|
mlp_ratio=mlp_ratio, |
|
|
drop=drop_rate, |
|
|
norm_layer=norm_layer, |
|
|
sparsity_threshold=sparsity_threshold, |
|
|
hard_thresholding_fraction=hard_thresholding_fraction, |
|
|
) |
|
|
for i in range(depth) |
|
|
] |
|
|
) |
|
|
|
|
|
self.head = nn.Linear( |
|
|
embed_dim, |
|
|
self.out_chans * self.patch_size[0] * self.patch_size[1], |
|
|
bias=False, |
|
|
) |
|
|
|
|
|
torch.nn.init.trunc_normal_(self.pos_embed, std=0.02) |
|
|
self.apply(self._init_weights) |
|
|
|
|
|
def _init_weights(self, m): |
|
|
"""Init model weights""" |
|
|
if isinstance(m, nn.Linear): |
|
|
torch.nn.init.trunc_normal_(m.weight, std=0.02) |
|
|
if isinstance(m, nn.Linear) and m.bias is not None: |
|
|
nn.init.constant_(m.bias, 0) |
|
|
elif isinstance(m, nn.LayerNorm): |
|
|
nn.init.constant_(m.bias, 0) |
|
|
nn.init.constant_(m.weight, 1.0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward_features(self, x: Tensor) -> Tensor: |
|
|
"""Forward pass of core AFNO""" |
|
|
B = x.shape[0] |
|
|
x = self.patch_embed(x) |
|
|
x = x + self.pos_embed |
|
|
x = self.pos_drop(x) |
|
|
|
|
|
x = x.reshape(B, self.h, self.w, self.embed_dim) |
|
|
for blk in self.blocks: |
|
|
x = blk(x) |
|
|
|
|
|
return x |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
x = self.forward_features(x) |
|
|
x = self.head(x) |
|
|
|
|
|
|
|
|
|
|
|
out = x.view(list(x.shape[:-1]) + [self.patch_size[0], self.patch_size[1], -1]) |
|
|
|
|
|
out = torch.permute(out, (0, 5, 1, 3, 2, 4)) |
|
|
|
|
|
out = out.reshape(list(out.shape[:2]) + [self.inp_shape[0], self.inp_shape[1]]) |
|
|
|
|
|
return out |
|
|
|
|
|
from thop import profile |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
net = Fourcastnet().to(device) |
|
|
|
|
|
input = torch.randn(1, 97, 120, 240).to(device) |
|
|
output = net(input) |
|
|
|
|
|
macs, params = profile(net, inputs=(input, )) |
|
|
|
|
|
print('macs: ', macs, 'params: ', params) |
|
|
print('macs: %.2f G, params: %.2f M' % (macs / 1000000000.0, params / 1000000.0)) |
|
|
print(output.shape) |
|
|
|