""" Base wrapper class for model wrappers. """ from abc import ABC, abstractmethod from typing import Any, Callable, Dict, Optional from PIL import Image class BaseModelWrapper(ABC): """ Abstract base class for model wrappers. All model wrappers should inherit from this class and implement the abstract methods. """ def __init__( self, repo_id: str, config: Dict[str, Any], local_path: str ): """ Initialize the wrapper. Args: repo_id: Hugging Face repository ID config: Configuration from config.json local_path: Local path where the model files are stored """ self.repo_id = repo_id self.config = config self.local_path = local_path self._predict_fn: Optional[Callable] = None @property def name(self) -> str: """ Get the short name of the model. Prefers 'name' from config if available, otherwise derives from repo_id. Strips '-final' suffix to ensure consistency with fusion configs. """ # Try to get name from config first config_name = self.config.get("name") if config_name: # Strip -final suffix if present return config_name.replace("-final", "") # Fall back to repo_id last part, strip -final suffix repo_name = self.repo_id.split("/")[-1] return repo_name.replace("-final", "") @abstractmethod def load(self) -> None: """ Load the model and prepare for inference. This method should import the predict function from the downloaded repository and store it for later use. """ pass @abstractmethod def predict(self, *args, **kwargs) -> Dict[str, Any]: """ Run prediction. Returns: Dictionary with standardized prediction fields: - pred_int: 0 (real) or 1 (fake) - pred: "real" or "fake" - prob_fake: float probability - meta: dict with any additional metadata """ pass def is_loaded(self) -> bool: """Check if the model is loaded and ready for inference.""" return self._predict_fn is not None def get_info(self) -> Dict[str, Any]: """ Get model information. Returns: Dictionary with model info """ return { "repo_id": self.repo_id, "name": self.name, "config": self.config, "local_path": self.local_path, "is_loaded": self.is_loaded() } class BaseSubmodelWrapper(BaseModelWrapper): """Base wrapper for submodels that process images.""" @abstractmethod def predict( self, image: Optional[Image.Image] = None, image_bytes: Optional[bytes] = None, explain: bool = False, **kwargs ) -> Dict[str, Any]: """ Run prediction on an image. Args: image: PIL Image object image_bytes: Raw image bytes (alternative to image) explain: If True, include explainability heatmap in output **kwargs: Additional arguments Returns: Standardized prediction dictionary with: - pred_int: 0 (real) or 1 (fake) - pred: "real" or "fake" - prob_fake: float probability - heatmap_base64: Optional[str] (when explain=True) - explainability_type: Optional[str] (when explain=True) """ pass class BaseFusionWrapper(BaseModelWrapper): """Base wrapper for fusion models that combine submodel outputs.""" @abstractmethod def predict( self, submodel_outputs: Dict[str, Dict[str, Any]], **kwargs ) -> Dict[str, Any]: """ Run fusion prediction on submodel outputs. Args: submodel_outputs: Dictionary mapping submodel name to its output **kwargs: Additional arguments Returns: Standardized prediction dictionary """ pass