import torch import torch.nn as nn class VideoEncoder(nn.Module): def __init__(self, dim, num_heads=8, dropout=0.1): super(VideoEncoder, self).__init__() self.attention = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout, batch_first=True) self.norm1 = nn.LayerNorm(dim) self.mlp = nn.Sequential( nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim), nn.Dropout(dropout) ) self.norm2 = nn.LayerNorm(dim) def forward(self, x): # x shape: (batch_size, seq_len, dim) residual = x attn_output, _ = self.attention(x, x, x) x = self.norm1(attn_output + residual) residual = x x = self.mlp(x) x = self.norm2(x + residual) return x # shape: (batch_size, seq_len, dim)