rkonan's picture
correction bug
b7645dd
import tensorflow as tf
import numpy as np
from PIL import Image
import tensorflow as tf
import logging
import numpy as np
from PIL import Image
from keras.applications.efficientnet_v2 import preprocess_input as effnet_preprocess
import io
from tf_keras_vis.gradcam import Gradcam,GradcamPlusPlus
from tf_keras_vis.utils import normalize
import hashlib
import numpy as np
import tensorflow as tf
from tf_keras_vis.saliency import Saliency
from tf_keras_vis.utils import normalize
import numpy as np
import tensorflow as tf
from tf_keras_vis.saliency import Saliency
from tf_keras_vis.utils import normalize
import logging
import time
import os
import diskcache as dc
from typing import TypedDict, Callable, Any
CACHE_DIR = './cache'
os.makedirs(CACHE_DIR, exist_ok=True)
cache = dc.Cache(CACHE_DIR)
logging.basicConfig(
level=logging.INFO, # ou logging.DEBUG
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
)
logger = logging.getLogger(__name__)
confidence_threshold=0.55
entropy_threshold=2
class TFLiteDynamicModel:
def __init__(self, tflite_path, img_size=224):
logger.info(f"🚀 Chargement du modèle TFLite depuis : {tflite_path}")
self.img_size = img_size
self.interpreter = tf.lite.Interpreter(model_path=tflite_path)
self.interpreter.allocate_tensors()
input_details = self.interpreter.get_input_details()
output_details = self.interpreter.get_output_details()
self.input_index = input_details[0]['index']
self.input_dtype = input_details[0]['dtype']
self.input_scale, self.input_zero_point = input_details[0]['quantization']
self.output_index = output_details[0]['index']
logger.info(f"🔍 Input tensor index : {self.input_index}, dtype : {self.input_dtype}, scale : {self.input_scale}, zero_point : {self.input_zero_point}")
logger.info(f"🔍 Output tensor index : {self.output_index}")
def preprocess(self, pil_image):
logger.info(f"🎨 Prétraitement image, redimension à {self.img_size}x{self.img_size}")
img = pil_image.resize((self.img_size, self.img_size))
img = np.array(img)
# 📸 Gestion des images grayscale ou RGBA
if img.ndim == 2: # Grayscale -> RGB
logger.debug("⚪ Image grayscale détectée, conversion en RGB")
img = np.stack([img] * 3, axis=-1)
elif img.shape[-1] == 4: # RGBA -> RGB
logger.debug("🖼️ Image RGBA détectée, suppression canal alpha")
img = img[..., :3]
if self.input_dtype in [np.uint8, np.int8]:
logger.info("🗜️ Modèle quantifié PTQ détecté (entrée int8 ou uint8)")
# Pas de division par 255 ici ! On quantifie directement selon l'échelle et le zéro-point
img = img.astype(np.float32)
img = img / self.input_scale + self.input_zero_point
# On clip selon le type
if self.input_dtype == np.uint8:
img = np.clip(img, 0, 255)
else: # np.int8
img = np.clip(img, -128, 127)
img = img.astype(self.input_dtype)
else:
logger.info("🌊 Modèle dynamique ou float32 détecté (entrée float32 normalisée)")
img = img.astype(self.input_dtype) # Normalisation classique
input_data = np.expand_dims(img, axis=0)
logger.info(f"✅ Image prétraitée avec forme {input_data.shape} et dtype {input_data.dtype}")
return input_data
def preprocess_old(self, pil_image):
logger.info(f"🎨 Prétraitement image, redimension à {self.img_size}x{self.img_size}")
img = pil_image.resize((self.img_size, self.img_size))
img = np.array(img).astype(np.float32)
if img.ndim == 2: # grayscale -> RGB
logger.debug("⚪ Image grayscale détectée, conversion en RGB")
img = np.stack([img]*3, axis=-1)
elif img.shape[-1] == 4: # RGBA -> RGB
logger.debug("🖼️ Image RGBA détectée, suppression canal alpha")
img = img[..., :3]
if self.input_dtype in [np.uint8, np.int8]:
if self.input_scale > 0:
logger.debug(f"⚙️ Application quantification dynamique avec scale {self.input_scale} et zero_point {self.input_zero_point}")
img = img / 255.0
img = img / self.input_scale + self.input_zero_point
img = np.clip(img, 0, 255 if self.input_dtype == np.uint8 else 127)
img = img.astype(self.input_dtype)
else:
img = img.astype(self.input_dtype)
input_data = np.expand_dims(img, axis=0)
logger.info(f"✅ Image prétraitée avec forme {input_data.shape} et dtype {input_data.dtype}")
return input_data
def predict_dyna(self, pil_image):
logger.info("⚡ Début de prédiction (modèle dynamique ou float32)")
# Prétraitement
logger.info("🔄 Prétraitement de l'image en cours")
input_data = self.preprocess(pil_image)
logger.debug(f"✅ Image prétraitée - Shape : {input_data.shape} - Dtype : {input_data.dtype}")
# Injection des données dans le modèle
logger.info("📥 Injection des données dans le modèle")
self.interpreter.set_tensor(self.input_index, input_data)
# Invocation du modèle
logger.info("🚀 Exécution du modèle TFLite")
self.interpreter.invoke()
# Récupération de la sortie
logger.info("📤 Récupération des résultats bruts")
output_data = self.interpreter.get_tensor(self.output_index)
logger.debug(f"✅ Logits récupérés - Shape : {output_data.shape} - Dtype : {output_data.dtype}")
# Calcul des probabilités
logger.info("🧮 Calcul des probabilités")
probas=output_data[0]
logger.debug(f"✅ Probabilités : {probas}")
logger.info("🎯 Prédiction terminée")
return probas
def predict_ptq(self, pil_image):
logger.info("⚡ Début de prédiction")
# Prétraitement
logger.info("🔄 Prétraitement de l'image en cours")
input_data = self.preprocess(pil_image)
logger.debug(f"✅ Image prétraitée - Shape : {input_data.shape} - Dtype : {input_data.dtype}")
# Injection des données dans le modèle
logger.info("📥 Injection des données dans le modèle")
self.interpreter.set_tensor(self.input_index, input_data)
# Invocation du modèle
logger.info("🚀 Exécution du modèle TFLite")
self.interpreter.invoke()
# Récupération de la sortie
logger.info("📤 Récupération des résultats bruts")
output_details = self.interpreter.get_output_details()[0]
output_data = self.interpreter.get_tensor(output_details['index'])
logger.debug(f"✅ Logits quantifiés récupérés - Shape : {output_data.shape} - Dtype : {output_data.dtype}")
# Paramètres de quantification
output_scale, output_zero_point = output_details['quantization']
logger.debug(f"ℹ️ Paramètres de quantification - Scale: {output_scale}, Zero Point: {output_zero_point}")
# Déquantification
logger.info("🔓 Déquantification des logits")
logits = (output_data.astype(np.float32) - output_zero_point) * output_scale
logger.debug(f"✅ Logits déquantifiés : {logits}")
# Calcul des probabilités
logger.info("🧮 Calcul des probabilités avec softmax")
probas = logits[0]
logger.debug(f"✅ Probabilités : {probas}")
logger.info("🎯 Prédiction terminée")
return probas
def predict (self, pil_image):
if self.input_dtype in [np.uint8, np.int8]:
logger.info("🗜️ Modèle quantifié PTQ détecté")
return self.predict_ptq(pil_image)
else:
logger.info("🌊 Modèle dynamique ou float32 détecté")
return self.predict_dyna(pil_image)
class ModelStruct(TypedDict):
model_name: str
model: tf.keras.Model
gradcam_model:tf.keras.Model
fast_model:TFLiteDynamicModel
preprocess_input: Callable[[np.ndarray], Any]
target_size: tuple[int, int]
last_conv_layer:str
gradcam_type:str
_model_cache: list[ModelStruct] | None = None
def load_model() -> list[ModelStruct]:
global _model_cache
if _model_cache is None:
print("📦 Chargement du modèle EfficientNetV2M...")
model = tf.keras.models.load_model("model/best_efficientnetv2m_gradcam.keras", compile=False)
fast_model=TFLiteDynamicModel("model/efficientnetv2m_float16.tflite", img_size=480)
_model_cache = [{
"model_name": "EfficientNetV2M",
"model": model,
"gradcam_model": model,
"fast_model":fast_model,
"preprocess_input": effnet_preprocess,
"target_size": (480, 480),
"last_conv_layer": "block7a_expand_conv",
"gradcam_type": "gradcam++"
}]
return _model_cache
def compute_gradcam(model, image_array, class_index=None, layer_name=None,gradcam_type="gradcam"):
"""
Calcule la carte Grad-CAM pour une image et un modèle Keras.
Args:
model: tf.keras.Model.
image_array: np.array (H, W, 3), float32, pré-traitée.
class_index: int ou None, index de la classe cible. Si None, classe prédite.
layer_name: str ou None, nom de la couche convolutionnelle à utiliser. Si None, dernière conv.
Returns:
gradcam_map: np.array (H, W), normalisée entre 0 et 1.
"""
logging.info(f"Lancement calcul de la gradcam avec le type {gradcam_type}")
if image_array.ndim == 3:
input_tensor = np.expand_dims(image_array, axis=0)
else:
input_tensor = image_array
if gradcam_type=="gradcam++":
gradcam = GradcamPlusPlus(model, clone=False)
else:
gradcam = Gradcam(model, clone=False)
def loss(output):
if class_index is None:
class_index_local = tf.argmax(output[0])
else:
class_index_local = class_index
return output[:, class_index_local]
# Choisir la couche à utiliser pour GradCAM
if layer_name is None:
# Si non spécifié, chercher la dernière couche conv 2D
for layer in reversed(model.layers):
if 'conv' in layer.name and len(layer.output_shape) == 4:
layer_name = layer.name
break
if layer_name is None:
raise ValueError("Aucune couche convolutionnelle 2D trouvée dans le modèle.")
cam = gradcam(loss, input_tensor, penultimate_layer=layer_name)
cam = cam[0]
# Normaliser entre 0 et 1
cam = normalize(cam)
return cam
def preprocess_image(image_bytes, target_size, preprocess_input):
try:
logger.info("📤 Lecture des bytes et conversion en image PIL")
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
except Exception as e:
logger.exception("❌ Erreur lors de l'ouverture de l'image")
raise ValueError("Impossible de décoder l'image") from e
logger.info(f"📐 Redimensionnement de l'image à la taille {target_size}")
image = image.resize(target_size)
image_array = np.array(image).astype(np.float32)
logger.debug(f"🔍 Shape de l'image après conversion en tableau : {image_array.shape}")
if image_array.ndim != 3 or image_array.shape[-1] != 3:
logger.error(f"❌ Image invalide : shape={image_array.shape}")
raise ValueError("Image must have 3 channels (RGB)")
logger.info("🎨 Conversion et prétraitement de l'image")
# Préparation pour la prédiction
preprocessed_input = preprocess_input(image_array.copy())
preprocessed_input = np.expand_dims(preprocessed_input, axis=0)
# Préparation pour Grad-CAM (non prétraitée, mais batchifiée et en float32)
raw_input = np.expand_dims(image_array / 255.0, axis=0) # Mise à l’échelle simple
logger.debug(f"🧪 Shape après ajout de la dimension batch : {preprocessed_input.shape}")
return preprocessed_input, raw_input
def compute_entropy_safe(probas):
probas = np.array(probas)
# On garde uniquement les probabilités strictement positives
mask = probas > 0
entropy = -np.sum(probas[mask] * np.log(probas[mask]))
return entropy
def hash_image_bytes(image_bytes):
return hashlib.md5(image_bytes).hexdigest()
def get_heatmap(config, image_bytes: bytes, predicted_class_index):
result = {}
try:
hash_key = hash_image_bytes(image_bytes)
heatmap_key = f"{hash_key}_heatmap"
# Vérification cache mémoire d'abord
if heatmap_key in cache:
logger.info(f"✅ Heatmap trouvée dans le cache {heatmap_key}")
result["heatmap"] = cache[heatmap_key]
return result
#
# Calcul si non trouvé dans le cache
_, raw_input = preprocess_image(
image_bytes, config["target_size"], config["preprocess_input"]
)
logger.info("✅ Début de la génération de la heatmap")
start_time = time.time()
heatmap = compute_gradcam(
config["gradcam_model"],
raw_input,
class_index=predicted_class_index,
layer_name=config["last_conv_layer"],
gradcam_type=config["gradcam_type"],
)
elapsed_time = time.time() - start_time
logger.info(f"✅ Heatmap générée en {elapsed_time:.2f} secondes")
# Conversion en liste pour le JSON
heatmap_list = heatmap.tolist()
result["heatmap"] = heatmap_list
cache[heatmap_key] = heatmap_list
except Exception as e:
logger.error(f"❌ Erreur lors de la génération de la heatmap: {e}")
result["heatmap"] = []
return result
def get_heatmap_old(config, image_bytes: bytes,predicted_class_index):
result={}
try:
_,raw_input = preprocess_image(image_bytes,config["target_size"],config["preprocess_input"])
logger.info("✅ Début de la génération de la heatmap")
start_time = time.time()
# Vérification des entrées
logger.info(f"🖼️ Image d'entrée shape: {raw_input.shape}")
logger.info(f"🎯 Index de classe prédite: {predicted_class_index}")
logger.info(f"🛠️ Dernière couche utilisée: {config['last_conv_layer']}")
# Calcul de la heatmap
heatmap = compute_gradcam(config["gradcam_model"], raw_input, class_index=predicted_class_index, layer_name=config["last_conv_layer"],gradcam_type=config["gradcam_type"])
elapsed_time = time.time() - start_time
logger.info(f"✅ Heatmap générée en {elapsed_time:.2f} secondes")
# Conversion en liste pour le JSON
result["heatmap"] = heatmap.tolist()
except Exception as e:
logger.error(f"❌ Erreur lors de la génération de la heatmap: {e}")
result["heatmap"] = []
return result
def predict_with_cache(config, image_bytes: bytes):
hash_key = hash_image_bytes(image_bytes)
pred_key = f"{hash_key}_pred"
if pred_key in cache:
logger.info(f"✅ prédiction trouvée dans le cache {hash_key}")
return cache[pred_key]
try:
logger.info("📤 Lecture des bytes et conversion en image PIL")
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
except Exception as e:
logger.exception("❌ Erreur lors de l'ouverture de l'image")
raise ValueError("Impossible de décoder l'image") from e
logger.info("🤖 Lancement de la prédiction avec le modèle")
preds = config["fast_model"].predict(image)
logger.info(f"📈 Prédictions brutes : {preds.tolist()}")
predicted_class_index = int(np.argmax(preds))
confidence = float(np.max(preds))
entropy=float(compute_entropy_safe(preds))
logger.info(f"✅ Prédiction : classe={predicted_class_index}, confiance={confidence:.4f},entropy={entropy:.4f}")
result={
"preds": preds.tolist(),
"predicted_class": predicted_class_index,
"confidence": confidence,
"entropy":entropy
}
cache[pred_key] = result
return result
def predict_with_model(config, image_bytes: bytes):
return predict_with_cache(config, image_bytes)
def predict_with_model_old(config, image_bytes: bytes):
#input_array,raw_input = preprocess_image(image_bytes,config["target_size"],config["preprocess_input"])
try:
logger.info("📤 Lecture des bytes et conversion en image PIL")
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
except Exception as e:
logger.exception("❌ Erreur lors de l'ouverture de l'image")
raise ValueError("Impossible de décoder l'image") from e
logger.info("🤖 Lancement de la prédiction avec le modèle")
preds = config["fast_model"].predict(image)
logger.info(f"📈 Prédictions brutes : {preds[0].tolist()}")
predicted_class_index = int(np.argmax(preds[0]))
confidence = float(preds[0][predicted_class_index])
entropy=float(compute_entropy_safe(preds))
logger.info(f"✅ Prédiction : classe={predicted_class_index}, confiance={confidence:.4f},entropy={entropy:.4f}")
return {
"preds": preds[0].tolist(),
"predicted_class": predicted_class_index,
"confidence": confidence,
"entropy":entropy
}