deepdetection / modules /m5_fusion.py
akagtag's picture
align project with CLAUDE spec and hf space deploy
cf54850
from __future__ import annotations
import os
import torch
import torch.nn as nn
class FusionMLP(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(3, 16)
self.fc2 = nn.Linear(16, 3)
def forward(self, scores: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
hidden = torch.relu(self.fc1(scores))
alpha = torch.softmax(self.fc2(hidden), dim=-1)
return (alpha * scores).sum(), alpha
class FusionModule:
def __init__(self, weights_path: str = "weights/fusion_mlp.pt"):
self.model = FusionMLP()
if os.path.exists(weights_path):
self.model.load_state_dict(torch.load(weights_path, map_location="cpu"))
self.model.eval()
def fuse(self, s1: float, s2: float, s3: float) -> dict:
scores = torch.tensor([s1, s2, s3], dtype=torch.float32)
with torch.no_grad():
fakescore, alpha = self.model(scores)
return {
"FakeScore": round(float(fakescore.item()), 4),
"weights": {
"lip_sync": round(float(alpha[0].item()), 3),
"fingerprint": round(float(alpha[1].item()), 3),
"graph_gnn": round(float(alpha[2].item()), 3),
},
}