File size: 2,118 Bytes
b7d9967
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# This code is part of a Qiskit project.
#
# (C) Copyright IBM 2021, 2023.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
"""A mixin class for saving and loading models."""

from typing import Any

import dill


class SerializableModelMixin:
    """

    Provides convenient methods for saving and loading models.

    """

    def save(self, file_name: str) -> None:
        """

        Saves this model to the specified file. Internally, the model is serialized via ``dill``.

        All parameters are saved, including a primitive instance that is referenced by internal

        objects. That means if a model is loaded from a file and is used, for instance, for

        inference, the same primitive will be used even if a cloud primitive was used.



        Args:

            file_name: a file name or path where to save the model.

        """
        with open(file_name, "wb") as handler:
            dill.dump(self, handler)

    @classmethod
    def load(cls, file_name: str) -> Any:
        """

        Loads a model from the file. If the loaded model is not an instance of the class whose

        method was called, then a warning is raised. Nevertheless, the loaded model may be a valid

        model.



        Args:

            file_name: a file name or path to load a model from.



        Returns:

            A loaded model.



        Raises:

            TypeError: if a loaded model is not an instance of the expected class.

        """
        with open(file_name, "rb") as handler:
            model = dill.load(handler)
        if not isinstance(model, cls):
            raise TypeError(f"Loaded model is of class {type(model)}. Expected class: {cls}.")
        return model