Spaces:
Sleeping
Sleeping
File size: 5,077 Bytes
c5c9261 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | """
Ensemble Detector — Combines multiple backbone models for superior detection.
Supports late fusion, learned fusion, and confidence-weighted strategies.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Dict
import logging
logger = logging.getLogger(__name__)
class LateFusionEnsemble(nn.Module):
"""
Late Fusion: Average (or weighted average) of per-model probabilities.
Simple, effective, and robust.
"""
def __init__(self, models: List[nn.Module], weights: List[float] = None):
super().__init__()
self.models = nn.ModuleList(models)
if weights is None:
weights = [1.0 / len(models)] * len(models)
assert len(weights) == len(models)
self.weights = torch.tensor(weights, dtype=torch.float32)
def forward(self, input_values: torch.Tensor,
attention_mask: torch.Tensor = None) -> torch.Tensor:
all_probs = []
for model in self.models:
logits = model(input_values, attention_mask)
probs = F.softmax(logits, dim=-1)
all_probs.append(probs)
# Stack and compute weighted average
stacked = torch.stack(all_probs, dim=0) # (num_models, batch, num_labels)
weights = self.weights.to(stacked.device).view(-1, 1, 1)
fused = (stacked * weights).sum(dim=0) # (batch, num_labels)
# Return log probabilities (compatible with loss functions)
return torch.log(fused + 1e-10)
class LearnedFusionEnsemble(nn.Module):
"""
Learned Fusion: Small MLP trained on concatenated model outputs.
More expressive than late fusion but requires end-to-end training.
"""
def __init__(self, models: List[nn.Module], num_labels: int = 2):
super().__init__()
self.models = nn.ModuleList(models)
total_labels = num_labels * len(models)
self.fusion_head = nn.Sequential(
nn.Linear(total_labels, 64),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(64, num_labels),
)
def forward(self, input_values: torch.Tensor,
attention_mask: torch.Tensor = None) -> torch.Tensor:
all_logits = []
for model in self.models:
logits = model(input_values, attention_mask)
all_logits.append(logits)
concatenated = torch.cat(all_logits, dim=-1) # (batch, num_labels * num_models)
return self.fusion_head(concatenated)
class ConfidenceWeightedEnsemble(nn.Module):
"""
Confidence-Weighted: Each model's vote is weighted by how confident it is.
Models that are uncertain contribute less to the final prediction.
"""
def __init__(self, models: List[nn.Module], temperature: float = 1.0):
super().__init__()
self.models = nn.ModuleList(models)
self.temperature = temperature
def forward(self, input_values: torch.Tensor,
attention_mask: torch.Tensor = None) -> torch.Tensor:
all_probs = []
confidences = []
for model in self.models:
logits = model(input_values, attention_mask)
probs = F.softmax(logits / self.temperature, dim=-1)
confidence = probs.max(dim=-1)[0] # Max probability as confidence
all_probs.append(probs)
confidences.append(confidence)
# Normalize confidences to sum to 1
conf_stack = torch.stack(confidences, dim=0) # (num_models, batch)
conf_weights = F.softmax(conf_stack, dim=0).unsqueeze(-1) # (num_models, batch, 1)
prob_stack = torch.stack(all_probs, dim=0) # (num_models, batch, num_labels)
fused = (prob_stack * conf_weights).sum(dim=0) # (batch, num_labels)
return torch.log(fused + 1e-10)
class EnsembleDetector:
"""Factory for creating ensemble models from config."""
@staticmethod
def create(models: List[nn.Module], strategy: str = "late_fusion",
weights: List[float] = None, num_labels: int = 2) -> nn.Module:
"""
Create an ensemble from a list of individual models.
Args:
models: List of DeepfakeClassifier instances
strategy: "late_fusion", "learned_fusion", or "confidence_weighted"
weights: Optional weights for late_fusion
num_labels: Number of output classes
Returns:
Ensemble nn.Module
"""
if strategy == "late_fusion":
logger.info(f"Creating Late Fusion Ensemble ({len(models)} models)")
return LateFusionEnsemble(models, weights)
elif strategy == "learned_fusion":
logger.info(f"Creating Learned Fusion Ensemble ({len(models)} models)")
return LearnedFusionEnsemble(models, num_labels)
elif strategy == "confidence_weighted":
logger.info(f"Creating Confidence-Weighted Ensemble ({len(models)} models)")
return ConfidenceWeightedEnsemble(models)
else:
raise ValueError(f"Unknown ensemble strategy: {strategy}")
|