File size: 6,452 Bytes
38691ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
"""

Interface abstraite pour les classifieurs de produits Rakuten.



Ce module définit le contrat que tous les classifieurs doivent respecter,

permettant une intégration facile des modèles développés par l'équipe.



Architecture:

- BaseClassifier: Classe abstraite définissant l'interface

- ClassificationResult: Dataclass pour structurer les résultats

- Les implémentations concrètes héritent de BaseClassifier

"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Optional, List, Tuple
import numpy as np
from PIL import Image


@dataclass
class ClassificationResult:
    """

    Résultat structuré d'une classification de produit.



    Attributes:

        category: Code de la catégorie prédite (ex: "2583")

        confidence: Score de confiance [0, 1] pour la prédiction principale

        top_k_predictions: Liste des (code_catégorie, score) triées par score décroissant

        source: Source de la prédiction ("image", "text", "multimodal", "mock")

        raw_probabilities: Vecteur complet des probabilités (27 classes)

    """
    category: str
    confidence: float
    top_k_predictions: List[Tuple[str, float]] = field(default_factory=list)
    source: str = "unknown"
    raw_probabilities: Optional[np.ndarray] = None

    def __post_init__(self):
        """Validation des données après initialisation."""
        if not 0 <= self.confidence <= 1:
            raise ValueError(f"Confidence must be in [0, 1], got {self.confidence}")

    def to_dict(self) -> dict:
        """Convertit le résultat en dictionnaire pour affichage."""
        return {
            "category": self.category,
            "confidence": self.confidence,
            "top_k": self.top_k_predictions,
            "source": self.source,
        }


class BaseClassifier(ABC):
    """

    Classe abstraite définissant l'interface pour tous les classifieurs.



    Cette interface permet d'intégrer facilement différents types de modèles:

    - Classifieur image (ResNet50 features + ML)

    - Classifieur texte (TF-IDF + ML)

    - Classifieur multimodal (fusion des deux)

    - Mock classifieur (pour tests UI)



    Usage:

        class MyClassifier(BaseClassifier):

            def predict(self, image=None, text=None) -> ClassificationResult:

                # Implémentation spécifique

                pass



            def load_model(self, path):

                # Chargement du modèle

                pass



    Les modèles des collègues doivent respecter cette interface pour

    être intégrés dans l'application Streamlit.

    """

    # Liste des 27 codes de catégories Rakuten
    CATEGORY_CODES = [
        "10", "40", "50", "60", "1140", "1160", "1180", "1280", "1281",
        "1300", "1301", "1302", "1320", "1560", "1920", "1940", "2060",
        "2220", "2280", "2403", "2462", "2522", "2582", "2583", "2585",
        "2705", "2905"
    ]

    NUM_CLASSES = 27

    @abstractmethod
    def predict(

        self,

        image: Optional[Image.Image] = None,

        text: Optional[str] = None,

        top_k: int = 5

    ) -> ClassificationResult:
        """

        Effectue une prédiction de catégorie pour un produit.



        Args:

            image: Image PIL du produit (optionnel selon le type de classifieur)

            text: Texte du produit (designation + description, optionnel)

            top_k: Nombre de prédictions à retourner



        Returns:

            ClassificationResult contenant la catégorie prédite et les scores



        Raises:

            ValueError: Si ni image ni texte n'est fourni (selon le type)

        """
        pass

    @abstractmethod
    def load_model(self, path: str) -> None:
        """

        Charge un modèle depuis un fichier.



        Args:

            path: Chemin vers le fichier du modèle (.joblib, .h5, etc.)



        Raises:

            FileNotFoundError: Si le fichier n'existe pas

            RuntimeError: Si le chargement échoue

        """
        pass

    @property
    @abstractmethod
    def is_ready(self) -> bool:
        """

        Vérifie si le classifieur est prêt à effectuer des prédictions.



        Returns:

            True si le modèle est chargé et fonctionnel

        """
        pass

    @property
    def name(self) -> str:
        """Retourne le nom du classifieur."""
        return self.__class__.__name__

    def _validate_inputs(

        self,

        image: Optional[Image.Image],

        text: Optional[str],

        require_image: bool = False,

        require_text: bool = False

    ) -> None:
        """

        Valide les entrées selon les exigences du classifieur.



        Args:

            image: Image à valider

            text: Texte à valider

            require_image: Si True, l'image est obligatoire

            require_text: Si True, le texte est obligatoire



        Raises:

            ValueError: Si les entrées requises sont manquantes

        """
        if require_image and image is None:
            raise ValueError(f"{self.name} requires an image input")
        if require_text and (text is None or text.strip() == ""):
            raise ValueError(f"{self.name} requires a text input")

    def _probabilities_to_predictions(

        self,

        probabilities: np.ndarray,

        top_k: int = 5

    ) -> List[Tuple[str, float]]:
        """

        Convertit un vecteur de probabilités en liste de prédictions triées.



        Args:

            probabilities: Array de shape (27,) avec les probabilités par classe

            top_k: Nombre de prédictions à retourner



        Returns:

            Liste de tuples (code_catégorie, score) triés par score décroissant

        """
        if len(probabilities) != self.NUM_CLASSES:
            raise ValueError(
                f"Expected {self.NUM_CLASSES} probabilities, got {len(probabilities)}"
            )

        # Indices triés par probabilité décroissante
        sorted_indices = np.argsort(probabilities)[::-1]

        predictions = []
        for idx in sorted_indices[:top_k]:
            category_code = self.CATEGORY_CODES[idx]
            score = float(probabilities[idx])
            predictions.append((category_code, score))

        return predictions