vineetshukla.work@gmail.com
final commit
c5c9261
"""
Ensemble Detector — Combines multiple backbone models for superior detection.
Supports late fusion, learned fusion, and confidence-weighted strategies.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Dict
import logging
logger = logging.getLogger(__name__)
class LateFusionEnsemble(nn.Module):
"""
Late Fusion: Average (or weighted average) of per-model probabilities.
Simple, effective, and robust.
"""
def __init__(self, models: List[nn.Module], weights: List[float] = None):
super().__init__()
self.models = nn.ModuleList(models)
if weights is None:
weights = [1.0 / len(models)] * len(models)
assert len(weights) == len(models)
self.weights = torch.tensor(weights, dtype=torch.float32)
def forward(self, input_values: torch.Tensor,
attention_mask: torch.Tensor = None) -> torch.Tensor:
all_probs = []
for model in self.models:
logits = model(input_values, attention_mask)
probs = F.softmax(logits, dim=-1)
all_probs.append(probs)
# Stack and compute weighted average
stacked = torch.stack(all_probs, dim=0) # (num_models, batch, num_labels)
weights = self.weights.to(stacked.device).view(-1, 1, 1)
fused = (stacked * weights).sum(dim=0) # (batch, num_labels)
# Return log probabilities (compatible with loss functions)
return torch.log(fused + 1e-10)
class LearnedFusionEnsemble(nn.Module):
"""
Learned Fusion: Small MLP trained on concatenated model outputs.
More expressive than late fusion but requires end-to-end training.
"""
def __init__(self, models: List[nn.Module], num_labels: int = 2):
super().__init__()
self.models = nn.ModuleList(models)
total_labels = num_labels * len(models)
self.fusion_head = nn.Sequential(
nn.Linear(total_labels, 64),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(64, num_labels),
)
def forward(self, input_values: torch.Tensor,
attention_mask: torch.Tensor = None) -> torch.Tensor:
all_logits = []
for model in self.models:
logits = model(input_values, attention_mask)
all_logits.append(logits)
concatenated = torch.cat(all_logits, dim=-1) # (batch, num_labels * num_models)
return self.fusion_head(concatenated)
class ConfidenceWeightedEnsemble(nn.Module):
"""
Confidence-Weighted: Each model's vote is weighted by how confident it is.
Models that are uncertain contribute less to the final prediction.
"""
def __init__(self, models: List[nn.Module], temperature: float = 1.0):
super().__init__()
self.models = nn.ModuleList(models)
self.temperature = temperature
def forward(self, input_values: torch.Tensor,
attention_mask: torch.Tensor = None) -> torch.Tensor:
all_probs = []
confidences = []
for model in self.models:
logits = model(input_values, attention_mask)
probs = F.softmax(logits / self.temperature, dim=-1)
confidence = probs.max(dim=-1)[0] # Max probability as confidence
all_probs.append(probs)
confidences.append(confidence)
# Normalize confidences to sum to 1
conf_stack = torch.stack(confidences, dim=0) # (num_models, batch)
conf_weights = F.softmax(conf_stack, dim=0).unsqueeze(-1) # (num_models, batch, 1)
prob_stack = torch.stack(all_probs, dim=0) # (num_models, batch, num_labels)
fused = (prob_stack * conf_weights).sum(dim=0) # (batch, num_labels)
return torch.log(fused + 1e-10)
class EnsembleDetector:
"""Factory for creating ensemble models from config."""
@staticmethod
def create(models: List[nn.Module], strategy: str = "late_fusion",
weights: List[float] = None, num_labels: int = 2) -> nn.Module:
"""
Create an ensemble from a list of individual models.
Args:
models: List of DeepfakeClassifier instances
strategy: "late_fusion", "learned_fusion", or "confidence_weighted"
weights: Optional weights for late_fusion
num_labels: Number of output classes
Returns:
Ensemble nn.Module
"""
if strategy == "late_fusion":
logger.info(f"Creating Late Fusion Ensemble ({len(models)} models)")
return LateFusionEnsemble(models, weights)
elif strategy == "learned_fusion":
logger.info(f"Creating Learned Fusion Ensemble ({len(models)} models)")
return LearnedFusionEnsemble(models, num_labels)
elif strategy == "confidence_weighted":
logger.info(f"Creating Confidence-Weighted Ensemble ({len(models)} models)")
return ConfidenceWeightedEnsemble(models)
else:
raise ValueError(f"Unknown ensemble strategy: {strategy}")