DeepFakeDetectorBackend / app /models /wrappers /dummy_random_wrapper.py
lukhsaankumar's picture
Deploy DeepFake Detector API - 2026-03-07 09:12:00
df4a21a
"""
Wrapper for dummy random submodels.
"""
import importlib.util
import sys
from pathlib import Path
from typing import Any, Dict, Optional
from PIL import Image
from app.core.errors import InferenceError, ConfigurationError
from app.core.logging import get_logger
from app.models.wrappers.base_wrapper import BaseSubmodelWrapper
logger = get_logger(__name__)
class DummyRandomWrapper(BaseSubmodelWrapper):
"""
Wrapper for dummy random prediction models.
These models are hosted on Hugging Face and contain a predict.py
with a predict() function that returns random predictions.
"""
def __init__(
self,
repo_id: str,
config: Dict[str, Any],
local_path: str
):
"""
Initialize the wrapper.
Args:
repo_id: Hugging Face repository ID (e.g., "DeepFakeDetector/test-random-a")
config: Configuration from config.json
local_path: Local path where the model files are stored
"""
super().__init__(repo_id, config, local_path)
logger.info(f"Initialized DummyRandomWrapper for {repo_id}")
def load(self) -> None:
"""
Load the predict function from the downloaded repository.
Dynamically imports predict.py and extracts the predict function.
"""
predict_path = Path(self.local_path) / "predict.py"
if not predict_path.exists():
raise ConfigurationError(
message=f"predict.py not found in {self.local_path}",
details={"repo_id": self.repo_id, "expected_path": str(predict_path)}
)
try:
# Create a unique module name to avoid conflicts
module_name = f"hf_model_{self.name.replace('-', '_')}_predict"
# Load the module dynamically
spec = importlib.util.spec_from_file_location(module_name, predict_path)
if spec is None or spec.loader is None:
raise ConfigurationError(
message=f"Could not load spec for {predict_path}",
details={"repo_id": self.repo_id}
)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
# Get the predict function
if not hasattr(module, "predict"):
raise ConfigurationError(
message=f"predict.py does not have a 'predict' function",
details={"repo_id": self.repo_id}
)
self._predict_fn = module.predict
logger.info(f"Loaded predict function from {self.repo_id}")
except ConfigurationError:
raise
except Exception as e:
logger.error(f"Failed to load predict function from {self.repo_id}: {e}")
raise ConfigurationError(
message=f"Failed to load model: {e}",
details={"repo_id": self.repo_id, "error": str(e)}
)
def predict(
self,
image: Optional[Image.Image] = None,
image_bytes: Optional[bytes] = None,
**kwargs
) -> Dict[str, Any]:
"""
Run prediction on an image.
Args:
image: PIL Image object (optional for dummy model)
image_bytes: Raw image bytes (optional for dummy model)
**kwargs: Additional arguments passed to the model
Returns:
Standardized prediction dictionary with:
- pred_int: 0 or 1
- pred: "real" or "fake"
- prob_fake: float
- meta: dict
"""
if self._predict_fn is None:
raise InferenceError(
message="Model not loaded",
details={"repo_id": self.repo_id}
)
try:
# Call the actual predict function from the HF repo
result = self._predict_fn(image_bytes=image_bytes, **kwargs)
# Validate and standardize the output
standardized = self._standardize_output(result)
return standardized
except InferenceError:
raise
except Exception as e:
logger.error(f"Prediction failed for {self.repo_id}: {e}")
raise InferenceError(
message=f"Prediction failed: {e}",
details={"repo_id": self.repo_id, "error": str(e)}
)
def _standardize_output(self, result: Dict[str, Any]) -> Dict[str, Any]:
"""
Standardize the model output to ensure consistent format.
Args:
result: Raw model output
Returns:
Standardized dictionary
"""
pred_int = result.get("pred_int", 0)
# Ensure pred_int is 0 or 1
if pred_int not in (0, 1):
pred_int = 1 if pred_int > 0.5 else 0
# Generate pred label if not present
pred = result.get("pred")
if pred is None:
pred = "fake" if pred_int == 1 else "real"
# Generate prob_fake if not present
prob_fake = result.get("prob_fake")
if prob_fake is None:
prob_fake = float(pred_int)
return {
"pred_int": pred_int,
"pred": pred,
"prob_fake": float(prob_fake),
"meta": result.get("meta", {})
}