File size: 5,077 Bytes
c5c9261
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""
Ensemble Detector — Combines multiple backbone models for superior detection.
Supports late fusion, learned fusion, and confidence-weighted strategies.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Dict
import logging

logger = logging.getLogger(__name__)


class LateFusionEnsemble(nn.Module):
    """
    Late Fusion: Average (or weighted average) of per-model probabilities.
    Simple, effective, and robust.
    """

    def __init__(self, models: List[nn.Module], weights: List[float] = None):
        super().__init__()
        self.models = nn.ModuleList(models)
        if weights is None:
            weights = [1.0 / len(models)] * len(models)
        assert len(weights) == len(models)
        self.weights = torch.tensor(weights, dtype=torch.float32)

    def forward(self, input_values: torch.Tensor,
                attention_mask: torch.Tensor = None) -> torch.Tensor:
        all_probs = []
        for model in self.models:
            logits = model(input_values, attention_mask)
            probs = F.softmax(logits, dim=-1)
            all_probs.append(probs)

        # Stack and compute weighted average
        stacked = torch.stack(all_probs, dim=0)  # (num_models, batch, num_labels)
        weights = self.weights.to(stacked.device).view(-1, 1, 1)
        fused = (stacked * weights).sum(dim=0)  # (batch, num_labels)

        # Return log probabilities (compatible with loss functions)
        return torch.log(fused + 1e-10)


class LearnedFusionEnsemble(nn.Module):
    """
    Learned Fusion: Small MLP trained on concatenated model outputs.
    More expressive than late fusion but requires end-to-end training.
    """

    def __init__(self, models: List[nn.Module], num_labels: int = 2):
        super().__init__()
        self.models = nn.ModuleList(models)
        total_labels = num_labels * len(models)

        self.fusion_head = nn.Sequential(
            nn.Linear(total_labels, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, num_labels),
        )

    def forward(self, input_values: torch.Tensor,
                attention_mask: torch.Tensor = None) -> torch.Tensor:
        all_logits = []
        for model in self.models:
            logits = model(input_values, attention_mask)
            all_logits.append(logits)

        concatenated = torch.cat(all_logits, dim=-1)  # (batch, num_labels * num_models)
        return self.fusion_head(concatenated)


class ConfidenceWeightedEnsemble(nn.Module):
    """
    Confidence-Weighted: Each model's vote is weighted by how confident it is.
    Models that are uncertain contribute less to the final prediction.
    """

    def __init__(self, models: List[nn.Module], temperature: float = 1.0):
        super().__init__()
        self.models = nn.ModuleList(models)
        self.temperature = temperature

    def forward(self, input_values: torch.Tensor,
                attention_mask: torch.Tensor = None) -> torch.Tensor:
        all_probs = []
        confidences = []

        for model in self.models:
            logits = model(input_values, attention_mask)
            probs = F.softmax(logits / self.temperature, dim=-1)
            confidence = probs.max(dim=-1)[0]  # Max probability as confidence
            all_probs.append(probs)
            confidences.append(confidence)

        # Normalize confidences to sum to 1
        conf_stack = torch.stack(confidences, dim=0)  # (num_models, batch)
        conf_weights = F.softmax(conf_stack, dim=0).unsqueeze(-1)  # (num_models, batch, 1)

        prob_stack = torch.stack(all_probs, dim=0)  # (num_models, batch, num_labels)
        fused = (prob_stack * conf_weights).sum(dim=0)  # (batch, num_labels)

        return torch.log(fused + 1e-10)


class EnsembleDetector:
    """Factory for creating ensemble models from config."""

    @staticmethod
    def create(models: List[nn.Module], strategy: str = "late_fusion",
               weights: List[float] = None, num_labels: int = 2) -> nn.Module:
        """
        Create an ensemble from a list of individual models.

        Args:
            models: List of DeepfakeClassifier instances
            strategy: "late_fusion", "learned_fusion", or "confidence_weighted"
            weights: Optional weights for late_fusion
            num_labels: Number of output classes

        Returns:
            Ensemble nn.Module
        """
        if strategy == "late_fusion":
            logger.info(f"Creating Late Fusion Ensemble ({len(models)} models)")
            return LateFusionEnsemble(models, weights)
        elif strategy == "learned_fusion":
            logger.info(f"Creating Learned Fusion Ensemble ({len(models)} models)")
            return LearnedFusionEnsemble(models, num_labels)
        elif strategy == "confidence_weighted":
            logger.info(f"Creating Confidence-Weighted Ensemble ({len(models)} models)")
            return ConfidenceWeightedEnsemble(models)
        else:
            raise ValueError(f"Unknown ensemble strategy: {strategy}")