File size: 14,510 Bytes
8c4cb40
 
 
 
 
 
 
 
 
 
ef9dd71
8c4cb40
 
 
 
 
 
 
 
 
 
 
a9fc500
8c4cb40
 
 
 
 
 
 
58919ac
16244b7
8c4cb40
 
 
 
 
 
 
 
ef9dd71
8c4cb40
 
 
7f8732e
58919ac
8c4cb40
7f8732e
58919ac
8c4cb40
 
 
 
 
 
 
 
 
ef9dd71
 
8c4cb40
 
 
 
 
 
 
 
ef9dd71
 
 
 
8c4cb40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef9dd71
8c4cb40
 
 
 
 
 
 
 
 
 
 
 
ef9dd71
8c4cb40
 
 
 
 
ef9dd71
 
 
 
8c4cb40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a9fc500
 
 
8c4cb40
a9fc500
 
 
 
 
 
ef9dd71
8c4cb40
a9fc500
 
 
 
 
 
 
 
 
 
 
 
8c4cb40
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
import tensorflow as tf
import numpy as np
from PIL import Image
import tensorflow as tf
import logging
import numpy as np
from PIL import Image
from keras.applications.efficientnet_v2 import preprocess_input as effnet_preprocess
from keras.applications.resnet_v2 import preprocess_input as resnet_preprocess
import io
from tf_keras_vis.gradcam import Gradcam,GradcamPlusPlus
from tf_keras_vis.utils import normalize

import numpy as np
import tensorflow as tf
from tf_keras_vis.saliency import Saliency
from tf_keras_vis.utils import normalize
import numpy as np
import tensorflow as tf
from tf_keras_vis.saliency import Saliency
from tf_keras_vis.utils import normalize
import logging
import time

from typing import TypedDict, Callable, Any
logging.basicConfig(
    level=logging.INFO,  # ou logging.DEBUG
    format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
)
logger = logging.getLogger(__name__)
confidence_threshold=0.5
entropy_threshold=2

class ModelStruct(TypedDict):
    model_name: str
    model: tf.keras.Model
    gradcam_model:tf.keras.Model
    preprocess_input: Callable[[np.ndarray], Any]
    target_size: tuple[int, int]
    last_conv_layer:str
    gradcam_type:str


def load_models() -> list[ModelStruct]:
    model1 = tf.keras.models.load_model("model/best_efficientnetv2m_gradcam.keras",compile=False)
    model1_for_gradcam=model1
    
    model2 = tf.keras.models.load_model("model/best_ResNet50V2_gradcam.keras",compile=False)
    model2_for_gradcam=model2

    return [
    
        {
        "model_name":"EfficientNetV2M",
        "model": model1,
        "gradcam_model":model1_for_gradcam,
        "preprocess_input": effnet_preprocess,
        "target_size": (480, 480),
        "last_conv_layer":"top_activation",
        "gradcam_type":"gradcam++"

    },
    {
        "model_name":"ResNet50V2",
        "gradcam_model":model2_for_gradcam,
        "model":model2,
        "preprocess_input":resnet_preprocess,
        "target_size":(224, 224),
        "last_conv_layer":"conv5_block3_out",
        "gradcam_type":"gradcam"
        #"gradcam_type":"gradcam++"


    }
   
    ]



def compute_saliency_map(model, image_array, class_index=None):
    """
    Calcule la carte de saillance avec tf-keras-vis Saliency.

    Args:
        model: tf.keras.Model.
        image_array: np.array, shape (H, W, 3), float32, pré-traitée.
        class_index: int ou None. Si None, prend la classe prédite.

    Returns:
        saliency_map: np.array float32, normalisée entre 0 et 1, shape (H, W).
    """
    logging.info("Début du calcul de la carte de saillance")

    if image_array.ndim == 3:
        input_tensor = np.expand_dims(image_array, axis=0)
        logging.debug(f"Image d'entrée dimensionnée de {image_array.shape} à {input_tensor.shape} (batch)")
    else:
        input_tensor = image_array
        logging.debug(f"Image d'entrée déjà batchée avec shape {input_tensor.shape}")

    saliency = Saliency(model)
    logging.info("Objet Saliency initialisé")

    def loss(output):
        # output shape: (batch_size, num_classes)
        if class_index is None:
            class_index_local = tf.argmax(output[0])
            logging.info(f"Classe cible non spécifiée, utilisation de la classe prédite: {class_index_local.numpy()}")
        else:
            class_index_local = class_index
            logging.info(f"Classe cible spécifiée: {class_index_local}")
        return output[:, class_index_local]

    saliency_map = saliency(loss, input_tensor)
    logging.info("Calcul de la carte de saillance terminé")

    saliency_map = saliency_map[0]  # shape (H, W, 3)
    logging.debug(f"Shape de la carte brute: {saliency_map.shape}")

    # Prendre le max absolu sur les canaux couleurs pour avoir une carte 2D
    
    if saliency_map.ndim == 3:
        saliency_map = np.max(np.abs(saliency_map), axis=-1)
    else:
        saliency_map = np.abs(saliency_map)
    logging.debug(f"Shape de la carte après réduction canaux: {saliency_map.shape}")

    # Normaliser la carte entre 0 et 1
    saliency_map = normalize(saliency_map)
    logging.info("Normalisation de la carte de saillance terminée")

    return saliency_map


def compute_gradcam(model, image_array, class_index=None, layer_name=None,gradcam_type="gradcam"):
    """
    Calcule la carte Grad-CAM pour une image et un modèle Keras.

    Args:
        model: tf.keras.Model.
        image_array: np.array (H, W, 3), float32, pré-traitée.
        class_index: int ou None, index de la classe cible. Si None, classe prédite.
        layer_name: str ou None, nom de la couche convolutionnelle à utiliser. Si None, dernière conv.

    Returns:
        gradcam_map: np.array (H, W), normalisée entre 0 et 1.
    """
    logging.info(f"Lancement calcul de la gradcam avec le type {gradcam_type}")

    if image_array.ndim == 3:
        input_tensor = np.expand_dims(image_array, axis=0)
    else:
        input_tensor = image_array
    if gradcam_type=="gradcam++":
        gradcam = GradcamPlusPlus(model, clone=False)
    else: 
        gradcam = Gradcam(model, clone=False)

    def loss(output):
        if class_index is None:
            class_index_local = tf.argmax(output[0])
        else:
            class_index_local = class_index
        return output[:, class_index_local]

    # Choisir la couche à utiliser pour GradCAM
    if layer_name is None:
        # Si non spécifié, chercher la dernière couche conv 2D
        for layer in reversed(model.layers):
            if 'conv' in layer.name and len(layer.output_shape) == 4:
                layer_name = layer.name
                break
        if layer_name is None:
            raise ValueError("Aucune couche convolutionnelle 2D trouvée dans le modèle.")

    cam = gradcam(loss, input_tensor, penultimate_layer=layer_name)
    cam = cam[0]

    # Normaliser entre 0 et 1
    cam = normalize(cam)

    return cam


def make_gradcam_heatmap(model,img_array,  pred_index=None):
    """
    Calcule la Grad-CAM heatmap pour une image donnée et un modèle Keras.

    Args:
      img_array (numpy array): image d'entrée preprocessée, shape (1, H, W, C)
      model (tf.keras.Model): modèle Keras complet (avec la tête softmax)
      last_conv_layer_name (str): nom de la dernière couche conv (ex: 'top_conv' dans EfficientNetB0)
      pred_index (int, optional): indice de la classe ciblée. Par défaut la classe prédite par le modèle.

    Returns:
      heatmap (numpy array): heatmap 2D normalisée entre 0 et 1
    """

    # 1. Récupérer la dernière couche conv
    back_bone=model.get_layer("efficientnetv2-m")
    last_conv_layer=back_bone.get_layer('top_activation')
    # last_conv_layer = next(
    # x for x in back_bone.layers[::-1] if isinstance(x,tf.keras.layers.Conv2D) #tf.keras.layers.MBConvBlock)
    # )

    
    # 2. Créer un modèle intermédiaire qui donne les activations de la dernière couche conv
    # et la sortie finale du modèle
    grad_model = tf.keras.models.Model(
        [model.inputs], [last_conv_layer.output, model.output]
    )

    with tf.GradientTape() as tape:
        conv_outputs, predictions = grad_model([img_array])
        if pred_index is None:
            pred_index = tf.argmax(predictions[0])
        class_channel = predictions[:, pred_index]

    # 3. Calculer le gradient des logits de la classe cible par rapport aux activations conv
    grads = tape.gradient(class_channel, conv_outputs)

    # 4. Moyenne globale des gradients sur les axes spatiaux (H, W)
    pooled_grads = tf.reduce_mean(grads, axis=(1, 2))

    # 5. Pondérer chaque canal des activations par les gradients moyens
    conv_outputs = conv_outputs[0]
    pooled_grads = pooled_grads[0]

    heatmap = tf.reduce_sum(tf.multiply(pooled_grads, conv_outputs), axis=-1)

    # 6. Normaliser la heatmap entre 0 et 1
    heatmap = tf.maximum(heatmap, 0) / (tf.reduce_max(heatmap) + 1e-8)
    heatmap = heatmap.numpy()

    return heatmap

def compute_saliency_map_basic(model, image_array, class_index=None):
    logger.info("🧠 Début du calcul de la salience...")
    """
    Calcule la carte de saillance (saliency map) d'une image pour un modèle donné.

    Args:
        model : modèle Keras TensorFlow.
        image_array : np.array, image d'entrée pré-traitée, shape (H, W, 3), float32.
        class_index : int ou None, index de la classe cible. 
                      Si None, on prend la prédiction la plus probable.

    Returns:
        saliency_map : np.array, carte de saillance normalisée (valeurs entre 0 et 1), shape (H, W).
    """
    if image_array.ndim == 3:
        image_tensor = tf.expand_dims(image_array, axis=0)
    else:
        image_tensor = tf.convert_to_tensor(image_array)

    image_tensor = tf.cast(image_tensor, tf.float32)
    with tf.GradientTape() as tape:
        tape.watch(image_tensor)
        preds = model(image_tensor,training=False)
        if class_index is None:
            class_index = tf.argmax(preds[0])
        class_score = preds[:, class_index]

    grads = tape.gradient(class_score, image_tensor)  # gradient de la sortie par rapport à l'image
    saliency = tf.reduce_max(tf.abs(grads), axis=-1)[0]  # max sur canaux couleur, retirer batch

    # Normaliser entre 0 et 1
    saliency = (saliency - tf.reduce_min(saliency)) / (tf.reduce_max(saliency) - tf.reduce_min(saliency) + 1e-10)
    
    logger.info("✅ Fin du calcul de la salience.")
    return saliency.numpy()
 


def preprocess_image(image_bytes, target_size, preprocess_input):
    try:
        logger.info("📤 Lecture des bytes et conversion en image PIL")
        image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
    except Exception as e:
        logger.exception("❌ Erreur lors de l'ouverture de l'image")
        raise ValueError("Impossible de décoder l'image") from e

    logger.info(f"📐 Redimensionnement de l'image à la taille {target_size}")
    image = image.resize(target_size)
    image_array = np.array(image).astype(np.float32)

    logger.debug(f"🔍 Shape de l'image après conversion en tableau : {image_array.shape}")

    if image_array.ndim != 3 or image_array.shape[-1] != 3:
        logger.error(f"❌ Image invalide : shape={image_array.shape}")
        raise ValueError("Image must have 3 channels (RGB)")

    logger.info("🎨 Conversion et prétraitement de l'image")

    # Préparation pour la prédiction
    preprocessed_input = preprocess_input(image_array.copy())
    preprocessed_input = np.expand_dims(preprocessed_input, axis=0)

    # Préparation pour Grad-CAM (non prétraitée, mais batchifiée et en float32)
    raw_input = np.expand_dims(image_array / 255.0, axis=0)  # Mise à l’échelle simple

    logger.debug(f"🧪 Shape après ajout de la dimension batch : {preprocessed_input.shape}")
    return preprocessed_input, raw_input


def preprocess_image_v1(image_bytes, target_size,preprocess_input) -> np.ndarray:
     
    try:
        logger.info("📤 Lecture des bytes et conversion en image PIL")
        image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
    except Exception as e:
        logger.exception("❌ Erreur lors de l'ouverture de l'image")
        raise ValueError("Impossible de décoder l'image") from e
    
    logger.info(f"📐 Redimensionnement de l'image à la taille {target_size}")
    image = image.resize(target_size)
    image_array = np.array(image)

    logger.debug(f"🔍 Shape de l'image après conversion en tableau : {image_array.shape}")

    # Forcer RGB si jamais ce n’est pas déjà le cas
    if image_array.ndim != 3 or image_array.shape[-1] != 3:
        logger.error(f"❌ Image invalide : shape={image_array.shape}")
        raise ValueError("Image must have 3 channels (RGB)")

    logger.info("🎨 Conversion et prétraitement de l'image")
    image_array = preprocess_input(image_array)

    processed = np.expand_dims(image_array, axis=0)
    logger.debug(f"🧪 Shape après ajout de la dimension batch : {processed.shape}")
    return processed


def compute_entropy_safe(probas):
    probas = np.array(probas)
    # On garde uniquement les probabilités strictement positives
    mask = probas > 0
    entropy = -np.sum(probas[mask] * np.log(probas[mask]))
    return entropy


def predict_with_model(config, image_bytes: bytes,show_heatmap=False):
   
    input_array,raw_input = preprocess_image(image_bytes,config["target_size"],config["preprocess_input"])

    logger.info("🤖 Lancement de la prédiction avec le modèle")
    preds = config["model"].predict(input_array)
    logger.debug(f"📈 Prédictions brutes : {preds[0].tolist()}")

    predicted_class_index = int(np.argmax(preds[0]))
    confidence = float(preds[0][predicted_class_index])
    entropy=float(compute_entropy_safe(preds))
    is_uncertain_model= (confidence<confidence_threshold) or (entropy>entropy_threshold)
    logger.info(f"✅ Prédiction : classe={predicted_class_index}, confiance={confidence:.4f},entropy={entropy:.4f},is_uncertain_model={is_uncertain_model}")

    result= {
        "preds": preds[0].tolist(),
        "predicted_class": predicted_class_index,
        "confidence": confidence,
        "entropy":entropy,
        "is_uncertain_model":is_uncertain_model
    }
    if show_heatmap and not is_uncertain_model:
        try:
            logger.info("✅ Début de la génération de la heatmap")
            start_time = time.time()

            # Vérification des entrées
            logger.info(f"🖼️ Image d'entrée shape: {raw_input.shape}")
            logger.info(f"🎯 Index de classe prédite: {predicted_class_index}")
            logger.info(f"🛠️ Dernière couche utilisée: {config['last_conv_layer']}")

            # Calcul de la heatmap
            heatmap = compute_gradcam(config["gradcam_model"], raw_input, class_index=predicted_class_index, layer_name=config["last_conv_layer"],gradcam_type=config["gradcam_type"])

            elapsed_time = time.time() - start_time
            logger.info(f"✅ Heatmap générée en {elapsed_time:.2f} secondes")

            # Conversion en liste pour le JSON
            result["heatmap"] = heatmap.tolist()

        except Exception as e:
            logger.error(f"❌ Erreur lors de la génération de la heatmap: {e}")
            result["heatmap"] = []
    else:
        logger.info("ℹ️ Heatmap non générée (option désactivée ou modèle incertain)")
        result["heatmap"] = []


    return result