from __future__ import annotations import torch import torch.nn as nn class LipFDNet(nn.Module): """ Minimal LipFD-compatible network wrapper for Space inference. The hosted checkpoint is loaded into this module by modules.m1_lipsync. The forward signature follows the app contract: visual lip crops plus an audio mel spectrogram produce frame-level logits. """ def __init__(self): super().__init__() self.visual = nn.Sequential( nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1), nn.ReLU(), nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), nn.ReLU(), nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(), ) self.audio = nn.Sequential( nn.Linear(1, 16), nn.ReLU(), ) self.classifier = nn.Sequential( nn.Linear(48, 32), nn.ReLU(), nn.Linear(32, 1), ) def forward(self, frames: torch.Tensor, audio: torch.Tensor) -> torch.Tensor: if frames.ndim == 3: frames = frames.unsqueeze(0) visual_feat = self.visual(frames) audio_level = audio.float().mean().reshape(1, 1).expand(visual_feat.size(0), 1) audio_feat = self.audio(audio_level) return self.classifier(torch.cat([visual_feat, audio_feat], dim=-1)).squeeze(-1)