File size: 4,036 Bytes
94dc344 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
from typing import List, Optional, Union
import torch
from pytorch3d.implicitron.tools.config import (
registry,
ReplaceableBase,
run_auto_creation,
)
from pytorch3d.renderer.implicit import HarmonicEmbedding
from .autodecoder import Autodecoder
class GlobalEncoderBase(ReplaceableBase):
"""
A base class for implementing encoders of global frame-specific quantities.
The latter includes e.g. the harmonic encoding of a frame timestamp
(`HarmonicTimeEncoder`), or an autodecoder encoding of the frame's sequence
(`SequenceAutodecoder`).
"""
def get_encoding_dim(self):
"""
Returns the dimensionality of the returned encoding.
"""
raise NotImplementedError()
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
"""
Calculates the squared norm of the encoding to report as the
`autodecoder_norm` loss of the model, as a zero dimensional tensor.
"""
raise NotImplementedError()
def forward(
self,
*,
frame_timestamp: Optional[torch.Tensor] = None,
sequence_name: Optional[Union[torch.LongTensor, List[str]]] = None,
**kwargs,
) -> torch.Tensor:
"""
Given a set of inputs to encode, generates a tensor containing the encoding.
Returns:
encoding: The tensor containing the global encoding.
"""
raise NotImplementedError()
# TODO: probabilistic embeddings?
@registry.register
class SequenceAutodecoder(GlobalEncoderBase, torch.nn.Module):
"""
A global encoder implementation which provides an autodecoder encoding
of the frame's sequence identifier.
"""
# pyre-fixme[13]: Attribute `autodecoder` is never initialized.
autodecoder: Autodecoder
def __post_init__(self):
run_auto_creation(self)
def get_encoding_dim(self):
return self.autodecoder.get_encoding_dim()
def forward(
self,
*,
frame_timestamp: Optional[torch.Tensor] = None,
sequence_name: Optional[Union[torch.LongTensor, List[str]]] = None,
**kwargs,
) -> torch.Tensor:
if sequence_name is None:
raise ValueError("sequence_name must be provided.")
# run dtype checks and pass sequence_name to self.autodecoder
return self.autodecoder(sequence_name)
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
return self.autodecoder.calculate_squared_encoding_norm()
@registry.register
class HarmonicTimeEncoder(GlobalEncoderBase, torch.nn.Module):
"""
A global encoder implementation which provides harmonic embeddings
of each frame's timestamp.
"""
n_harmonic_functions: int = 10
append_input: bool = True
time_divisor: float = 1.0
def __post_init__(self):
self._harmonic_embedding = HarmonicEmbedding(
n_harmonic_functions=self.n_harmonic_functions,
append_input=self.append_input,
)
def get_encoding_dim(self):
return self._harmonic_embedding.get_output_dim(1)
def forward(
self,
*,
frame_timestamp: Optional[torch.Tensor] = None,
sequence_name: Optional[Union[torch.LongTensor, List[str]]] = None,
**kwargs,
) -> torch.Tensor:
if frame_timestamp is None:
raise ValueError("frame_timestamp must be provided.")
if frame_timestamp.shape[-1] != 1:
raise ValueError("Frame timestamp's last dimensions should be one.")
time = frame_timestamp / self.time_divisor
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
return self._harmonic_embedding(time)
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
return None
|