File size: 6,103 Bytes
94dc344
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import warnings
from collections import defaultdict
from typing import Dict, List, Optional, Union

import torch
from pytorch3d.implicitron.tools.config import Configurable


class Autodecoder(Configurable, torch.nn.Module):
    """
    Autodecoder which maps a list of integer or string keys to optimizable embeddings.

    Settings:
        encoding_dim: Embedding dimension for the decoder.
        n_instances: The maximum number of instances stored by the autodecoder.
        init_scale: Scale factor for the initial autodecoder weights.
        ignore_input: If `True`, optimizes a single code for any input.
    """

    encoding_dim: int = 0
    n_instances: int = 1
    init_scale: float = 1.0
    ignore_input: bool = False

    def __post_init__(self):
        if self.n_instances <= 0:
            raise ValueError(f"Invalid n_instances {self.n_instances}")

        self._autodecoder_codes = torch.nn.Embedding(
            self.n_instances,
            self.encoding_dim,
            scale_grad_by_freq=True,
        )
        with torch.no_grad():
            # weight has been initialised from Normal(0, 1)
            self._autodecoder_codes.weight *= self.init_scale

        self._key_map = self._build_key_map()
        # Make sure to register hooks for correct handling of saving/loading
        # the module's _key_map.
        self._register_load_state_dict_pre_hook(self._load_key_map_hook)
        self._register_state_dict_hook(_save_key_map_hook)

    def _build_key_map(
        self, key_map_dict: Optional[Dict[str, int]] = None
    ) -> Dict[str, int]:
        """
        Args:
            key_map_dict: A dictionary used to initialize the key_map.

        Returns:
            key_map: a dictionary of key: id pairs.
        """
        # increments the counter when asked for a new value
        key_map = defaultdict(iter(range(self.n_instances)).__next__)
        if key_map_dict is not None:
            # Assign all keys from the loaded key_map_dict to self._key_map.
            # Since this is done in the original order, it should generate
            # the same set of key:id pairs. We check this with an assert to be sure.
            for x, x_id in key_map_dict.items():
                x_id_ = key_map[x]
                assert x_id == x_id_
        return key_map

    def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
        # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `weight`.
        return (self._autodecoder_codes.weight**2).mean()

    def get_encoding_dim(self) -> int:
        return self.encoding_dim

    def forward(self, x: Union[torch.LongTensor, List[str]]) -> Optional[torch.Tensor]:
        """
        Args:
            x: A batch of `N` identifiers. Either a long tensor of size
            `(N,)` keys in [0, n_instances), or a list of `N` string keys that
            are hashed to codes (without collisions).

        Returns:
            codes: A tensor of shape `(N, self.encoding_dim)` containing the
                key-specific autodecoder codes.
        """
        if self.ignore_input:
            x = ["singleton"]

        if isinstance(x[0], str):
            try:
                # pyre-fixme[9]: x has type `Union[List[str], LongTensor]`; used as
                #  `Tensor`.
                x = torch.tensor(
                    # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, ...
                    [self._key_map[elem] for elem in x],
                    dtype=torch.long,
                    device=next(self.parameters()).device,
                )
            except StopIteration:
                raise ValueError("Not enough n_instances in the autodecoder") from None

        # pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
        return self._autodecoder_codes(x)

    def _load_key_map_hook(
        self,
        state_dict,
        prefix,
        local_metadata,
        strict,
        missing_keys,
        unexpected_keys,
        error_msgs,
    ):
        """
        Args:
            state_dict (dict): a dict containing parameters and
                persistent buffers.
            prefix (str): the prefix for parameters and buffers used in this
                module
            local_metadata (dict): a dict containing the metadata for this module.
            strict (bool): whether to strictly enforce that the keys in
                :attr:`state_dict` with :attr:`prefix` match the names of
                parameters and buffers in this module
            missing_keys (list of str): if ``strict=True``, add missing keys to
                this list
            unexpected_keys (list of str): if ``strict=True``, add unexpected
                keys to this list
            error_msgs (list of str): error messages should be added to this
                list, and will be reported together in
                :meth:`~torch.nn.Module.load_state_dict`

        Returns:
            Constructed key_map if it exists in the state_dict
            else raises a warning only.
        """
        key_map_key = prefix + "_key_map"
        if key_map_key in state_dict:
            key_map_dict = state_dict.pop(key_map_key)
            self._key_map = self._build_key_map(key_map_dict=key_map_dict)
        else:
            warnings.warn("No key map in Autodecoder state dict!")


def _save_key_map_hook(
    self,
    state_dict,
    prefix,
    local_metadata,
) -> None:
    """
    Args:
        state_dict (dict): a dict containing parameters and
            persistent buffers.
        prefix (str): the prefix for parameters and buffers used in this
            module
        local_metadata (dict): a dict containing the metadata for this module.
    """
    key_map_key = prefix + "_key_map"
    key_map_dict = dict(self._key_map.items())
    state_dict[key_map_key] = key_map_dict