Spaces:
Paused
Paused
| from __future__ import annotations | |
| import os | |
| import torch | |
| import torch.nn as nn | |
| class FusionMLP(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.fc1 = nn.Linear(3, 16) | |
| self.fc2 = nn.Linear(16, 3) | |
| def forward(self, scores: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
| hidden = torch.relu(self.fc1(scores)) | |
| alpha = torch.softmax(self.fc2(hidden), dim=-1) | |
| return (alpha * scores).sum(), alpha | |
| class FusionModule: | |
| def __init__(self, weights_path: str = "weights/fusion_mlp.pt"): | |
| self.model = FusionMLP() | |
| if os.path.exists(weights_path): | |
| self.model.load_state_dict(torch.load(weights_path, map_location="cpu")) | |
| self.model.eval() | |
| def fuse(self, s1: float, s2: float, s3: float) -> dict: | |
| scores = torch.tensor([s1, s2, s3], dtype=torch.float32) | |
| with torch.no_grad(): | |
| fakescore, alpha = self.model(scores) | |
| return { | |
| "FakeScore": round(float(fakescore.item()), 4), | |
| "weights": { | |
| "lip_sync": round(float(alpha[0].item()), 3), | |
| "fingerprint": round(float(alpha[1].item()), 3), | |
| "graph_gnn": round(float(alpha[2].item()), 3), | |
| }, | |
| } | |