Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,553 Bytes
fc605f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
from typing import Optional
import torch
class AlignModalities(torch.nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
normalize: bool = True,
with_gate: bool = True,
):
super().__init__()
self.conv = torch.nn.Conv1d(
in_channels=in_channels, out_channels=out_channels, kernel_size=1
)
self.normalize = normalize
if self.normalize:
self.layer_norm = torch.nn.LayerNorm(out_channels)
self.gate = None
if with_gate:
self.gate = torch.nn.Parameter(torch.tensor([0.0]))
self.out_channels = out_channels
def forward(self, anchor: torch.Tensor, tgt: Optional[torch.Tensor] = None):
"""
Align video features to the input audio features
Args:
anchor (torch.Tensor): Input anchor tensor of shape (B, T, C), where B is batch size, C is channel size, and T is sequence length.
tgt (Optional[torch.Tensor]): Optional features tensor to be aligned to anchor, expected shape (B, in_channels, T).
"""
if tgt is None:
return anchor
post_conv = self.conv(tgt)
post_conv = post_conv.permute(0, 2, 1) # BCT -> BTC
if self.normalize:
post_conv = self.layer_norm(post_conv)
if self.gate is None:
return post_conv
else:
return anchor + self.gate.tanh() * post_conv
|