File size: 3,874 Bytes
d13c106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""
Model Registry - Central place to register and manage all models.

This module makes it easy to add new models for different datasets.
Each model handler should implement the BaseModelHandler interface.
"""

from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple, Any
import numpy as np
from PIL import Image


class PredictionResult:
    """Container for prediction results from a model."""

    def __init__(
        self,
        label: str,
        confidence: float,
        all_labels: List[str],
        all_confidences: List[float],
        explanation_image: Optional[np.ndarray] = None,
    ):
        self.label = label
        self.confidence = confidence
        self.all_labels = all_labels
        self.all_confidences = all_confidences
        self.explanation_image = explanation_image  # Grad-CAM or attention map


class CalibrationResult:
    """Container for model calibration analysis results."""

    def __init__(
        self,
        ece: float,
        bin_accuracies: List[float],
        bin_confidences: List[float],
        bin_counts: List[int],
        reliability_diagram: Optional[Any] = None,
        source: Optional[str] = None,
    ):
        self.ece = ece
        self.bin_accuracies = bin_accuracies
        self.bin_confidences = bin_confidences
        self.bin_counts = bin_counts
        self.reliability_diagram = reliability_diagram
        self.source = source


class BaseModelHandler(ABC):
    """
    Abstract base class for model handlers.
    
    To add a new model, create a subclass and implement all abstract methods.
    Then register it in the MODEL_REGISTRY dictionary below.
    """

    @abstractmethod
    def get_model_name(self) -> str:
        """Return human-readable model name."""
        pass

    @abstractmethod
    def get_dataset_name(self) -> str:
        """Return the dataset name this model was trained on."""
        pass

    @abstractmethod
    def get_data_type(self) -> str:
        """Return data type: 'image', 'text', or 'multimodal'."""
        pass

    @abstractmethod
    def get_class_labels(self) -> List[str]:
        """Return list of class labels."""
        pass

    @abstractmethod
    def get_model_info(self) -> Dict[str, str]:
        """Return dict of model info for display (architecture, params, etc.)."""
        pass

    @abstractmethod
    def predict(self, input_data) -> PredictionResult:
        """
        Run prediction on input data.
        
        For image models: input_data is a PIL Image or numpy array
        For text models: input_data is a string
        For multimodal: input_data is a tuple (image, text)
        
        Returns: PredictionResult
        """
        pass

    @abstractmethod
    def get_example_inputs(self) -> List[Any]:
        """Return list of example inputs for the demo."""
        pass

    def get_calibration_data(
        self, max_samples: Optional[int] = None
    ) -> Optional[CalibrationResult]:
        """
        Optionally return calibration analysis result.
        Override this in subclass if you want calibration display.
        """
        return None


# Global model registry - add new models here
MODEL_REGISTRY: Dict[str, BaseModelHandler] = {}


def register_model(key: str, handler: BaseModelHandler):
    """Register a model handler in the global registry."""
    MODEL_REGISTRY[key] = handler


def get_model_handler(key: str) -> Optional[BaseModelHandler]:
    """Get a model handler by key."""
    return MODEL_REGISTRY.get(key)


def get_all_model_keys() -> List[str]:
    """Get all registered model keys."""
    return list(MODEL_REGISTRY.keys())


def get_models_by_type(data_type: str) -> Dict[str, BaseModelHandler]:
    """Get all models of a specific data type."""
    return {k: v for k, v in MODEL_REGISTRY.items() if v.get_data_type() == data_type}