| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | """A mixin class for saving and loading models."""
|
| |
|
| | from typing import Any
|
| |
|
| | import dill
|
| |
|
| |
|
| | class SerializableModelMixin:
|
| | """
|
| | Provides convenient methods for saving and loading models.
|
| | """
|
| |
|
| | def save(self, file_name: str) -> None:
|
| | """
|
| | Saves this model to the specified file. Internally, the model is serialized via ``dill``.
|
| | All parameters are saved, including a primitive instance that is referenced by internal
|
| | objects. That means if a model is loaded from a file and is used, for instance, for
|
| | inference, the same primitive will be used even if a cloud primitive was used.
|
| |
|
| | Args:
|
| | file_name: a file name or path where to save the model.
|
| | """
|
| | with open(file_name, "wb") as handler:
|
| | dill.dump(self, handler)
|
| |
|
| | @classmethod
|
| | def load(cls, file_name: str) -> Any:
|
| | """
|
| | Loads a model from the file. If the loaded model is not an instance of the class whose
|
| | method was called, then a warning is raised. Nevertheless, the loaded model may be a valid
|
| | model.
|
| |
|
| | Args:
|
| | file_name: a file name or path to load a model from.
|
| |
|
| | Returns:
|
| | A loaded model.
|
| |
|
| | Raises:
|
| | TypeError: if a loaded model is not an instance of the expected class.
|
| | """
|
| | with open(file_name, "rb") as handler:
|
| | model = dill.load(handler)
|
| | if not isinstance(model, cls):
|
| | raise TypeError(f"Loaded model is of class {type(model)}. Expected class: {cls}.")
|
| | return model
|
| |
|