AIDAS-Omni-Modal-Diffusion / MMaDA /models /modeling_video_encoder.py
jaeikkim
Reinit Space without binary assets
7bfbdc3
import torch
import torch.nn as nn
class VideoEncoder(nn.Module):
def __init__(self, dim, num_heads=8, dropout=0.1):
super(VideoEncoder, self).__init__()
self.attention = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout, batch_first=True)
self.norm1 = nn.LayerNorm(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, dim),
nn.Dropout(dropout)
)
self.norm2 = nn.LayerNorm(dim)
def forward(self, x):
# x shape: (batch_size, seq_len, dim)
residual = x
attn_output, _ = self.attention(x, x, x)
x = self.norm1(attn_output + residual)
residual = x
x = self.mlp(x)
x = self.norm2(x + residual)
return x # shape: (batch_size, seq_len, dim)