File size: 4,036 Bytes
94dc344
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

from typing import List, Optional, Union

import torch
from pytorch3d.implicitron.tools.config import (
    registry,
    ReplaceableBase,
    run_auto_creation,
)
from pytorch3d.renderer.implicit import HarmonicEmbedding

from .autodecoder import Autodecoder


class GlobalEncoderBase(ReplaceableBase):
    """
    A base class for implementing encoders of global frame-specific quantities.

    The latter includes e.g. the harmonic encoding of a frame timestamp
    (`HarmonicTimeEncoder`), or an autodecoder encoding of the frame's sequence
    (`SequenceAutodecoder`).
    """

    def get_encoding_dim(self):
        """
        Returns the dimensionality of the returned encoding.
        """
        raise NotImplementedError()

    def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
        """
        Calculates the squared norm of the encoding to report as the
        `autodecoder_norm` loss of the model, as a zero dimensional tensor.
        """
        raise NotImplementedError()

    def forward(
        self,
        *,
        frame_timestamp: Optional[torch.Tensor] = None,
        sequence_name: Optional[Union[torch.LongTensor, List[str]]] = None,
        **kwargs,
    ) -> torch.Tensor:
        """
        Given a set of inputs to encode, generates a tensor containing the encoding.

        Returns:
            encoding: The tensor containing the global encoding.
        """
        raise NotImplementedError()


# TODO: probabilistic embeddings?
@registry.register
class SequenceAutodecoder(GlobalEncoderBase, torch.nn.Module):
    """
    A global encoder implementation which provides an autodecoder encoding
    of the frame's sequence identifier.
    """

    # pyre-fixme[13]: Attribute `autodecoder` is never initialized.
    autodecoder: Autodecoder

    def __post_init__(self):
        run_auto_creation(self)

    def get_encoding_dim(self):
        return self.autodecoder.get_encoding_dim()

    def forward(
        self,
        *,
        frame_timestamp: Optional[torch.Tensor] = None,
        sequence_name: Optional[Union[torch.LongTensor, List[str]]] = None,
        **kwargs,
    ) -> torch.Tensor:
        if sequence_name is None:
            raise ValueError("sequence_name must be provided.")
        # run dtype checks and pass sequence_name to self.autodecoder
        return self.autodecoder(sequence_name)

    def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
        return self.autodecoder.calculate_squared_encoding_norm()


@registry.register
class HarmonicTimeEncoder(GlobalEncoderBase, torch.nn.Module):
    """
    A global encoder implementation which provides harmonic embeddings
    of each frame's timestamp.
    """

    n_harmonic_functions: int = 10
    append_input: bool = True
    time_divisor: float = 1.0

    def __post_init__(self):
        self._harmonic_embedding = HarmonicEmbedding(
            n_harmonic_functions=self.n_harmonic_functions,
            append_input=self.append_input,
        )

    def get_encoding_dim(self):
        return self._harmonic_embedding.get_output_dim(1)

    def forward(
        self,
        *,
        frame_timestamp: Optional[torch.Tensor] = None,
        sequence_name: Optional[Union[torch.LongTensor, List[str]]] = None,
        **kwargs,
    ) -> torch.Tensor:
        if frame_timestamp is None:
            raise ValueError("frame_timestamp must be provided.")
        if frame_timestamp.shape[-1] != 1:
            raise ValueError("Frame timestamp's last dimensions should be one.")
        time = frame_timestamp / self.time_divisor
        # pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
        return self._harmonic_embedding(time)

    def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
        return None