import torch import torch.nn as nn import torch.nn.functional as F from torchvision import models class VideoEncoder(nn.Module): def __init__(self, embed_dim: int = 128, pretrained: bool = True): super().__init__() # Proper weights loading if pretrained: backbone = models.resnet18(weights=models.ResNet18_Weights.DEFAULT) else: backbone = models.resnet18(weights=None) self.feature_extractor = nn.Sequential(*list(backbone.children())[:-1]) self.frame_projection = nn.Sequential( nn.Linear(512, 256), nn.ReLU(), nn.Linear(256, embed_dim) ) self.temporal_pool = nn.AdaptiveAvgPool1d(1) def forward(self, x): """ x: (B, T, 3, H, W) returns: (B, embed_dim) """ B, T, C, H, W = x.shape # Flatten frames x = x.view(B * T, C, H, W) feats = self.feature_extractor(x) # (B*T, 512, 1, 1) feats = feats.view(B, T, 512) # (B, T, 512) feats = self.frame_projection(feats) # (B, T, embed_dim) feats = feats.permute(0, 2, 1) # (B, embed_dim, T) pooled = self.temporal_pool(feats).squeeze(-1) embeddings = F.normalize(pooled, dim=1) return embeddings