File size: 4,279 Bytes
df4a21a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""
Base wrapper class for model wrappers.
"""

from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Optional

from PIL import Image


class BaseModelWrapper(ABC):
    """
    Abstract base class for model wrappers.
    
    All model wrappers should inherit from this class and implement
    the abstract methods.
    """
    
    def __init__(
        self,
        repo_id: str,
        config: Dict[str, Any],
        local_path: str
    ):
        """
        Initialize the wrapper.
        
        Args:
            repo_id: Hugging Face repository ID
            config: Configuration from config.json
            local_path: Local path where the model files are stored
        """
        self.repo_id = repo_id
        self.config = config
        self.local_path = local_path
        self._predict_fn: Optional[Callable] = None
    
    @property
    def name(self) -> str:
        """
        Get the short name of the model.
        
        Prefers 'name' from config if available, otherwise derives from repo_id.
        Strips '-final' suffix to ensure consistency with fusion configs.
        """
        # Try to get name from config first
        config_name = self.config.get("name")
        if config_name:
            # Strip -final suffix if present
            return config_name.replace("-final", "")
        
        # Fall back to repo_id last part, strip -final suffix
        repo_name = self.repo_id.split("/")[-1]
        return repo_name.replace("-final", "")
    
    @abstractmethod
    def load(self) -> None:
        """
        Load the model and prepare for inference.
        
        This method should import the predict function from the downloaded
        repository and store it for later use.
        """
        pass
    
    @abstractmethod
    def predict(self, *args, **kwargs) -> Dict[str, Any]:
        """
        Run prediction.
        
        Returns:
            Dictionary with standardized prediction fields:
            - pred_int: 0 (real) or 1 (fake)
            - pred: "real" or "fake"
            - prob_fake: float probability
            - meta: dict with any additional metadata
        """
        pass
    
    def is_loaded(self) -> bool:
        """Check if the model is loaded and ready for inference."""
        return self._predict_fn is not None
    
    def get_info(self) -> Dict[str, Any]:
        """
        Get model information.
        
        Returns:
            Dictionary with model info
        """
        return {
            "repo_id": self.repo_id,
            "name": self.name,
            "config": self.config,
            "local_path": self.local_path,
            "is_loaded": self.is_loaded()
        }


class BaseSubmodelWrapper(BaseModelWrapper):
    """Base wrapper for submodels that process images."""
    
    @abstractmethod
    def predict(
        self,
        image: Optional[Image.Image] = None,
        image_bytes: Optional[bytes] = None,
        explain: bool = False,
        **kwargs
    ) -> Dict[str, Any]:
        """
        Run prediction on an image.
        
        Args:
            image: PIL Image object
            image_bytes: Raw image bytes (alternative to image)
            explain: If True, include explainability heatmap in output
            **kwargs: Additional arguments
            
        Returns:
            Standardized prediction dictionary with:
            - pred_int: 0 (real) or 1 (fake)
            - pred: "real" or "fake"
            - prob_fake: float probability
            - heatmap_base64: Optional[str] (when explain=True)
            - explainability_type: Optional[str] (when explain=True)
        """
        pass


class BaseFusionWrapper(BaseModelWrapper):
    """Base wrapper for fusion models that combine submodel outputs."""
    
    @abstractmethod
    def predict(
        self,
        submodel_outputs: Dict[str, Dict[str, Any]],
        **kwargs
    ) -> Dict[str, Any]:
        """
        Run fusion prediction on submodel outputs.
        
        Args:
            submodel_outputs: Dictionary mapping submodel name to its output
            **kwargs: Additional arguments
            
        Returns:
            Standardized prediction dictionary
        """
        pass