Spaces:
Runtime error
Runtime error
| import tensorflow as tf | |
| import numpy as np | |
| from PIL import Image | |
| import tensorflow as tf | |
| import logging | |
| import numpy as np | |
| from PIL import Image | |
| from keras.applications.efficientnet_v2 import preprocess_input as effnet_preprocess | |
| import io | |
| from tf_keras_vis.gradcam import Gradcam,GradcamPlusPlus | |
| from tf_keras_vis.utils import normalize | |
| import hashlib | |
| import numpy as np | |
| import tensorflow as tf | |
| from tf_keras_vis.saliency import Saliency | |
| from tf_keras_vis.utils import normalize | |
| import numpy as np | |
| import tensorflow as tf | |
| from tf_keras_vis.saliency import Saliency | |
| from tf_keras_vis.utils import normalize | |
| import logging | |
| import time | |
| import os | |
| import diskcache as dc | |
| from typing import TypedDict, Callable, Any | |
| CACHE_DIR = './cache' | |
| os.makedirs(CACHE_DIR, exist_ok=True) | |
| cache = dc.Cache(CACHE_DIR) | |
| logging.basicConfig( | |
| level=logging.INFO, # ou logging.DEBUG | |
| format="%(asctime)s [%(levelname)s] %(name)s: %(message)s" | |
| ) | |
| logger = logging.getLogger(__name__) | |
| confidence_threshold=0.55 | |
| entropy_threshold=2 | |
| class TFLiteDynamicModel: | |
| def __init__(self, tflite_path, img_size=224): | |
| logger.info(f"🚀 Chargement du modèle TFLite depuis : {tflite_path}") | |
| self.img_size = img_size | |
| self.interpreter = tf.lite.Interpreter(model_path=tflite_path) | |
| self.interpreter.allocate_tensors() | |
| input_details = self.interpreter.get_input_details() | |
| output_details = self.interpreter.get_output_details() | |
| self.input_index = input_details[0]['index'] | |
| self.input_dtype = input_details[0]['dtype'] | |
| self.input_scale, self.input_zero_point = input_details[0]['quantization'] | |
| self.output_index = output_details[0]['index'] | |
| logger.info(f"🔍 Input tensor index : {self.input_index}, dtype : {self.input_dtype}, scale : {self.input_scale}, zero_point : {self.input_zero_point}") | |
| logger.info(f"🔍 Output tensor index : {self.output_index}") | |
| def preprocess(self, pil_image): | |
| logger.info(f"🎨 Prétraitement image, redimension à {self.img_size}x{self.img_size}") | |
| img = pil_image.resize((self.img_size, self.img_size)) | |
| img = np.array(img) | |
| # 📸 Gestion des images grayscale ou RGBA | |
| if img.ndim == 2: # Grayscale -> RGB | |
| logger.debug("⚪ Image grayscale détectée, conversion en RGB") | |
| img = np.stack([img] * 3, axis=-1) | |
| elif img.shape[-1] == 4: # RGBA -> RGB | |
| logger.debug("🖼️ Image RGBA détectée, suppression canal alpha") | |
| img = img[..., :3] | |
| if self.input_dtype in [np.uint8, np.int8]: | |
| logger.info("🗜️ Modèle quantifié PTQ détecté (entrée int8 ou uint8)") | |
| # Pas de division par 255 ici ! On quantifie directement selon l'échelle et le zéro-point | |
| img = img.astype(np.float32) | |
| img = img / self.input_scale + self.input_zero_point | |
| # On clip selon le type | |
| if self.input_dtype == np.uint8: | |
| img = np.clip(img, 0, 255) | |
| else: # np.int8 | |
| img = np.clip(img, -128, 127) | |
| img = img.astype(self.input_dtype) | |
| else: | |
| logger.info("🌊 Modèle dynamique ou float32 détecté (entrée float32 normalisée)") | |
| img = img.astype(self.input_dtype) # Normalisation classique | |
| input_data = np.expand_dims(img, axis=0) | |
| logger.info(f"✅ Image prétraitée avec forme {input_data.shape} et dtype {input_data.dtype}") | |
| return input_data | |
| def preprocess_old(self, pil_image): | |
| logger.info(f"🎨 Prétraitement image, redimension à {self.img_size}x{self.img_size}") | |
| img = pil_image.resize((self.img_size, self.img_size)) | |
| img = np.array(img).astype(np.float32) | |
| if img.ndim == 2: # grayscale -> RGB | |
| logger.debug("⚪ Image grayscale détectée, conversion en RGB") | |
| img = np.stack([img]*3, axis=-1) | |
| elif img.shape[-1] == 4: # RGBA -> RGB | |
| logger.debug("🖼️ Image RGBA détectée, suppression canal alpha") | |
| img = img[..., :3] | |
| if self.input_dtype in [np.uint8, np.int8]: | |
| if self.input_scale > 0: | |
| logger.debug(f"⚙️ Application quantification dynamique avec scale {self.input_scale} et zero_point {self.input_zero_point}") | |
| img = img / 255.0 | |
| img = img / self.input_scale + self.input_zero_point | |
| img = np.clip(img, 0, 255 if self.input_dtype == np.uint8 else 127) | |
| img = img.astype(self.input_dtype) | |
| else: | |
| img = img.astype(self.input_dtype) | |
| input_data = np.expand_dims(img, axis=0) | |
| logger.info(f"✅ Image prétraitée avec forme {input_data.shape} et dtype {input_data.dtype}") | |
| return input_data | |
| def predict_dyna(self, pil_image): | |
| logger.info("⚡ Début de prédiction (modèle dynamique ou float32)") | |
| # Prétraitement | |
| logger.info("🔄 Prétraitement de l'image en cours") | |
| input_data = self.preprocess(pil_image) | |
| logger.debug(f"✅ Image prétraitée - Shape : {input_data.shape} - Dtype : {input_data.dtype}") | |
| # Injection des données dans le modèle | |
| logger.info("📥 Injection des données dans le modèle") | |
| self.interpreter.set_tensor(self.input_index, input_data) | |
| # Invocation du modèle | |
| logger.info("🚀 Exécution du modèle TFLite") | |
| self.interpreter.invoke() | |
| # Récupération de la sortie | |
| logger.info("📤 Récupération des résultats bruts") | |
| output_data = self.interpreter.get_tensor(self.output_index) | |
| logger.debug(f"✅ Logits récupérés - Shape : {output_data.shape} - Dtype : {output_data.dtype}") | |
| # Calcul des probabilités | |
| logger.info("🧮 Calcul des probabilités") | |
| probas=output_data[0] | |
| logger.debug(f"✅ Probabilités : {probas}") | |
| logger.info("🎯 Prédiction terminée") | |
| return probas | |
| def predict_ptq(self, pil_image): | |
| logger.info("⚡ Début de prédiction") | |
| # Prétraitement | |
| logger.info("🔄 Prétraitement de l'image en cours") | |
| input_data = self.preprocess(pil_image) | |
| logger.debug(f"✅ Image prétraitée - Shape : {input_data.shape} - Dtype : {input_data.dtype}") | |
| # Injection des données dans le modèle | |
| logger.info("📥 Injection des données dans le modèle") | |
| self.interpreter.set_tensor(self.input_index, input_data) | |
| # Invocation du modèle | |
| logger.info("🚀 Exécution du modèle TFLite") | |
| self.interpreter.invoke() | |
| # Récupération de la sortie | |
| logger.info("📤 Récupération des résultats bruts") | |
| output_details = self.interpreter.get_output_details()[0] | |
| output_data = self.interpreter.get_tensor(output_details['index']) | |
| logger.debug(f"✅ Logits quantifiés récupérés - Shape : {output_data.shape} - Dtype : {output_data.dtype}") | |
| # Paramètres de quantification | |
| output_scale, output_zero_point = output_details['quantization'] | |
| logger.debug(f"ℹ️ Paramètres de quantification - Scale: {output_scale}, Zero Point: {output_zero_point}") | |
| # Déquantification | |
| logger.info("🔓 Déquantification des logits") | |
| logits = (output_data.astype(np.float32) - output_zero_point) * output_scale | |
| logger.debug(f"✅ Logits déquantifiés : {logits}") | |
| # Calcul des probabilités | |
| logger.info("🧮 Calcul des probabilités avec softmax") | |
| probas = logits[0] | |
| logger.debug(f"✅ Probabilités : {probas}") | |
| logger.info("🎯 Prédiction terminée") | |
| return probas | |
| def predict (self, pil_image): | |
| if self.input_dtype in [np.uint8, np.int8]: | |
| logger.info("🗜️ Modèle quantifié PTQ détecté") | |
| return self.predict_ptq(pil_image) | |
| else: | |
| logger.info("🌊 Modèle dynamique ou float32 détecté") | |
| return self.predict_dyna(pil_image) | |
| class ModelStruct(TypedDict): | |
| model_name: str | |
| model: tf.keras.Model | |
| gradcam_model:tf.keras.Model | |
| fast_model:TFLiteDynamicModel | |
| preprocess_input: Callable[[np.ndarray], Any] | |
| target_size: tuple[int, int] | |
| last_conv_layer:str | |
| gradcam_type:str | |
| _model_cache: list[ModelStruct] | None = None | |
| def load_model() -> list[ModelStruct]: | |
| global _model_cache | |
| if _model_cache is None: | |
| print("📦 Chargement du modèle EfficientNetV2M...") | |
| model = tf.keras.models.load_model("model/best_efficientnetv2m_gradcam.keras", compile=False) | |
| fast_model=TFLiteDynamicModel("model/efficientnetv2m_float16.tflite", img_size=480) | |
| _model_cache = [{ | |
| "model_name": "EfficientNetV2M", | |
| "model": model, | |
| "gradcam_model": model, | |
| "fast_model":fast_model, | |
| "preprocess_input": effnet_preprocess, | |
| "target_size": (480, 480), | |
| "last_conv_layer": "block7a_expand_conv", | |
| "gradcam_type": "gradcam++" | |
| }] | |
| return _model_cache | |
| def compute_gradcam(model, image_array, class_index=None, layer_name=None,gradcam_type="gradcam"): | |
| """ | |
| Calcule la carte Grad-CAM pour une image et un modèle Keras. | |
| Args: | |
| model: tf.keras.Model. | |
| image_array: np.array (H, W, 3), float32, pré-traitée. | |
| class_index: int ou None, index de la classe cible. Si None, classe prédite. | |
| layer_name: str ou None, nom de la couche convolutionnelle à utiliser. Si None, dernière conv. | |
| Returns: | |
| gradcam_map: np.array (H, W), normalisée entre 0 et 1. | |
| """ | |
| logging.info(f"Lancement calcul de la gradcam avec le type {gradcam_type}") | |
| if image_array.ndim == 3: | |
| input_tensor = np.expand_dims(image_array, axis=0) | |
| else: | |
| input_tensor = image_array | |
| if gradcam_type=="gradcam++": | |
| gradcam = GradcamPlusPlus(model, clone=False) | |
| else: | |
| gradcam = Gradcam(model, clone=False) | |
| def loss(output): | |
| if class_index is None: | |
| class_index_local = tf.argmax(output[0]) | |
| else: | |
| class_index_local = class_index | |
| return output[:, class_index_local] | |
| # Choisir la couche à utiliser pour GradCAM | |
| if layer_name is None: | |
| # Si non spécifié, chercher la dernière couche conv 2D | |
| for layer in reversed(model.layers): | |
| if 'conv' in layer.name and len(layer.output_shape) == 4: | |
| layer_name = layer.name | |
| break | |
| if layer_name is None: | |
| raise ValueError("Aucune couche convolutionnelle 2D trouvée dans le modèle.") | |
| cam = gradcam(loss, input_tensor, penultimate_layer=layer_name) | |
| cam = cam[0] | |
| # Normaliser entre 0 et 1 | |
| cam = normalize(cam) | |
| return cam | |
| def preprocess_image(image_bytes, target_size, preprocess_input): | |
| try: | |
| logger.info("📤 Lecture des bytes et conversion en image PIL") | |
| image = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
| except Exception as e: | |
| logger.exception("❌ Erreur lors de l'ouverture de l'image") | |
| raise ValueError("Impossible de décoder l'image") from e | |
| logger.info(f"📐 Redimensionnement de l'image à la taille {target_size}") | |
| image = image.resize(target_size) | |
| image_array = np.array(image).astype(np.float32) | |
| logger.debug(f"🔍 Shape de l'image après conversion en tableau : {image_array.shape}") | |
| if image_array.ndim != 3 or image_array.shape[-1] != 3: | |
| logger.error(f"❌ Image invalide : shape={image_array.shape}") | |
| raise ValueError("Image must have 3 channels (RGB)") | |
| logger.info("🎨 Conversion et prétraitement de l'image") | |
| # Préparation pour la prédiction | |
| preprocessed_input = preprocess_input(image_array.copy()) | |
| preprocessed_input = np.expand_dims(preprocessed_input, axis=0) | |
| # Préparation pour Grad-CAM (non prétraitée, mais batchifiée et en float32) | |
| raw_input = np.expand_dims(image_array / 255.0, axis=0) # Mise à l’échelle simple | |
| logger.debug(f"🧪 Shape après ajout de la dimension batch : {preprocessed_input.shape}") | |
| return preprocessed_input, raw_input | |
| def compute_entropy_safe(probas): | |
| probas = np.array(probas) | |
| # On garde uniquement les probabilités strictement positives | |
| mask = probas > 0 | |
| entropy = -np.sum(probas[mask] * np.log(probas[mask])) | |
| return entropy | |
| def hash_image_bytes(image_bytes): | |
| return hashlib.md5(image_bytes).hexdigest() | |
| def get_heatmap(config, image_bytes: bytes, predicted_class_index): | |
| result = {} | |
| try: | |
| hash_key = hash_image_bytes(image_bytes) | |
| heatmap_key = f"{hash_key}_heatmap" | |
| # Vérification cache mémoire d'abord | |
| if heatmap_key in cache: | |
| logger.info(f"✅ Heatmap trouvée dans le cache {heatmap_key}") | |
| result["heatmap"] = cache[heatmap_key] | |
| return result | |
| # | |
| # Calcul si non trouvé dans le cache | |
| _, raw_input = preprocess_image( | |
| image_bytes, config["target_size"], config["preprocess_input"] | |
| ) | |
| logger.info("✅ Début de la génération de la heatmap") | |
| start_time = time.time() | |
| heatmap = compute_gradcam( | |
| config["gradcam_model"], | |
| raw_input, | |
| class_index=predicted_class_index, | |
| layer_name=config["last_conv_layer"], | |
| gradcam_type=config["gradcam_type"], | |
| ) | |
| elapsed_time = time.time() - start_time | |
| logger.info(f"✅ Heatmap générée en {elapsed_time:.2f} secondes") | |
| # Conversion en liste pour le JSON | |
| heatmap_list = heatmap.tolist() | |
| result["heatmap"] = heatmap_list | |
| cache[heatmap_key] = heatmap_list | |
| except Exception as e: | |
| logger.error(f"❌ Erreur lors de la génération de la heatmap: {e}") | |
| result["heatmap"] = [] | |
| return result | |
| def get_heatmap_old(config, image_bytes: bytes,predicted_class_index): | |
| result={} | |
| try: | |
| _,raw_input = preprocess_image(image_bytes,config["target_size"],config["preprocess_input"]) | |
| logger.info("✅ Début de la génération de la heatmap") | |
| start_time = time.time() | |
| # Vérification des entrées | |
| logger.info(f"🖼️ Image d'entrée shape: {raw_input.shape}") | |
| logger.info(f"🎯 Index de classe prédite: {predicted_class_index}") | |
| logger.info(f"🛠️ Dernière couche utilisée: {config['last_conv_layer']}") | |
| # Calcul de la heatmap | |
| heatmap = compute_gradcam(config["gradcam_model"], raw_input, class_index=predicted_class_index, layer_name=config["last_conv_layer"],gradcam_type=config["gradcam_type"]) | |
| elapsed_time = time.time() - start_time | |
| logger.info(f"✅ Heatmap générée en {elapsed_time:.2f} secondes") | |
| # Conversion en liste pour le JSON | |
| result["heatmap"] = heatmap.tolist() | |
| except Exception as e: | |
| logger.error(f"❌ Erreur lors de la génération de la heatmap: {e}") | |
| result["heatmap"] = [] | |
| return result | |
| def predict_with_cache(config, image_bytes: bytes): | |
| hash_key = hash_image_bytes(image_bytes) | |
| pred_key = f"{hash_key}_pred" | |
| if pred_key in cache: | |
| logger.info(f"✅ prédiction trouvée dans le cache {hash_key}") | |
| return cache[pred_key] | |
| try: | |
| logger.info("📤 Lecture des bytes et conversion en image PIL") | |
| image = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
| except Exception as e: | |
| logger.exception("❌ Erreur lors de l'ouverture de l'image") | |
| raise ValueError("Impossible de décoder l'image") from e | |
| logger.info("🤖 Lancement de la prédiction avec le modèle") | |
| preds = config["fast_model"].predict(image) | |
| logger.info(f"📈 Prédictions brutes : {preds.tolist()}") | |
| predicted_class_index = int(np.argmax(preds)) | |
| confidence = float(np.max(preds)) | |
| entropy=float(compute_entropy_safe(preds)) | |
| logger.info(f"✅ Prédiction : classe={predicted_class_index}, confiance={confidence:.4f},entropy={entropy:.4f}") | |
| result={ | |
| "preds": preds.tolist(), | |
| "predicted_class": predicted_class_index, | |
| "confidence": confidence, | |
| "entropy":entropy | |
| } | |
| cache[pred_key] = result | |
| return result | |
| def predict_with_model(config, image_bytes: bytes): | |
| return predict_with_cache(config, image_bytes) | |
| def predict_with_model_old(config, image_bytes: bytes): | |
| #input_array,raw_input = preprocess_image(image_bytes,config["target_size"],config["preprocess_input"]) | |
| try: | |
| logger.info("📤 Lecture des bytes et conversion en image PIL") | |
| image = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
| except Exception as e: | |
| logger.exception("❌ Erreur lors de l'ouverture de l'image") | |
| raise ValueError("Impossible de décoder l'image") from e | |
| logger.info("🤖 Lancement de la prédiction avec le modèle") | |
| preds = config["fast_model"].predict(image) | |
| logger.info(f"📈 Prédictions brutes : {preds[0].tolist()}") | |
| predicted_class_index = int(np.argmax(preds[0])) | |
| confidence = float(preds[0][predicted_class_index]) | |
| entropy=float(compute_entropy_safe(preds)) | |
| logger.info(f"✅ Prédiction : classe={predicted_class_index}, confiance={confidence:.4f},entropy={entropy:.4f}") | |
| return { | |
| "preds": preds[0].tolist(), | |
| "predicted_class": predicted_class_index, | |
| "confidence": confidence, | |
| "entropy":entropy | |
| } | |