deepdetection / lipfd /model.py
akagtag's picture
Implement ZeroGPU Space runtime
eff3d67
from __future__ import annotations
import torch
import torch.nn as nn
class LipFDNet(nn.Module):
"""
Minimal LipFD-compatible network wrapper for Space inference.
The hosted checkpoint is loaded into this module by modules.m1_lipsync.
The forward signature follows the app contract: visual lip crops plus an
audio mel spectrogram produce frame-level logits.
"""
def __init__(self):
super().__init__()
self.visual = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(),
)
self.audio = nn.Sequential(
nn.Linear(1, 16),
nn.ReLU(),
)
self.classifier = nn.Sequential(
nn.Linear(48, 32),
nn.ReLU(),
nn.Linear(32, 1),
)
def forward(self, frames: torch.Tensor, audio: torch.Tensor) -> torch.Tensor:
if frames.ndim == 3:
frames = frames.unsqueeze(0)
visual_feat = self.visual(frames)
audio_level = audio.float().mean().reshape(1, 1).expand(visual_feat.size(0), 1)
audio_feat = self.audio(audio_level)
return self.classifier(torch.cat([visual_feat, audio_feat], dim=-1)).squeeze(-1)