|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Interfaces common to all Neural Modules and Models.""" |
|
|
from __future__ import annotations |
|
|
|
|
|
import copy |
|
|
import hashlib |
|
|
import inspect |
|
|
import os |
|
|
import shutil |
|
|
import traceback |
|
|
from abc import ABC, abstractmethod |
|
|
from collections.abc import Mapping, Sequence |
|
|
from contextlib import contextmanager |
|
|
from dataclasses import dataclass, field |
|
|
from enum import Enum |
|
|
from functools import total_ordering |
|
|
from pathlib import Path |
|
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
|
|
|
import hydra |
|
|
import torch |
|
|
import wrapt |
|
|
from huggingface_hub import _CACHED_NO_EXIST, HfApi |
|
|
from huggingface_hub import get_token as get_hf_token |
|
|
from huggingface_hub import hf_hub_download, snapshot_download, try_to_load_from_cache |
|
|
from omegaconf import DictConfig, OmegaConf |
|
|
|
|
|
import nemo |
|
|
from nemo.core.classes.mixins.hf_io_mixin import HuggingFaceFileIO |
|
|
from nemo.core.config.templates.model_card import NEMO_DEFAULT_MODEL_CARD_TEMPLATE |
|
|
from nemo.core.connectors.save_restore_connector import SaveRestoreConnector |
|
|
from nemo.core.neural_types import NeuralType, NeuralTypeComparisonResult |
|
|
from nemo.utils import logging |
|
|
from nemo.utils.cloud import maybe_download_from_cloud |
|
|
from nemo.utils.data_utils import resolve_cache_dir |
|
|
from nemo.utils.model_utils import import_class_by_path, maybe_update_config_version |
|
|
|
|
|
__all__ = ['Typing', 'FileIO', 'Model', 'Serialization', 'typecheck', 'PretrainedModelInfo'] |
|
|
|
|
|
_TYPECHECK_ENABLED = True |
|
|
_TYPECHECK_SEMANTIC_CHECK_ENABLED = True |
|
|
|
|
|
_HAS_HYDRA = True |
|
|
|
|
|
|
|
|
|
|
|
ALLOWED_TARGET_PREFIXES = [ |
|
|
"nemo.collections.", |
|
|
"nemo.core.", |
|
|
"nemo.utils.", |
|
|
"nemo.lightning.", |
|
|
"tests.collections.", |
|
|
"torch.nn.", |
|
|
"torch.optim.", |
|
|
"torch.utils.data.", |
|
|
"lightning.pytorch.callbacks.", |
|
|
"lightning.pytorch.loggers.", |
|
|
"lightning.pytorch.strategies.", |
|
|
"lightning.pytorch.accelerators.", |
|
|
"omegaconf.", |
|
|
"megatron.", |
|
|
] |
|
|
|
|
|
|
|
|
def _is_target_allowed(target_path: str) -> bool: |
|
|
if not isinstance(target_path, str): |
|
|
return False |
|
|
return any(target_path.startswith(prefix) for prefix in ALLOWED_TARGET_PREFIXES) |
|
|
|
|
|
|
|
|
def _validate_config_targets_recursive(config_node: Any): |
|
|
if isinstance(config_node, Mapping): |
|
|
if "_target_" in config_node: |
|
|
target_path = config_node["_target_"] |
|
|
if not _is_target_allowed(target_path): |
|
|
raise ValueError( |
|
|
f"Instantiation of unsafe target '{target_path}' is blocked. " |
|
|
f"The '_target_' must point to a class or function within an approved namespace. " |
|
|
f"This restriction is in place to prevent potential arbitrary code execution." |
|
|
) |
|
|
for key, value in config_node.items(): |
|
|
_validate_config_targets_recursive(value) |
|
|
elif isinstance(config_node, Sequence) and not isinstance(config_node, str): |
|
|
for item in config_node: |
|
|
_validate_config_targets_recursive(item) |
|
|
|
|
|
|
|
|
def safe_instantiate(config: DictConfig, *args, **kwargs): |
|
|
""" |
|
|
A wrapper around hydra.utils.instantiate that first validates all _target_ |
|
|
fields in the config against an allow-list of prefixes. |
|
|
""" |
|
|
if config is not None: |
|
|
_validate_config_targets_recursive(config) |
|
|
return hydra.utils.instantiate(config, *args, **kwargs) |
|
|
|
|
|
|
|
|
def is_typecheck_enabled(): |
|
|
""" |
|
|
Getter method for typechecking state. |
|
|
""" |
|
|
return _TYPECHECK_ENABLED |
|
|
|
|
|
|
|
|
def is_semantic_typecheck_enabled(): |
|
|
""" |
|
|
Getter method for typechecking semantics state. |
|
|
""" |
|
|
return _TYPECHECK_SEMANTIC_CHECK_ENABLED |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TypecheckMetadata: |
|
|
""" |
|
|
Metadata class for input/output neural types. |
|
|
|
|
|
# Primary attributes |
|
|
original_types: Preserve the dictionary of type information provided. |
|
|
|
|
|
ignore_collections: For backward compatibility, container support can be disabled explicitly |
|
|
using this flag. When set to True, all nesting is ignored and nest-depth checks are skipped. |
|
|
|
|
|
# Derived attributed |
|
|
mandatory_types: Sub-dictionary of `original_types` which contains only those types which |
|
|
are mandatory to include when calling the function. |
|
|
|
|
|
base_types: Dictionary of flattened `str: NeuralType` definitions, disregarding the nest level |
|
|
details into appropriate arguments. |
|
|
|
|
|
container_depth: Dictionary mapping `str: int` - such that the valid depth of the nest of this |
|
|
neural type is recorded. |
|
|
|
|
|
has_container_types: Bool flag declaring if any of the neural types declares a container nest |
|
|
in its signature. |
|
|
|
|
|
is_singular_container_type: Bool flag declaring if this is a single Neural Type with a container |
|
|
nest in its signature. Required for supporting python list expansion in return statement. |
|
|
|
|
|
""" |
|
|
|
|
|
original_types: Dict[str, NeuralType] |
|
|
ignore_collections: bool |
|
|
|
|
|
mandatory_types: Dict[str, NeuralType] = field(init=False) |
|
|
base_types: Dict[str, NeuralType] = field(init=False) |
|
|
|
|
|
container_depth: Dict[str, int] = field(init=False) |
|
|
has_container_types: bool = field(init=False) |
|
|
is_singular_container_type: bool = field(init=False) |
|
|
|
|
|
def __post_init__(self): |
|
|
|
|
|
has_container_types = False |
|
|
for type_val in self.original_types.values(): |
|
|
if isinstance(type_val, (list, tuple)): |
|
|
has_container_types = True |
|
|
break |
|
|
self.has_container_types = has_container_types |
|
|
|
|
|
|
|
|
if self.has_container_types and len(self.original_types) == 1: |
|
|
self.is_singular_container_type = True |
|
|
else: |
|
|
self.is_singular_container_type = False |
|
|
|
|
|
|
|
|
|
|
|
if self.has_container_types: |
|
|
self.base_types = {} |
|
|
self.container_depth = {} |
|
|
|
|
|
for type_key, type_val in self.original_types.items(): |
|
|
depth = 0 |
|
|
while isinstance(type_val, (list, tuple)): |
|
|
if len(type_val) > 1: |
|
|
raise TypeError( |
|
|
f"Neural Type `{type_key}`: {type_val} definition contains more than one element when " |
|
|
"declaring the nested container structure.\n" |
|
|
"Please ensure that you have only 1 NeuralType inside of the entire nested structure " |
|
|
"definition." |
|
|
) |
|
|
|
|
|
type_val = type_val[0] |
|
|
depth += 1 |
|
|
|
|
|
self.base_types[type_key] = type_val |
|
|
self.container_depth[type_key] = depth |
|
|
else: |
|
|
|
|
|
self.base_types = self.original_types |
|
|
self.container_depth = {type_key: 0 for type_key in self.base_types.keys()} |
|
|
|
|
|
|
|
|
self.mandatory_types = { |
|
|
type_key: type_val for type_key, type_val in self.base_types.items() if not type_val.optional |
|
|
} |
|
|
|
|
|
|
|
|
class Typing(ABC): |
|
|
""" |
|
|
An interface which endows module with neural types |
|
|
""" |
|
|
|
|
|
@property |
|
|
def input_types(self) -> Optional[Dict[str, NeuralType]]: |
|
|
"""Define these to enable input neural type checks""" |
|
|
return None |
|
|
|
|
|
@property |
|
|
def output_types(self) -> Optional[Dict[str, NeuralType]]: |
|
|
"""Define these to enable output neural type checks""" |
|
|
return None |
|
|
|
|
|
def _validate_input_types(self, input_types=None, ignore_collections=False, **kwargs): |
|
|
""" |
|
|
This function does a few things. |
|
|
|
|
|
1) It ensures that len(self.input_types <non-optional>) <= len(kwargs) <= len(self.input_types). |
|
|
2) For each (keyword name, keyword value) passed as input to the wrapped function: |
|
|
- Check if the keyword name exists in the list of valid self.input_types names. |
|
|
- Check if keyword value has the `neural_type` property. |
|
|
- If it does, then perform a comparative check and assert that neural types |
|
|
are compatible (SAME or GREATER). |
|
|
- Check if keyword value is a container type (list or tuple). If yes, |
|
|
then perform the elementwise test of neural type above on each element |
|
|
of the nested structure, recursively. |
|
|
|
|
|
Args: |
|
|
input_types: Either the `input_types` defined at class level, or the local function |
|
|
overridden type definition. |
|
|
ignore_collections: For backward compatibility, container support can be disabled explicitly |
|
|
using this flag. When set to True, all nesting is ignored and nest-depth checks are skipped. |
|
|
kwargs: Dictionary of argument_name:argument_value pairs passed to the wrapped |
|
|
function upon call. |
|
|
""" |
|
|
if input_types is not None: |
|
|
|
|
|
metadata = TypecheckMetadata(original_types=input_types, ignore_collections=ignore_collections) |
|
|
|
|
|
total_input_types = len(input_types) |
|
|
mandatory_input_types = len(metadata.mandatory_types) |
|
|
|
|
|
|
|
|
if len(kwargs) < mandatory_input_types or len(kwargs) > total_input_types: |
|
|
raise TypeError( |
|
|
f"Number of input arguments provided ({len(kwargs)}) is not as expected. Function has " |
|
|
f"{total_input_types} total inputs with {mandatory_input_types} mandatory inputs." |
|
|
) |
|
|
|
|
|
for key, value in kwargs.items(): |
|
|
|
|
|
if key not in input_types: |
|
|
raise TypeError( |
|
|
f"Input argument {key} has no corresponding input_type match. " |
|
|
f"Existing input_types = {input_types.keys()}" |
|
|
) |
|
|
|
|
|
|
|
|
if ( |
|
|
hasattr(value, 'neural_type') |
|
|
and is_semantic_typecheck_enabled() |
|
|
and not metadata.base_types[key].compare(value.neural_type) |
|
|
in ( |
|
|
NeuralTypeComparisonResult.SAME, |
|
|
NeuralTypeComparisonResult.GREATER, |
|
|
) |
|
|
): |
|
|
error_msg = [ |
|
|
f"{input_types[key].compare(value.neural_type)} :", |
|
|
f"Input type expected : {input_types[key]}", |
|
|
f"Input type found : {value.neural_type}", |
|
|
f"Argument: {key}", |
|
|
] |
|
|
for i, dict_tuple in enumerate(metadata.base_types[key].elements_type.type_parameters.items()): |
|
|
error_msg.insert(i + 2, f' input param_{i} : {dict_tuple[0]}: {dict_tuple[1]}') |
|
|
for i, dict_tuple in enumerate(value.neural_type.elements_type.type_parameters.items()): |
|
|
error_msg.append(f' input param_{i} : {dict_tuple[0]}: {dict_tuple[1]}') |
|
|
raise TypeError("\n".join(error_msg)) |
|
|
|
|
|
|
|
|
if hasattr(value, 'shape'): |
|
|
value_shape = value.shape |
|
|
type_shape = metadata.base_types[key].axes |
|
|
name = key |
|
|
|
|
|
if type_shape is not None and len(value_shape) != len(type_shape): |
|
|
raise TypeError( |
|
|
f"Input shape mismatch occured for {name} in module {self.__class__.__name__} : \n" |
|
|
f"Input shape expected = {metadata.base_types[key].axes} | \n" |
|
|
f"Input shape found : {value_shape}" |
|
|
) |
|
|
|
|
|
|
|
|
elif isinstance(value, list) or isinstance(value, tuple): |
|
|
for ind, val in enumerate(value): |
|
|
""" |
|
|
This initiates a DFS, tracking the depth count as it goes along the nested structure. |
|
|
Initial depth is 1 as we consider the current loop to be the 1st step inside the nest. |
|
|
""" |
|
|
self.__check_neural_type(val, metadata, depth=1, name=key) |
|
|
|
|
|
def _attach_and_validate_output_types(self, out_objects, ignore_collections=False, output_types=None): |
|
|
""" |
|
|
This function does a few things. |
|
|
|
|
|
1) It ensures that len(out_object) == len(self.output_types). |
|
|
2) If the output is a tensor (or list/tuple of list/tuple ... of tensors), it |
|
|
attaches a neural_type to it. For objects without the neural_type attribute, |
|
|
such as python objects (dictionaries and lists, primitive data types, structs), |
|
|
no neural_type is attached. |
|
|
|
|
|
Note: tensor.neural_type is only checked during _validate_input_types which is |
|
|
called prior to forward(). |
|
|
|
|
|
Args: |
|
|
output_types: Either the `output_types` defined at class level, or the local function |
|
|
overridden type definition. |
|
|
ignore_collections: For backward compatibility, container support can be disabled explicitly |
|
|
using this flag. When set to True, all nesting is ignored and nest-depth checks are skipped. |
|
|
out_objects: The outputs of the wrapped function. |
|
|
""" |
|
|
|
|
|
if output_types is not None: |
|
|
|
|
|
metadata = TypecheckMetadata(original_types=output_types, ignore_collections=ignore_collections) |
|
|
out_types_list = list(metadata.base_types.items()) |
|
|
mandatory_out_types_list = list(metadata.mandatory_types.items()) |
|
|
|
|
|
|
|
|
if isinstance(out_objects, (list, tuple)): |
|
|
out_container = out_objects |
|
|
else: |
|
|
out_container = [out_objects] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if metadata.is_singular_container_type: |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elif len(out_container) > len(out_types_list) or len(out_container) < len(mandatory_out_types_list): |
|
|
raise TypeError( |
|
|
"Number of output arguments provided ({}) is not as expected. " |
|
|
"It should be larger or equal than {} and less or equal than {}.\n" |
|
|
"This can be either because insufficient/extra number of output NeuralTypes were provided," |
|
|
"or the provided NeuralTypes {} should enable container support " |
|
|
"(add '[]' to the NeuralType definition)".format( |
|
|
len(out_container), len(out_types_list), len(mandatory_out_types_list), output_types |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
if not isinstance(out_objects, tuple) and not isinstance(out_objects, list): |
|
|
|
|
|
try: |
|
|
out_objects.neural_type = out_types_list[0][1] |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
if hasattr(out_objects, 'shape'): |
|
|
value_shape = out_objects.shape |
|
|
type_shape = out_types_list[0][1].axes |
|
|
name = out_types_list[0][0] |
|
|
|
|
|
if type_shape is not None and len(value_shape) != len(type_shape): |
|
|
raise TypeError( |
|
|
f"Output shape mismatch occured for {name} in module {self.__class__.__name__} : \n" |
|
|
f"Output shape expected = {type_shape} | \n" |
|
|
f"Output shape found : {value_shape}" |
|
|
) |
|
|
|
|
|
elif metadata.is_singular_container_type: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
depth = 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(out_objects) == 1 and type(out_objects) == tuple: |
|
|
depth = 0 |
|
|
|
|
|
for ind, res in enumerate(out_objects): |
|
|
self.__attach_neural_type(res, metadata, depth=depth, name=out_types_list[0][0]) |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for ind, res in enumerate(out_objects): |
|
|
self.__attach_neural_type(res, metadata, depth=0, name=out_types_list[ind][0]) |
|
|
|
|
|
def __check_neural_type(self, obj, metadata: TypecheckMetadata, depth: int, name: str = None): |
|
|
""" |
|
|
Recursively tests whether the obj satisfies the semantic neural type assertion. |
|
|
Can include shape checks if shape information is provided. |
|
|
|
|
|
Args: |
|
|
obj: Any python object that can be assigned a value. |
|
|
metadata: TypecheckMetadata object. |
|
|
depth: Current depth of recursion. |
|
|
name: Optional name used of the source obj, used when an error occurs. |
|
|
""" |
|
|
if isinstance(obj, tuple) or isinstance(obj, list): |
|
|
for elem in obj: |
|
|
self.__check_neural_type(elem, metadata, depth + 1, name=name) |
|
|
return |
|
|
|
|
|
type_val = metadata.base_types[name] |
|
|
|
|
|
|
|
|
if not metadata.ignore_collections and depth != metadata.container_depth[name]: |
|
|
raise TypeError( |
|
|
"While checking input neural types,\n" |
|
|
"Nested depth of value did not match container specification:\n" |
|
|
f"Current nested depth of NeuralType '{name}' ({type_val}): {depth}\n" |
|
|
f"Expected nested depth : {metadata.container_depth[name]}" |
|
|
) |
|
|
|
|
|
if ( |
|
|
hasattr(obj, 'neural_type') |
|
|
and is_semantic_typecheck_enabled() |
|
|
and not type_val.compare(obj.neural_type) |
|
|
in ( |
|
|
NeuralTypeComparisonResult.SAME, |
|
|
NeuralTypeComparisonResult.GREATER, |
|
|
) |
|
|
): |
|
|
raise TypeError( |
|
|
f"{type_val.compare(obj.neural_type)} : \n" |
|
|
f"Input type expected = {type_val} | \n" |
|
|
f"Input type found : {obj.neural_type}" |
|
|
) |
|
|
|
|
|
|
|
|
if hasattr(obj, 'shape'): |
|
|
value_shape = obj.shape |
|
|
type_shape = type_val.axes |
|
|
|
|
|
if type_shape is not None and len(value_shape) != len(type_shape): |
|
|
raise TypeError( |
|
|
f"Input shape mismatch occured for {name} in module {self.__class__.__name__} : \n" |
|
|
f"Input shape expected = {type_shape} | \n" |
|
|
f"Input shape found : {value_shape}" |
|
|
) |
|
|
|
|
|
def __attach_neural_type(self, obj, metadata: TypecheckMetadata, depth: int, name: str = None): |
|
|
""" |
|
|
Recursively attach neural types to a given object - as long as it can be assigned some value. |
|
|
|
|
|
Args: |
|
|
obj: Any python object that can be assigned a value. |
|
|
metadata: TypecheckMetadata object. |
|
|
depth: Current depth of recursion. |
|
|
name: Optional name used of the source obj, used when an error occurs. |
|
|
""" |
|
|
if isinstance(obj, tuple) or isinstance(obj, list): |
|
|
for elem in obj: |
|
|
self.__attach_neural_type(elem, metadata, depth=depth + 1, name=name) |
|
|
return |
|
|
|
|
|
type_val = metadata.base_types[name] |
|
|
|
|
|
|
|
|
if not metadata.ignore_collections and depth != metadata.container_depth[name]: |
|
|
raise TypeError( |
|
|
"While attaching output neural types,\n" |
|
|
"Nested depth of value did not match container specification:\n" |
|
|
f"Current nested depth of NeuralType '{name}' ({type_val}): {depth}\n" |
|
|
f"Expected nested depth : {metadata.container_depth[name]}" |
|
|
) |
|
|
|
|
|
try: |
|
|
obj.neural_type = type_val |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
if hasattr(obj, 'shape'): |
|
|
value_shape = obj.shape |
|
|
type_shape = type_val.axes |
|
|
|
|
|
if type_shape is not None and len(value_shape) != len(type_shape): |
|
|
raise TypeError( |
|
|
f"Output shape mismatch occured for {name} in module {self.__class__.__name__} : \n" |
|
|
f"Output shape expected = {type_shape} | \n" |
|
|
f"Output shape found : {value_shape}" |
|
|
) |
|
|
|
|
|
|
|
|
class Serialization(ABC): |
|
|
@classmethod |
|
|
def from_config_dict(cls, config: 'DictConfig', trainer: Optional['Trainer'] = None): |
|
|
"""Instantiates object using DictConfig-based configuration""" |
|
|
|
|
|
if _HAS_HYDRA: |
|
|
if isinstance(config, DictConfig): |
|
|
config = OmegaConf.to_container(config, resolve=True) |
|
|
config = OmegaConf.create(config) |
|
|
OmegaConf.set_struct(config, True) |
|
|
|
|
|
config = maybe_update_config_version(config) |
|
|
|
|
|
|
|
|
if ('cls' in config or 'target' in config) and 'params' in config and _HAS_HYDRA: |
|
|
|
|
|
instance = safe_instantiate(config=config) |
|
|
|
|
|
elif '_target_' in config and _HAS_HYDRA: |
|
|
|
|
|
instance = safe_instantiate(config=config) |
|
|
else: |
|
|
instance = None |
|
|
prev_error = "" |
|
|
|
|
|
if 'target' in config: |
|
|
target_cls_path = config["target"] |
|
|
imported_cls = None |
|
|
try: |
|
|
|
|
|
imported_cls = import_class_by_path(target_cls_path) |
|
|
|
|
|
|
|
|
if issubclass(cls, imported_cls): |
|
|
imported_cls = cls |
|
|
accepts_trainer = Serialization._inspect_signature_for_trainer(imported_cls) |
|
|
if accepts_trainer: |
|
|
instance = imported_cls(cfg=config, trainer=trainer) |
|
|
else: |
|
|
instance = imported_cls(cfg=config) |
|
|
except Exception as e: |
|
|
|
|
|
tb = traceback.format_exc() |
|
|
prev_error = ( |
|
|
f"Model instantiation failed!\nTarget class:\t{target_cls_path}" f"\nError(s):\t{e}\n{tb}" |
|
|
) |
|
|
logging.debug(prev_error + "\nFalling back to `cls`.") |
|
|
|
|
|
|
|
|
if instance is None: |
|
|
try: |
|
|
accepts_trainer = Serialization._inspect_signature_for_trainer(cls) |
|
|
if accepts_trainer: |
|
|
instance = cls(cfg=config, trainer=trainer) |
|
|
else: |
|
|
instance = cls(cfg=config) |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
if prev_error: |
|
|
logging.error(prev_error) |
|
|
raise e |
|
|
|
|
|
if not hasattr(instance, '_cfg'): |
|
|
instance._cfg = config |
|
|
return instance |
|
|
|
|
|
def to_config_dict(self) -> 'DictConfig': |
|
|
"""Returns object's configuration to config dictionary""" |
|
|
if hasattr(self, '_cfg') and self._cfg is not None: |
|
|
|
|
|
if _HAS_HYDRA and isinstance(self._cfg, DictConfig): |
|
|
config = OmegaConf.to_container(self._cfg, resolve=True) |
|
|
config = OmegaConf.create(config) |
|
|
OmegaConf.set_struct(config, True) |
|
|
|
|
|
config = maybe_update_config_version(config) |
|
|
|
|
|
self._cfg = config |
|
|
|
|
|
return self._cfg |
|
|
else: |
|
|
raise NotImplementedError( |
|
|
'to_config_dict() can currently only return object._cfg but current object does not have it.' |
|
|
) |
|
|
|
|
|
@classmethod |
|
|
def _inspect_signature_for_trainer(cls, check_cls): |
|
|
if hasattr(check_cls, '__init__'): |
|
|
signature = inspect.signature(check_cls.__init__) |
|
|
if 'trainer' in signature.parameters: |
|
|
return True |
|
|
else: |
|
|
return False |
|
|
else: |
|
|
return False |
|
|
|
|
|
|
|
|
class FileIO(ABC): |
|
|
def save_to(self, save_path: str): |
|
|
""" |
|
|
Standardized method to save a tarfile containing the checkpoint, config, and any additional artifacts. |
|
|
Implemented via :meth:`nemo.core.connectors.save_restore_connector.SaveRestoreConnector.save_to`. |
|
|
|
|
|
Args: |
|
|
save_path: str, path to where the file should be saved. |
|
|
""" |
|
|
raise NotImplementedError() |
|
|
|
|
|
@classmethod |
|
|
def restore_from( |
|
|
cls, |
|
|
restore_path: str, |
|
|
override_config_path: Optional[str] = None, |
|
|
map_location: Optional['torch.device'] = None, |
|
|
strict: bool = True, |
|
|
return_config: bool = False, |
|
|
trainer: Optional['Trainer'] = None, |
|
|
save_restore_connector: SaveRestoreConnector = None, |
|
|
): |
|
|
""" |
|
|
Restores model instance (weights and configuration) from a .nemo file |
|
|
|
|
|
Args: |
|
|
restore_path: path to .nemo file from which model should be instantiated |
|
|
override_config_path: path to a yaml config that will override the internal |
|
|
config file or an OmegaConf / DictConfig object representing the model config. |
|
|
map_location: Optional torch.device() to map the instantiated model to a device. |
|
|
By default (None), it will select a GPU if available, falling back to CPU otherwise. |
|
|
strict: Passed to load_state_dict. By default True |
|
|
return_config: If set to true, will return just the underlying config of the restored |
|
|
model as an OmegaConf DictConfig object without instantiating the model. |
|
|
trainer: An optional Trainer object, passed to the model constructor. |
|
|
save_restore_connector: An optional SaveRestoreConnector object that defines the implementation |
|
|
of the restore_from() method. |
|
|
""" |
|
|
raise NotImplementedError() |
|
|
|
|
|
@classmethod |
|
|
def from_config_file(cls, path2yaml_file: str): |
|
|
""" |
|
|
Instantiates an instance of NeMo Model from YAML config file. |
|
|
Weights will be initialized randomly. |
|
|
Args: |
|
|
path2yaml_file: path to yaml file with model configuration |
|
|
|
|
|
Returns: |
|
|
|
|
|
""" |
|
|
if issubclass(cls, Serialization): |
|
|
conf = OmegaConf.load(path2yaml_file) |
|
|
return cls.from_config_dict(config=conf) |
|
|
else: |
|
|
raise NotImplementedError() |
|
|
|
|
|
def to_config_file(self, path2yaml_file: str): |
|
|
""" |
|
|
Saves current instance's configuration to YAML config file. Weights will not be saved. |
|
|
Args: |
|
|
path2yaml_file: path2yaml_file: path to yaml file where model model configuration will be saved |
|
|
|
|
|
Returns: |
|
|
""" |
|
|
if hasattr(self, '_cfg'): |
|
|
self._cfg = maybe_update_config_version(self._cfg) |
|
|
with open(path2yaml_file, 'w', encoding='utf-8') as fout: |
|
|
OmegaConf.save(config=self._cfg, f=fout, resolve=True) |
|
|
else: |
|
|
raise NotImplementedError() |
|
|
|
|
|
|
|
|
@total_ordering |
|
|
@dataclass |
|
|
class PretrainedModelInfo: |
|
|
pretrained_model_name: str |
|
|
description: str |
|
|
location: str |
|
|
class_: 'Model' = None |
|
|
aliases: List[str] = None |
|
|
|
|
|
def __repr__(self): |
|
|
base = self.__class__.__name__ |
|
|
extras = ( |
|
|
"pretrained_model_name={pretrained_model_name},\n\t" |
|
|
"description={description},\n\t" |
|
|
"location={location}".format(**self.__dict__) |
|
|
) |
|
|
|
|
|
if self.class_ is not None: |
|
|
extras = "{extras},\n\t" "class_={class_}".format(extras=extras, **self.__dict__) |
|
|
|
|
|
representation = f"{base}(\n\t{extras}\n)" |
|
|
return representation |
|
|
|
|
|
def __hash__(self): |
|
|
|
|
|
|
|
|
location_hash = hash(self.location) |
|
|
return location_hash |
|
|
|
|
|
def __eq__(self, other): |
|
|
|
|
|
|
|
|
return hash(self) == hash(other) or self.pretrained_model_name == other.pretrained_model_name |
|
|
|
|
|
def __lt__(self, other): |
|
|
return self.pretrained_model_name < other.pretrained_model_name |
|
|
|
|
|
|
|
|
class Model(Typing, Serialization, FileIO, HuggingFaceFileIO): |
|
|
""" |
|
|
Abstract class offering interface which should be implemented by all NeMo models. |
|
|
""" |
|
|
|
|
|
@classmethod |
|
|
@abstractmethod |
|
|
def list_available_models(cls) -> Optional[List[PretrainedModelInfo]]: |
|
|
""" |
|
|
Should list all pre-trained models available via NVIDIA NGC cloud. |
|
|
Note: There is no check that requires model names and aliases to be unique. In the case of a collision, |
|
|
whatever model (or alias) is listed first in the this returned list will be instantiated. |
|
|
|
|
|
Returns: |
|
|
A list of PretrainedModelInfo entries |
|
|
""" |
|
|
pass |
|
|
|
|
|
@classmethod |
|
|
def get_available_model_names(cls) -> List[str]: |
|
|
""" |
|
|
Returns the list of model names available via NVIDIA NGC cloud, |
|
|
to get the complete model description use list_available_models() |
|
|
Returns: |
|
|
A list of model names |
|
|
""" |
|
|
model_names = [] |
|
|
if cls.list_available_models() is not None: |
|
|
model_names = [model.pretrained_model_name for model in cls.list_available_models()] |
|
|
return model_names |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained( |
|
|
cls, |
|
|
model_name: str, |
|
|
refresh_cache: bool = False, |
|
|
override_config_path: Optional[str] = None, |
|
|
map_location: Optional['torch.device'] = None, |
|
|
strict: bool = True, |
|
|
return_config: bool = False, |
|
|
trainer: Optional['Trainer'] = None, |
|
|
save_restore_connector: SaveRestoreConnector = None, |
|
|
return_model_file: Optional[bool] = False, |
|
|
): |
|
|
""" |
|
|
Instantiates an instance of NeMo from NVIDIA NGC cloud |
|
|
Use restore_from() to instantiate from a local .nemo file. |
|
|
Args: |
|
|
model_name: string key which will be used to find the module. |
|
|
refresh_cache: If set to True, then when fetching from cloud, this will re-fetch the file |
|
|
from cloud even if it is already found in a cache locally. |
|
|
override_config_path: path to a yaml config that will override the internal |
|
|
config file |
|
|
map_location: Optional torch.device() to map the instantiated model to a device. |
|
|
By default (None), it will select a GPU if available, falling back to CPU otherwise. |
|
|
strict: Passed to torch.load_state_dict. By default true. |
|
|
return_config: If set to true, will return just the underlying config of the restored |
|
|
model as an OmegaConf DictConfig object without instantiating the model. |
|
|
return_model_file: If set to true, will return just the downloaded model file in cache |
|
|
|
|
|
Returns: |
|
|
A model instance of a particular model class or its underlying config (if return_config is set). |
|
|
""" |
|
|
if save_restore_connector is None: |
|
|
save_restore_connector = SaveRestoreConnector() |
|
|
|
|
|
|
|
|
|
|
|
if '/' in model_name: |
|
|
class_, nemo_model_file_in_cache = cls._get_hf_hub_pretrained_model_info( |
|
|
model_name=model_name, refresh_cache=refresh_cache |
|
|
) |
|
|
|
|
|
|
|
|
if os.path.isdir(nemo_model_file_in_cache): |
|
|
|
|
|
save_restore_connector.model_extracted_dir = nemo_model_file_in_cache |
|
|
|
|
|
else: |
|
|
|
|
|
class_, nemo_model_file_in_cache = cls._get_ngc_pretrained_model_info( |
|
|
model_name=model_name, refresh_cache=refresh_cache |
|
|
) |
|
|
|
|
|
if return_model_file: |
|
|
return nemo_model_file_in_cache |
|
|
|
|
|
instance = class_.restore_from( |
|
|
restore_path=nemo_model_file_in_cache, |
|
|
override_config_path=override_config_path, |
|
|
map_location=map_location, |
|
|
strict=strict, |
|
|
return_config=return_config, |
|
|
trainer=trainer, |
|
|
save_restore_connector=save_restore_connector, |
|
|
) |
|
|
return instance |
|
|
|
|
|
@classmethod |
|
|
def _get_ngc_pretrained_model_info(cls, model_name: str, refresh_cache: bool = False) -> Tuple[type, str]: |
|
|
""" |
|
|
Resolve the NGC model pretrained information given a model name. |
|
|
Assumes the model subclass implements the `list_available_models()` inherited method. |
|
|
|
|
|
Args: |
|
|
model_name: Str name of the model. Must be the original name or an alias of the model, without any '/'. |
|
|
refresh_cache: Bool, determines whether cache must be refreshed (model is re-downloaded). |
|
|
|
|
|
Returns: |
|
|
A tuple of details describing : |
|
|
- The resolved class of the model. This requires subclass to implement PretrainedModelInfo.class_. |
|
|
If the class cannot be resolved, default to the class that called this method. |
|
|
- The path to the NeMo model (.nemo file) in some cached directory. |
|
|
""" |
|
|
location_in_the_cloud = None |
|
|
description = None |
|
|
class_ = None |
|
|
models = cls.list_available_models() |
|
|
if models is not None: |
|
|
for pretrained_model_info in cls.list_available_models(): |
|
|
found = False |
|
|
if pretrained_model_info.pretrained_model_name == model_name: |
|
|
found = True |
|
|
elif pretrained_model_info.aliases is not None: |
|
|
for alias in pretrained_model_info.aliases: |
|
|
if alias == model_name: |
|
|
found = True |
|
|
break |
|
|
if found: |
|
|
location_in_the_cloud = pretrained_model_info.location |
|
|
description = pretrained_model_info.description |
|
|
class_ = pretrained_model_info.class_ |
|
|
break |
|
|
|
|
|
if location_in_the_cloud is None: |
|
|
raise FileNotFoundError( |
|
|
f"Model {model_name} was not found. Check cls.list_available_models()\n" |
|
|
f"for the list of all available models." |
|
|
) |
|
|
filename = location_in_the_cloud.split("/")[-1] |
|
|
url = location_in_the_cloud.replace(filename, "") |
|
|
cache_dir = Path.joinpath(resolve_cache_dir(), f'{filename[:-5]}') |
|
|
|
|
|
cache_subfolder = hashlib.md5((location_in_the_cloud + description).encode('utf-8')).hexdigest() |
|
|
|
|
|
nemo_model_file_in_cache = maybe_download_from_cloud( |
|
|
url=url, filename=filename, cache_dir=cache_dir, subfolder=cache_subfolder, refresh_cache=refresh_cache |
|
|
) |
|
|
|
|
|
logging.info("Instantiating model from pre-trained checkpoint") |
|
|
|
|
|
if class_ is None: |
|
|
class_ = cls |
|
|
|
|
|
return class_, nemo_model_file_in_cache |
|
|
|
|
|
@classmethod |
|
|
def _get_hf_hub_pretrained_model_info(cls, model_name: str, refresh_cache: bool = False) -> Tuple[type, str]: |
|
|
""" |
|
|
Resolve the HuggingFace Hub model pretrained information given a model name. |
|
|
The model name must be of general syntax ``{source_repo}/{model_name}``. |
|
|
|
|
|
Note: |
|
|
The ``{source_repo}`` need not be ``nvidia``, it can be any public repository, even external to Nvidia. |
|
|
This allows public, externally contributed models to be run freely using Nvidia NeMo. |
|
|
|
|
|
Args: |
|
|
model_name: Str name of the model. Must be the original name or an alias of the model, without any '/'. |
|
|
refresh_cache: Bool, determines whether cache must be refreshed (model is re-downloaded). |
|
|
|
|
|
Returns: |
|
|
A tuple of details describing : |
|
|
- The resolved class of the model. Since the source is external to NeMo, always default to using |
|
|
the calling class. Depend on target class resolution by restore_from() for calling the correct class. |
|
|
- The path to the NeMo model (.nemo file) in some cached directory (managed by HF Hub). |
|
|
""" |
|
|
|
|
|
resolved_model_filename = model_name.split("/")[-1] + '.nemo' |
|
|
|
|
|
|
|
|
if not refresh_cache: |
|
|
path = try_to_load_from_cache(repo_id=model_name, filename=resolved_model_filename) |
|
|
if path is not None and path is not _CACHED_NO_EXIST: |
|
|
return cls, path |
|
|
|
|
|
|
|
|
hf_token = get_hf_token() |
|
|
|
|
|
|
|
|
api = HfApi(token=hf_token) |
|
|
|
|
|
|
|
|
nemo_file_exists = api.file_exists(repo_id=model_name, filename=resolved_model_filename, repo_type="model") |
|
|
|
|
|
if nemo_file_exists: |
|
|
|
|
|
path = hf_hub_download( |
|
|
repo_id=model_name, |
|
|
filename=resolved_model_filename, |
|
|
library_name='nemo', |
|
|
library_version=nemo.__version__, |
|
|
force_download=refresh_cache, |
|
|
token=hf_token, |
|
|
) |
|
|
else: |
|
|
repo_info = api.repo_info(repo_id=model_name, token=hf_token, files_metadata=True) |
|
|
|
|
|
|
|
|
cache_dir = Path.joinpath(resolve_cache_dir(), "hf_hub_cache", f'{model_name}') |
|
|
|
|
|
|
|
|
cache_subfolder = [] |
|
|
|
|
|
for sibling in repo_info.siblings: |
|
|
filename = sibling.rfilename.lower() |
|
|
|
|
|
if "readme" not in filename or "git" not in filename: |
|
|
cache_subfolder.append(sibling.blob_id) |
|
|
cache_subfolder = sorted(cache_subfolder) |
|
|
cache_subfolder = "".join(cache_subfolder) |
|
|
cache_subfolder = hashlib.md5(cache_subfolder.encode('utf-8')).hexdigest() |
|
|
|
|
|
|
|
|
save_path = os.path.join(cache_dir, cache_subfolder) |
|
|
|
|
|
|
|
|
if os.path.exists(cache_dir): |
|
|
num_files_in_dir = len(os.listdir(cache_dir)) |
|
|
if num_files_in_dir > 0: |
|
|
logging.info("Found {} files in cache directory {}".format(num_files_in_dir, cache_dir)) |
|
|
logging.info( |
|
|
f"Deleting old cache directory for model `{model_name}` in order to prevent duplicates..." |
|
|
) |
|
|
shutil.rmtree(cache_dir, ignore_errors=True) |
|
|
|
|
|
if not os.path.exists(save_path): |
|
|
logging.info(f"Downloading {model_name} from HuggingFace Hub to path: {save_path}") |
|
|
os.makedirs(save_path, exist_ok=True) |
|
|
|
|
|
path = snapshot_download( |
|
|
repo_id=model_name, |
|
|
library_name='nemo', |
|
|
library_version=nemo.__version__, |
|
|
force_download=refresh_cache, |
|
|
cache_dir=save_path, |
|
|
local_dir=save_path, |
|
|
local_dir_use_symlinks=False, |
|
|
token=hf_token, |
|
|
) |
|
|
|
|
|
return cls, path |
|
|
|
|
|
def generate_model_card( |
|
|
self, type: str = "hf", template: str = None, template_kwargs: Optional[Dict[str, str]] = None |
|
|
) -> object: |
|
|
""" |
|
|
Generates a ModelCard for the current model. This method is called when pushing the model to the Hub. |
|
|
|
|
|
Returns: |
|
|
An object that can be represented as a str representation of the model card, usually in Markdown format. |
|
|
""" |
|
|
if template is None: |
|
|
template = copy.deepcopy(NEMO_DEFAULT_MODEL_CARD_TEMPLATE) |
|
|
|
|
|
|
|
|
if template_kwargs is None: |
|
|
template_kwargs = {} |
|
|
|
|
|
if type == "hf": |
|
|
|
|
|
return self._get_hf_model_card(template=template, template_kwargs=template_kwargs) |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Model card type {type} not supported.") |
|
|
|
|
|
|
|
|
class typecheck: |
|
|
""" |
|
|
A decorator which performs input-output neural type checks, and attaches |
|
|
neural types to the output of the function that it wraps. |
|
|
|
|
|
Requires that the class inherit from :class:`~nemo.core.Typing` in order to perform |
|
|
type checking, and will raise an error if that is not the case. |
|
|
|
|
|
# Usage (Class level type support) |
|
|
|
|
|
.. code-block:: python |
|
|
|
|
|
@typecheck() |
|
|
def fn(self, arg1, arg2, ...): |
|
|
... |
|
|
|
|
|
# Usage (Function level type support) |
|
|
|
|
|
.. code-block:: python |
|
|
|
|
|
@typecheck(input_types=..., output_types=...) |
|
|
def fn(self, arg1, arg2, ...): |
|
|
... |
|
|
|
|
|
Points to be noted: |
|
|
|
|
|
1) The brackets () in `@typecheck()` are necessary. |
|
|
|
|
|
You will encounter a TypeError: __init__() takes 1 positional argument but X |
|
|
were given without those brackets. |
|
|
|
|
|
2) The function can take any number of positional arguments during definition. |
|
|
|
|
|
When you call this function, all arguments must be passed using kwargs only. |
|
|
|
|
|
""" |
|
|
|
|
|
class TypeState(Enum): |
|
|
""" |
|
|
Placeholder to denote the default value of type information provided. |
|
|
If the constructor of this decorator is used to override the class level type definition, |
|
|
this enum value indicate that types will be overridden. |
|
|
""" |
|
|
|
|
|
UNINITIALIZED = 0 |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
input_types: Union[TypeState, Dict[str, NeuralType]] = TypeState.UNINITIALIZED, |
|
|
output_types: Union[TypeState, Dict[str, NeuralType]] = TypeState.UNINITIALIZED, |
|
|
ignore_collections: bool = False, |
|
|
): |
|
|
self.input_types = input_types |
|
|
self.output_types = output_types |
|
|
|
|
|
if input_types == self.TypeState.UNINITIALIZED: |
|
|
self.input_override = False |
|
|
else: |
|
|
self.input_override = True |
|
|
|
|
|
if output_types == self.TypeState.UNINITIALIZED: |
|
|
self.output_override = False |
|
|
else: |
|
|
self.output_override = True |
|
|
|
|
|
self.ignore_collections = ignore_collections |
|
|
|
|
|
def __call__(self, wrapped): |
|
|
return self.wrapped_call(wrapped) |
|
|
|
|
|
def unwrapped_call(self, wrapped): |
|
|
return wrapped |
|
|
|
|
|
@wrapt.decorator(enabled=is_typecheck_enabled) |
|
|
def wrapped_call(self, wrapped, instance: Typing, args, kwargs): |
|
|
""" |
|
|
Wrapper method that can be used on any function of a class that implements :class:`~nemo.core.Typing`. |
|
|
By default, it will utilize the `input_types` and `output_types` properties of the class inheriting Typing. |
|
|
|
|
|
Local function level overrides can be provided by supplying dictionaries as arguments to the decorator. |
|
|
|
|
|
Args: |
|
|
input_types: Union[TypeState, Dict[str, NeuralType]]. By default, uses the global `input_types`. |
|
|
output_types: Union[TypeState, Dict[str, NeuralType]]. By default, uses the global `output_types`. |
|
|
ignore_collections: Bool. Determines if container types should be asserted for depth checks, or |
|
|
if depth checks are skipped entirely. |
|
|
|
|
|
""" |
|
|
if instance is None: |
|
|
raise RuntimeError("Only classes which inherit nemo.core.Typing can use this decorator !") |
|
|
|
|
|
if not isinstance(instance, Typing): |
|
|
raise RuntimeError("Only classes which inherit nemo.core.Typing can use this decorator !") |
|
|
|
|
|
if hasattr(instance, 'input_ports') or hasattr(instance, 'output_ports'): |
|
|
raise RuntimeError( |
|
|
"Typing requires override of `input_types()` and `output_types()`, " |
|
|
"not `input_ports() and `output_ports()`" |
|
|
) |
|
|
|
|
|
|
|
|
if self.input_types is typecheck.TypeState.UNINITIALIZED: |
|
|
self.input_types = instance.input_types |
|
|
|
|
|
if self.output_types is typecheck.TypeState.UNINITIALIZED: |
|
|
self.output_types = instance.output_types |
|
|
|
|
|
|
|
|
if self.input_override: |
|
|
input_types = self.input_types |
|
|
else: |
|
|
input_types = instance.input_types |
|
|
|
|
|
if self.output_override: |
|
|
output_types = self.output_types |
|
|
else: |
|
|
output_types = instance.output_types |
|
|
|
|
|
|
|
|
if input_types is None and output_types is None: |
|
|
return wrapped(*args, **kwargs) |
|
|
|
|
|
|
|
|
if input_types is not None and len(args) > 0: |
|
|
raise TypeError("All arguments must be passed by kwargs only for typed methods") |
|
|
|
|
|
|
|
|
instance._validate_input_types(input_types=input_types, ignore_collections=self.ignore_collections, **kwargs) |
|
|
|
|
|
|
|
|
outputs = wrapped(*args, **kwargs) |
|
|
|
|
|
instance._attach_and_validate_output_types( |
|
|
output_types=output_types, ignore_collections=self.ignore_collections, out_objects=outputs |
|
|
) |
|
|
|
|
|
return outputs |
|
|
|
|
|
@staticmethod |
|
|
def set_typecheck_enabled(enabled: bool = True): |
|
|
""" |
|
|
Global method to enable/disable typechecking. |
|
|
|
|
|
Args: |
|
|
enabled: bool, when True will enable typechecking. |
|
|
""" |
|
|
global _TYPECHECK_ENABLED |
|
|
_TYPECHECK_ENABLED = enabled |
|
|
|
|
|
@staticmethod |
|
|
@contextmanager |
|
|
def disable_checks(): |
|
|
""" |
|
|
Context manager that temporarily disables type checking within its context. |
|
|
""" |
|
|
typecheck.set_typecheck_enabled(enabled=False) |
|
|
try: |
|
|
yield |
|
|
finally: |
|
|
typecheck.set_typecheck_enabled(enabled=True) |
|
|
|
|
|
@staticmethod |
|
|
def set_semantic_check_enabled(enabled: bool = True): |
|
|
""" |
|
|
Global method to enable/disable semantic typechecking. |
|
|
|
|
|
Args: |
|
|
enabled: bool, when True will enable semantic typechecking. |
|
|
""" |
|
|
global _TYPECHECK_SEMANTIC_CHECK_ENABLED |
|
|
_TYPECHECK_SEMANTIC_CHECK_ENABLED = enabled |
|
|
|
|
|
@staticmethod |
|
|
@contextmanager |
|
|
def disable_semantic_checks(): |
|
|
""" |
|
|
Context manager that temporarily disables semantic type checking within its context. |
|
|
""" |
|
|
typecheck.set_semantic_check_enabled(enabled=False) |
|
|
try: |
|
|
yield |
|
|
finally: |
|
|
typecheck.set_semantic_check_enabled(enabled=True) |
|
|
|
|
|
@staticmethod |
|
|
def enable_wrapping(enabled: bool = True): |
|
|
typecheck.set_typecheck_enabled(enabled) |
|
|
if enabled: |
|
|
typecheck.__call__ = nemo.core.classes.common.typecheck.wrapped_call |
|
|
else: |
|
|
typecheck.__call__ = nemo.core.classes.common.typecheck.unwrapped_call |
|
|
|