File size: 971 Bytes
4853fdc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
from typing import Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.torch_utilities import create_mask_from_length
class MlpVideoEncoder(nn.Module):
def __init__(
self,
video_feat_dim: int,
embed_dim: int,
):
super().__init__()
self.mlp = nn.Linear(video_feat_dim, embed_dim)
self.init_weights()
def init_weights(self):
def _init_weights(module):
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0.)
self.apply(_init_weights)
def forward(self, frames: torch.Tensor, frame_nums: Sequence[int]):
device = frames.device
x = F.normalize(frames, p=2, dim=-1)
x = self.mlp(x)
mask = create_mask_from_length(frame_nums).to(device)
return {"output": x, "mask": mask}
|