File size: 4,279 Bytes
df4a21a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | """
Base wrapper class for model wrappers.
"""
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Optional
from PIL import Image
class BaseModelWrapper(ABC):
"""
Abstract base class for model wrappers.
All model wrappers should inherit from this class and implement
the abstract methods.
"""
def __init__(
self,
repo_id: str,
config: Dict[str, Any],
local_path: str
):
"""
Initialize the wrapper.
Args:
repo_id: Hugging Face repository ID
config: Configuration from config.json
local_path: Local path where the model files are stored
"""
self.repo_id = repo_id
self.config = config
self.local_path = local_path
self._predict_fn: Optional[Callable] = None
@property
def name(self) -> str:
"""
Get the short name of the model.
Prefers 'name' from config if available, otherwise derives from repo_id.
Strips '-final' suffix to ensure consistency with fusion configs.
"""
# Try to get name from config first
config_name = self.config.get("name")
if config_name:
# Strip -final suffix if present
return config_name.replace("-final", "")
# Fall back to repo_id last part, strip -final suffix
repo_name = self.repo_id.split("/")[-1]
return repo_name.replace("-final", "")
@abstractmethod
def load(self) -> None:
"""
Load the model and prepare for inference.
This method should import the predict function from the downloaded
repository and store it for later use.
"""
pass
@abstractmethod
def predict(self, *args, **kwargs) -> Dict[str, Any]:
"""
Run prediction.
Returns:
Dictionary with standardized prediction fields:
- pred_int: 0 (real) or 1 (fake)
- pred: "real" or "fake"
- prob_fake: float probability
- meta: dict with any additional metadata
"""
pass
def is_loaded(self) -> bool:
"""Check if the model is loaded and ready for inference."""
return self._predict_fn is not None
def get_info(self) -> Dict[str, Any]:
"""
Get model information.
Returns:
Dictionary with model info
"""
return {
"repo_id": self.repo_id,
"name": self.name,
"config": self.config,
"local_path": self.local_path,
"is_loaded": self.is_loaded()
}
class BaseSubmodelWrapper(BaseModelWrapper):
"""Base wrapper for submodels that process images."""
@abstractmethod
def predict(
self,
image: Optional[Image.Image] = None,
image_bytes: Optional[bytes] = None,
explain: bool = False,
**kwargs
) -> Dict[str, Any]:
"""
Run prediction on an image.
Args:
image: PIL Image object
image_bytes: Raw image bytes (alternative to image)
explain: If True, include explainability heatmap in output
**kwargs: Additional arguments
Returns:
Standardized prediction dictionary with:
- pred_int: 0 (real) or 1 (fake)
- pred: "real" or "fake"
- prob_fake: float probability
- heatmap_base64: Optional[str] (when explain=True)
- explainability_type: Optional[str] (when explain=True)
"""
pass
class BaseFusionWrapper(BaseModelWrapper):
"""Base wrapper for fusion models that combine submodel outputs."""
@abstractmethod
def predict(
self,
submodel_outputs: Dict[str, Dict[str, Any]],
**kwargs
) -> Dict[str, Any]:
"""
Run fusion prediction on submodel outputs.
Args:
submodel_outputs: Dictionary mapping submodel name to its output
**kwargs: Additional arguments
Returns:
Standardized prediction dictionary
"""
pass
|