File size: 862 Bytes
7bfbdc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
import torch.nn as nn

class VideoEncoder(nn.Module):
    def __init__(self, dim, num_heads=8, dropout=0.1):
        super(VideoEncoder, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Linear(dim * 4, dim),
            nn.Dropout(dropout)
        )
        self.norm2 = nn.LayerNorm(dim)

    def forward(self, x):
        # x shape: (batch_size, seq_len, dim)
        residual = x
        attn_output, _ = self.attention(x, x, x)
        x = self.norm1(attn_output + residual)

        residual = x
        x = self.mlp(x)
        x = self.norm2(x + residual)
        return x  # shape: (batch_size, seq_len, dim)