File size: 17,620 Bytes
f725085
 
 
 
 
 
 
 
 
 
 
9b3e2a4
f725085
 
 
 
 
 
 
 
 
 
9b3e2a4
2a4e3dd
f725085
9b3e2a4
2a4e3dd
9b3e2a4
2a4e3dd
9b3e2a4
f725085
 
 
 
 
 
 
 
2a044a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f725085
 
 
 
2a044a1
f725085
 
 
 
 
 
 
 
 
 
 
 
2a044a1
f725085
 
 
 
 
2a044a1
f725085
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a044a1
f725085
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b3e2a4
 
 
 
 
 
 
2a4e3dd
 
9b3e2a4
2a4e3dd
 
 
 
9b3e2a4
 
 
2a4e3dd
9b3e2a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a4e3dd
9b3e2a4
 
 
 
 
 
 
 
2a044a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f725085
2a4e3dd
 
 
 
 
 
 
 
 
 
2a044a1
 
 
 
 
 
f725085
2a044a1
 
f725085
2a044a1
 
f725085
2a044a1
2a4e3dd
2a044a1
f725085
 
2a044a1
f725085
2a4e3dd
b7645dd
2a4e3dd
 
 
 
f725085
 
2a044a1
 
 
f725085
2a044a1
 
 
 
 
 
 
 
 
f725085
 
2a044a1
 
 
 
 
f725085
2a044a1
 
 
 
 
 
f725085
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
import tensorflow as tf
import numpy as np
from PIL import Image
import tensorflow as tf
import logging
import numpy as np
from PIL import Image
from keras.applications.efficientnet_v2 import preprocess_input as effnet_preprocess
import io
from tf_keras_vis.gradcam import Gradcam,GradcamPlusPlus
from tf_keras_vis.utils import normalize
import hashlib
import numpy as np
import tensorflow as tf
from tf_keras_vis.saliency import Saliency
from tf_keras_vis.utils import normalize
import numpy as np
import tensorflow as tf
from tf_keras_vis.saliency import Saliency
from tf_keras_vis.utils import normalize
import logging
import time
import os
import diskcache as dc
from typing import TypedDict, Callable, Any

CACHE_DIR = './cache'
os.makedirs(CACHE_DIR, exist_ok=True)
cache = dc.Cache(CACHE_DIR)

logging.basicConfig(
    level=logging.INFO,  # ou logging.DEBUG
    format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
)
logger = logging.getLogger(__name__)
confidence_threshold=0.55
entropy_threshold=2




class TFLiteDynamicModel:
    def __init__(self, tflite_path, img_size=224):
        logger.info(f"🚀 Chargement du modèle TFLite depuis : {tflite_path}")
        self.img_size = img_size
        self.interpreter = tf.lite.Interpreter(model_path=tflite_path)
        self.interpreter.allocate_tensors()

        input_details = self.interpreter.get_input_details()
        output_details = self.interpreter.get_output_details()

        self.input_index = input_details[0]['index']
        self.input_dtype = input_details[0]['dtype']
        self.input_scale, self.input_zero_point = input_details[0]['quantization']
        self.output_index = output_details[0]['index']

        logger.info(f"🔍 Input tensor index : {self.input_index}, dtype : {self.input_dtype}, scale : {self.input_scale}, zero_point : {self.input_zero_point}")
        logger.info(f"🔍 Output tensor index : {self.output_index}")


    def preprocess(self, pil_image):
        logger.info(f"🎨 Prétraitement image, redimension à {self.img_size}x{self.img_size}")
        img = pil_image.resize((self.img_size, self.img_size))
        img = np.array(img)

        # 📸 Gestion des images grayscale ou RGBA
        if img.ndim == 2:  # Grayscale -> RGB
            logger.debug("⚪ Image grayscale détectée, conversion en RGB")
            img = np.stack([img] * 3, axis=-1)
        elif img.shape[-1] == 4:  # RGBA -> RGB
            logger.debug("🖼️ Image RGBA détectée, suppression canal alpha")
            img = img[..., :3]

        if self.input_dtype in [np.uint8, np.int8]:
            logger.info("🗜️ Modèle quantifié PTQ détecté (entrée int8 ou uint8)")

            # Pas de division par 255 ici ! On quantifie directement selon l'échelle et le zéro-point
            img = img.astype(np.float32)
            img = img / self.input_scale + self.input_zero_point

            # On clip selon le type
            if self.input_dtype == np.uint8:
                img = np.clip(img, 0, 255)
            else:  # np.int8
                img = np.clip(img, -128, 127)

            img = img.astype(self.input_dtype)

        else:
            logger.info("🌊 Modèle dynamique ou float32 détecté (entrée float32 normalisée)")
            img = img.astype(self.input_dtype)  # Normalisation classique

        input_data = np.expand_dims(img, axis=0)
        logger.info(f"✅ Image prétraitée avec forme {input_data.shape} et dtype {input_data.dtype}")
        return input_data


    def preprocess_old(self, pil_image):
        logger.info(f"🎨 Prétraitement image, redimension à {self.img_size}x{self.img_size}")
        img = pil_image.resize((self.img_size, self.img_size))
        img = np.array(img).astype(np.float32)

        if img.ndim == 2:  # grayscale -> RGB
            logger.debug("⚪ Image grayscale détectée, conversion en RGB")
            img = np.stack([img]*3, axis=-1)
        elif img.shape[-1] == 4:  # RGBA -> RGB
            logger.debug("🖼️ Image RGBA détectée, suppression canal alpha")
            img = img[..., :3]

        if self.input_dtype in [np.uint8, np.int8]:
            if self.input_scale > 0:
                logger.debug(f"⚙️ Application quantification dynamique avec scale {self.input_scale} et zero_point {self.input_zero_point}")
                img = img / 255.0
                img = img / self.input_scale + self.input_zero_point
                img = np.clip(img, 0, 255 if self.input_dtype == np.uint8 else 127)
            img = img.astype(self.input_dtype)
        else:
            img = img.astype(self.input_dtype)

        input_data = np.expand_dims(img, axis=0)
        logger.info(f"✅ Image prétraitée avec forme {input_data.shape} et dtype {input_data.dtype}")
        return input_data

    

    def predict_dyna(self, pil_image):
        logger.info("⚡ Début de prédiction (modèle dynamique ou float32)")

        # Prétraitement
        logger.info("🔄 Prétraitement de l'image en cours")
        input_data = self.preprocess(pil_image)
        logger.debug(f"✅ Image prétraitée - Shape : {input_data.shape} - Dtype : {input_data.dtype}")

        # Injection des données dans le modèle
        logger.info("📥 Injection des données dans le modèle")
        self.interpreter.set_tensor(self.input_index, input_data)

        # Invocation du modèle
        logger.info("🚀 Exécution du modèle TFLite")
        self.interpreter.invoke()

        # Récupération de la sortie
        logger.info("📤 Récupération des résultats bruts")
        output_data = self.interpreter.get_tensor(self.output_index)
        logger.debug(f"✅ Logits récupérés - Shape : {output_data.shape} - Dtype : {output_data.dtype}")

        # Calcul des probabilités
        logger.info("🧮 Calcul des probabilités")
        probas=output_data[0]
        logger.debug(f"✅ Probabilités : {probas}")

    
        logger.info("🎯 Prédiction terminée")

        return  probas


    def predict_ptq(self, pil_image):
        logger.info("⚡ Début de prédiction")
        
        # Prétraitement
        logger.info("🔄 Prétraitement de l'image en cours")
        input_data = self.preprocess(pil_image)
        logger.debug(f"✅ Image prétraitée - Shape : {input_data.shape} - Dtype : {input_data.dtype}")

        # Injection des données dans le modèle
        logger.info("📥 Injection des données dans le modèle")
        self.interpreter.set_tensor(self.input_index, input_data)

        # Invocation du modèle
        logger.info("🚀 Exécution du modèle TFLite")
        self.interpreter.invoke()

        # Récupération de la sortie
        logger.info("📤 Récupération des résultats bruts")
        output_details = self.interpreter.get_output_details()[0]
        output_data = self.interpreter.get_tensor(output_details['index'])
        logger.debug(f"✅ Logits quantifiés récupérés - Shape : {output_data.shape} - Dtype : {output_data.dtype}")

        # Paramètres de quantification
        output_scale, output_zero_point = output_details['quantization']
        logger.debug(f"ℹ️ Paramètres de quantification - Scale: {output_scale}, Zero Point: {output_zero_point}")

        # Déquantification
        logger.info("🔓 Déquantification des logits")
        logits = (output_data.astype(np.float32) - output_zero_point) * output_scale
        logger.debug(f"✅ Logits déquantifiés : {logits}")

        # Calcul des probabilités
        logger.info("🧮 Calcul des probabilités avec softmax")
        probas = logits[0]
        logger.debug(f"✅ Probabilités : {probas}")

        logger.info("🎯 Prédiction terminée")

        return probas
    
    def predict (self, pil_image):
        if self.input_dtype in [np.uint8, np.int8]:
            logger.info("🗜️ Modèle quantifié PTQ détecté")
            return self.predict_ptq(pil_image)
        else:
            logger.info("🌊 Modèle dynamique ou float32 détecté")
            return self.predict_dyna(pil_image)


class ModelStruct(TypedDict):
    model_name: str
    model: tf.keras.Model
    gradcam_model:tf.keras.Model
    fast_model:TFLiteDynamicModel
    preprocess_input: Callable[[np.ndarray], Any]
    target_size: tuple[int, int]
    last_conv_layer:str
    gradcam_type:str


_model_cache: list[ModelStruct] | None = None
def load_model() -> list[ModelStruct]:
    global _model_cache
    if _model_cache is None:
        print("📦 Chargement du modèle EfficientNetV2M...")
        model = tf.keras.models.load_model("model/best_efficientnetv2m_gradcam.keras", compile=False)
        fast_model=TFLiteDynamicModel("model/efficientnetv2m_float16.tflite", img_size=480)

        _model_cache = [{
            "model_name": "EfficientNetV2M",
            "model": model,
            "gradcam_model": model, 
            "fast_model":fast_model,
            "preprocess_input": effnet_preprocess,
            "target_size": (480, 480),
            "last_conv_layer": "block7a_expand_conv",
            "gradcam_type": "gradcam++"
        }]
    return _model_cache



def compute_gradcam(model, image_array, class_index=None, layer_name=None,gradcam_type="gradcam"):
    """
    Calcule la carte Grad-CAM pour une image et un modèle Keras.

    Args:
        model: tf.keras.Model.
        image_array: np.array (H, W, 3), float32, pré-traitée.
        class_index: int ou None, index de la classe cible. Si None, classe prédite.
        layer_name: str ou None, nom de la couche convolutionnelle à utiliser. Si None, dernière conv.

    Returns:
        gradcam_map: np.array (H, W), normalisée entre 0 et 1.
    """
    logging.info(f"Lancement calcul de la gradcam avec le type {gradcam_type}")

    if image_array.ndim == 3:
        input_tensor = np.expand_dims(image_array, axis=0)
    else:
        input_tensor = image_array
    if gradcam_type=="gradcam++":
        gradcam = GradcamPlusPlus(model, clone=False)
    else: 
        gradcam = Gradcam(model, clone=False)

    def loss(output):
        if class_index is None:
            class_index_local = tf.argmax(output[0])
        else:
            class_index_local = class_index
        return output[:, class_index_local]

    # Choisir la couche à utiliser pour GradCAM
    if layer_name is None:
        # Si non spécifié, chercher la dernière couche conv 2D
        for layer in reversed(model.layers):
            if 'conv' in layer.name and len(layer.output_shape) == 4:
                layer_name = layer.name
                break
        if layer_name is None:
            raise ValueError("Aucune couche convolutionnelle 2D trouvée dans le modèle.")

    cam = gradcam(loss, input_tensor, penultimate_layer=layer_name)
    cam = cam[0]

    # Normaliser entre 0 et 1
    cam = normalize(cam)

    return cam

 

def preprocess_image(image_bytes, target_size, preprocess_input):
    try:
        logger.info("📤 Lecture des bytes et conversion en image PIL")
        image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
    except Exception as e:
        logger.exception("❌ Erreur lors de l'ouverture de l'image")
        raise ValueError("Impossible de décoder l'image") from e

    logger.info(f"📐 Redimensionnement de l'image à la taille {target_size}")
    image = image.resize(target_size)
    image_array = np.array(image).astype(np.float32)

    logger.debug(f"🔍 Shape de l'image après conversion en tableau : {image_array.shape}")

    if image_array.ndim != 3 or image_array.shape[-1] != 3:
        logger.error(f"❌ Image invalide : shape={image_array.shape}")
        raise ValueError("Image must have 3 channels (RGB)")

    logger.info("🎨 Conversion et prétraitement de l'image")

    # Préparation pour la prédiction
    preprocessed_input = preprocess_input(image_array.copy())
    preprocessed_input = np.expand_dims(preprocessed_input, axis=0)

    # Préparation pour Grad-CAM (non prétraitée, mais batchifiée et en float32)
    raw_input = np.expand_dims(image_array / 255.0, axis=0)  # Mise à l’échelle simple

    logger.debug(f"🧪 Shape après ajout de la dimension batch : {preprocessed_input.shape}")
    return preprocessed_input, raw_input



def compute_entropy_safe(probas):
    probas = np.array(probas)
    # On garde uniquement les probabilités strictement positives
    mask = probas > 0
    entropy = -np.sum(probas[mask] * np.log(probas[mask]))
    return entropy


def hash_image_bytes(image_bytes):
    return hashlib.md5(image_bytes).hexdigest()


def get_heatmap(config, image_bytes: bytes, predicted_class_index):
    result = {}
    try:
        hash_key = hash_image_bytes(image_bytes)
        heatmap_key = f"{hash_key}_heatmap"

       # Vérification cache mémoire d'abord
        if heatmap_key in cache:
            logger.info(f"✅ Heatmap trouvée dans le cache {heatmap_key}")
            result["heatmap"] = cache[heatmap_key]
            return result


        #
        # Calcul si non trouvé dans le cache
        _, raw_input = preprocess_image(
            image_bytes, config["target_size"], config["preprocess_input"]
        )
        logger.info("✅ Début de la génération de la heatmap")
        start_time = time.time()

        heatmap = compute_gradcam(
            config["gradcam_model"],
            raw_input,
            class_index=predicted_class_index,
            layer_name=config["last_conv_layer"],
            gradcam_type=config["gradcam_type"],
        )

        elapsed_time = time.time() - start_time
        logger.info(f"✅ Heatmap générée en {elapsed_time:.2f} secondes")

        # Conversion en liste pour le JSON
        heatmap_list = heatmap.tolist()
        result["heatmap"] = heatmap_list
        cache[heatmap_key] = heatmap_list
    except Exception as e:
        logger.error(f"❌ Erreur lors de la génération de la heatmap: {e}")
        result["heatmap"] = []

    return result


def get_heatmap_old(config, image_bytes: bytes,predicted_class_index):
    result={}
    try:
        _,raw_input = preprocess_image(image_bytes,config["target_size"],config["preprocess_input"])
        logger.info("✅ Début de la génération de la heatmap")
        start_time = time.time()

        # Vérification des entrées
        logger.info(f"🖼️ Image d'entrée shape: {raw_input.shape}")
        logger.info(f"🎯 Index de classe prédite: {predicted_class_index}")
        logger.info(f"🛠️ Dernière couche utilisée: {config['last_conv_layer']}")

        # Calcul de la heatmap
        heatmap = compute_gradcam(config["gradcam_model"], raw_input, class_index=predicted_class_index, layer_name=config["last_conv_layer"],gradcam_type=config["gradcam_type"])

        elapsed_time = time.time() - start_time
        logger.info(f"✅ Heatmap générée en {elapsed_time:.2f} secondes")

        # Conversion en liste pour le JSON
        result["heatmap"] = heatmap.tolist()

    except Exception as e:
        logger.error(f"❌ Erreur lors de la génération de la heatmap: {e}")
        result["heatmap"] = []
    return result




def predict_with_cache(config, image_bytes: bytes):
    hash_key = hash_image_bytes(image_bytes)
    pred_key = f"{hash_key}_pred"

    if pred_key in cache:
        logger.info(f"✅ prédiction trouvée dans le cache {hash_key}")
        return cache[pred_key]
    
    try:
        logger.info("📤 Lecture des bytes et conversion en image PIL")
        image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
    except Exception as e:
        logger.exception("❌ Erreur lors de l'ouverture de l'image")
        raise ValueError("Impossible de décoder l'image") from e
    logger.info("🤖 Lancement de la prédiction avec le modèle")
    preds = config["fast_model"].predict(image)
    logger.info(f"📈 Prédictions brutes : {preds.tolist()}")

    predicted_class_index = int(np.argmax(preds))
    confidence = float(np.max(preds))
    entropy=float(compute_entropy_safe(preds))
    logger.info(f"✅ Prédiction : classe={predicted_class_index}, confiance={confidence:.4f},entropy={entropy:.4f}")
    result={
        "preds": preds.tolist(),
        "predicted_class": predicted_class_index,
        "confidence": confidence,
        "entropy":entropy
    }
    cache[pred_key] = result
    return  result

def predict_with_model(config, image_bytes: bytes):
   return predict_with_cache(config, image_bytes)
   


def predict_with_model_old(config, image_bytes: bytes):
   
    #input_array,raw_input = preprocess_image(image_bytes,config["target_size"],config["preprocess_input"])

    try:
        logger.info("📤 Lecture des bytes et conversion en image PIL")
        image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
    except Exception as e:
        logger.exception("❌ Erreur lors de l'ouverture de l'image")
        raise ValueError("Impossible de décoder l'image") from e
    logger.info("🤖 Lancement de la prédiction avec le modèle")
    preds = config["fast_model"].predict(image)
    logger.info(f"📈 Prédictions brutes : {preds[0].tolist()}")


    
    predicted_class_index = int(np.argmax(preds[0]))
    confidence = float(preds[0][predicted_class_index])
    entropy=float(compute_entropy_safe(preds))
    logger.info(f"✅ Prédiction : classe={predicted_class_index}, confiance={confidence:.4f},entropy={entropy:.4f}")

    return  {
        "preds": preds[0].tolist(),
        "predicted_class": predicted_class_index,
        "confidence": confidence,
        "entropy":entropy
    }