File size: 5,918 Bytes
df4a21a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 | """
Wrapper for dummy majority vote fusion model.
"""
import importlib.util
import sys
from pathlib import Path
from typing import Any, Dict, List
from app.core.errors import FusionError, ConfigurationError
from app.core.logging import get_logger
from app.models.wrappers.base_wrapper import BaseFusionWrapper
logger = get_logger(__name__)
class DummyMajorityFusionWrapper(BaseFusionWrapper):
"""
Wrapper for dummy majority vote fusion models.
These models are hosted on Hugging Face and contain a fusion.py
with a predict() function that performs majority voting on submodel outputs.
"""
def __init__(
self,
repo_id: str,
config: Dict[str, Any],
local_path: str
):
"""
Initialize the wrapper.
Args:
repo_id: Hugging Face repository ID (e.g., "DeepFakeDetector/fusion-majority-test")
config: Configuration from config.json
local_path: Local path where the model files are stored
"""
super().__init__(repo_id, config, local_path)
self._submodel_repos: List[str] = config.get("submodels", [])
logger.info(f"Initialized DummyMajorityFusionWrapper for {repo_id}")
logger.info(f"Submodels: {self._submodel_repos}")
@property
def submodel_repos(self) -> List[str]:
"""Get list of submodel repository IDs."""
return self._submodel_repos
def load(self) -> None:
"""
Load the fusion predict function from the downloaded repository.
Dynamically imports predict.py and extracts the predict function.
"""
fusion_path = Path(self.local_path) / "predict.py"
if not fusion_path.exists():
raise ConfigurationError(
message=f"predict.py not found in {self.local_path}",
details={"repo_id": self.repo_id, "expected_path": str(fusion_path)}
)
try:
# Create a unique module name to avoid conflicts
module_name = f"hf_model_{self.name.replace('-', '_')}_fusion"
# Load the module dynamically
spec = importlib.util.spec_from_file_location(module_name, fusion_path)
if spec is None or spec.loader is None:
raise ConfigurationError(
message=f"Could not load spec for {fusion_path}",
details={"repo_id": self.repo_id}
)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
# Get the predict function
if not hasattr(module, "predict"):
raise ConfigurationError(
message=f"predict.py does not have a 'predict' function",
details={"repo_id": self.repo_id}
)
self._predict_fn = module.predict
logger.info(f"Loaded fusion predict function from {self.repo_id}")
except ConfigurationError:
raise
except Exception as e:
logger.error(f"Failed to load fusion function from {self.repo_id}: {e}")
raise ConfigurationError(
message=f"Failed to load fusion model: {e}",
details={"repo_id": self.repo_id, "error": str(e)}
)
def predict(
self,
submodel_outputs: Dict[str, Dict[str, Any]],
**kwargs
) -> Dict[str, Any]:
"""
Run fusion prediction on submodel outputs.
Args:
submodel_outputs: Dictionary mapping submodel name to its prediction output
**kwargs: Additional arguments passed to the fusion function
Returns:
Standardized prediction dictionary with:
- pred_int: 0 or 1
- pred: "real" or "fake"
- prob_fake: float (average of pred_ints)
- meta: dict
"""
if self._predict_fn is None:
raise FusionError(
message="Fusion model not loaded",
details={"repo_id": self.repo_id}
)
try:
# Call the actual fusion predict function from the HF repo
result = self._predict_fn(submodel_outputs=submodel_outputs, **kwargs)
# Validate and standardize the output
standardized = self._standardize_output(result)
return standardized
except FusionError:
raise
except Exception as e:
logger.error(f"Fusion prediction failed for {self.repo_id}: {e}")
raise FusionError(
message=f"Fusion prediction failed: {e}",
details={"repo_id": self.repo_id, "error": str(e)}
)
def _standardize_output(self, result: Dict[str, Any]) -> Dict[str, Any]:
"""
Standardize the fusion output to ensure consistent format.
Args:
result: Raw fusion output
Returns:
Standardized dictionary
"""
pred_int = result.get("pred_int", 0)
# Ensure pred_int is 0 or 1
if pred_int not in (0, 1):
pred_int = 1 if pred_int > 0.5 else 0
# Generate pred label if not present
pred = result.get("pred")
if pred is None:
pred = "fake" if pred_int == 1 else "real"
# Generate prob_fake if not present
prob_fake = result.get("prob_fake")
if prob_fake is None:
prob_fake = float(pred_int)
return {
"pred_int": pred_int,
"pred": pred,
"prob_fake": float(prob_fake),
"meta": result.get("meta", {})
}
|