from __future__ import annotations import os import torch import torch.nn as nn class FusionMLP(nn.Module): def __init__(self): super().__init__() self.fc1 = nn.Linear(3, 16) self.fc2 = nn.Linear(16, 3) def forward(self, scores: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: hidden = torch.relu(self.fc1(scores)) alpha = torch.softmax(self.fc2(hidden), dim=-1) return (alpha * scores).sum(), alpha class FusionModule: def __init__(self, weights_path: str = "weights/fusion_mlp.pt"): self.model = FusionMLP() if os.path.exists(weights_path): self.model.load_state_dict(torch.load(weights_path, map_location="cpu")) self.model.eval() def fuse(self, s1: float, s2: float, s3: float) -> dict: scores = torch.tensor([s1, s2, s3], dtype=torch.float32) with torch.no_grad(): fakescore, alpha = self.model(scores) return { "FakeScore": round(float(fakescore.item()), 4), "weights": { "lip_sync": round(float(alpha[0].item()), 3), "fingerprint": round(float(alpha[1].item()), 3), "graph_gnn": round(float(alpha[2].item()), 3), }, }