File size: 5,918 Bytes
df4a21a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
"""
Wrapper for dummy majority vote fusion model.
"""

import importlib.util
import sys
from pathlib import Path
from typing import Any, Dict, List

from app.core.errors import FusionError, ConfigurationError
from app.core.logging import get_logger
from app.models.wrappers.base_wrapper import BaseFusionWrapper

logger = get_logger(__name__)


class DummyMajorityFusionWrapper(BaseFusionWrapper):
    """
    Wrapper for dummy majority vote fusion models.
    
    These models are hosted on Hugging Face and contain a fusion.py
    with a predict() function that performs majority voting on submodel outputs.
    """
    
    def __init__(
        self,
        repo_id: str,
        config: Dict[str, Any],
        local_path: str
    ):
        """
        Initialize the wrapper.
        
        Args:
            repo_id: Hugging Face repository ID (e.g., "DeepFakeDetector/fusion-majority-test")
            config: Configuration from config.json
            local_path: Local path where the model files are stored
        """
        super().__init__(repo_id, config, local_path)
        self._submodel_repos: List[str] = config.get("submodels", [])
        logger.info(f"Initialized DummyMajorityFusionWrapper for {repo_id}")
        logger.info(f"Submodels: {self._submodel_repos}")
    
    @property
    def submodel_repos(self) -> List[str]:
        """Get list of submodel repository IDs."""
        return self._submodel_repos
    
    def load(self) -> None:
        """
        Load the fusion predict function from the downloaded repository.
        
        Dynamically imports predict.py and extracts the predict function.
        """
        fusion_path = Path(self.local_path) / "predict.py"
        
        if not fusion_path.exists():
            raise ConfigurationError(
                message=f"predict.py not found in {self.local_path}",
                details={"repo_id": self.repo_id, "expected_path": str(fusion_path)}
            )
        
        try:
            # Create a unique module name to avoid conflicts
            module_name = f"hf_model_{self.name.replace('-', '_')}_fusion"
            
            # Load the module dynamically
            spec = importlib.util.spec_from_file_location(module_name, fusion_path)
            if spec is None or spec.loader is None:
                raise ConfigurationError(
                    message=f"Could not load spec for {fusion_path}",
                    details={"repo_id": self.repo_id}
                )
            
            module = importlib.util.module_from_spec(spec)
            sys.modules[module_name] = module
            spec.loader.exec_module(module)
            
            # Get the predict function
            if not hasattr(module, "predict"):
                raise ConfigurationError(
                    message=f"predict.py does not have a 'predict' function",
                    details={"repo_id": self.repo_id}
                )
            
            self._predict_fn = module.predict
            logger.info(f"Loaded fusion predict function from {self.repo_id}")
            
        except ConfigurationError:
            raise
        except Exception as e:
            logger.error(f"Failed to load fusion function from {self.repo_id}: {e}")
            raise ConfigurationError(
                message=f"Failed to load fusion model: {e}",
                details={"repo_id": self.repo_id, "error": str(e)}
            )
    
    def predict(
        self,
        submodel_outputs: Dict[str, Dict[str, Any]],
        **kwargs
    ) -> Dict[str, Any]:
        """
        Run fusion prediction on submodel outputs.
        
        Args:
            submodel_outputs: Dictionary mapping submodel name to its prediction output
            **kwargs: Additional arguments passed to the fusion function
            
        Returns:
            Standardized prediction dictionary with:
            - pred_int: 0 or 1
            - pred: "real" or "fake"
            - prob_fake: float (average of pred_ints)
            - meta: dict
        """
        if self._predict_fn is None:
            raise FusionError(
                message="Fusion model not loaded",
                details={"repo_id": self.repo_id}
            )
        
        try:
            # Call the actual fusion predict function from the HF repo
            result = self._predict_fn(submodel_outputs=submodel_outputs, **kwargs)
            
            # Validate and standardize the output
            standardized = self._standardize_output(result)
            return standardized
            
        except FusionError:
            raise
        except Exception as e:
            logger.error(f"Fusion prediction failed for {self.repo_id}: {e}")
            raise FusionError(
                message=f"Fusion prediction failed: {e}",
                details={"repo_id": self.repo_id, "error": str(e)}
            )
    
    def _standardize_output(self, result: Dict[str, Any]) -> Dict[str, Any]:
        """
        Standardize the fusion output to ensure consistent format.
        
        Args:
            result: Raw fusion output
            
        Returns:
            Standardized dictionary
        """
        pred_int = result.get("pred_int", 0)
        
        # Ensure pred_int is 0 or 1
        if pred_int not in (0, 1):
            pred_int = 1 if pred_int > 0.5 else 0
        
        # Generate pred label if not present
        pred = result.get("pred")
        if pred is None:
            pred = "fake" if pred_int == 1 else "real"
        
        # Generate prob_fake if not present
        prob_fake = result.get("prob_fake")
        if prob_fake is None:
            prob_fake = float(pred_int)
        
        return {
            "pred_int": pred_int,
            "pred": pred,
            "prob_fake": float(prob_fake),
            "meta": result.get("meta", {})
        }