Spaces:
Paused
Paused
| import aiohttp | |
| import numpy as np | |
| import logging | |
| from PIL import Image | |
| import io | |
| from io import BytesIO | |
| import base64 | |
| import numpy as np | |
| import aiohttp | |
| shannon_threashold=0.15 | |
| from app.model import predict_with_model,compute_entropy_safe | |
| logging.basicConfig( | |
| level=logging.INFO, # ou logging.DEBUG | |
| format="%(asctime)s [%(levelname)s] %(name)s: %(message)s" | |
| ) | |
| logger = logging.getLogger(__name__) | |
| from scipy.spatial.distance import jensenshannon | |
| import numpy as np | |
| from scipy.spatial.distance import jensenshannon | |
| def compute_js_divergence(all_probs): | |
| """ | |
| Calcule la divergence de Jensen-Shannon sur une liste de distributions de probabilités. | |
| Args: | |
| all_probs (list of np.array): Liste des prédictions de chaque modèle (softmax). | |
| Returns: | |
| float: La divergence de Jensen-Shannon entre les modèles. | |
| """ | |
| if len(all_probs) < 2: | |
| return 0.0 # Pas de désaccord possible avec un seul modèle | |
| # Convertir la liste en tableau numpy (shape: [nb_modèles, nb_classes]) | |
| probs_array = np.array(all_probs) | |
| # Calculer la moyenne des distributions (distribution "moyenne") | |
| mean_probs = np.mean(probs_array, axis=0) | |
| # Calculer la JSD entre chaque modèle et la moyenne | |
| jsd_values = [] | |
| for probs in probs_array: | |
| jsd = jensenshannon(probs, mean_probs, base=2) # base=2 : divergence bornée entre 0 et 1 | |
| jsd_values.append(jsd) | |
| # Retourner la moyenne des divergences | |
| return np.mean(jsd_values) | |
| # Si js_divergence > 0.1 → Désaccord modéré | |
| async def soft_voting(model_configs, image_bytes: bytes, mode, show_heatmap, default_model): | |
| logger.info("🔁 Début de la prédiction multi-modèles") | |
| all_probs = [] | |
| models = [] | |
| models_predictions = [] | |
| models_confidences = [] | |
| models_entropies = [] | |
| models_uncertainties = [] | |
| models_heatmaps = [] | |
| # On commence toujours par le modèle par défaut | |
| default_config = next((config for config in model_configs if config["model_name"].lower() == default_model.lower()), None) | |
| if default_config is None: | |
| logger.error(f"❌ Modèle par défaut '{default_model}' introuvable dans les configurations.") | |
| return None | |
| async with aiohttp.ClientSession() as session: | |
| # Prédiction avec le modèle par défaut | |
| logger.info(f"🚀 Prédiction avec le modèle par défaut : {default_model}") | |
| prediction = predict_with_model(default_config, image_bytes, show_heatmap) | |
| all_probs.append(prediction["preds"]) | |
| models_predictions.append(prediction["predicted_class"]) | |
| models_confidences.append(prediction["confidence"]) | |
| models_entropies.append(prediction["entropy"]) | |
| models_uncertainties.append(prediction["is_uncertain_model"]) | |
| models.append(default_config["model_name"]) | |
| if show_heatmap: | |
| heatmap = prediction.get("heatmap") | |
| if heatmap and len(heatmap) > 0: | |
| models_heatmaps.append(heatmap) | |
| else: | |
| logger.warning(f"⚠️ Heatmap vide ou invalide pour le modèle {default_config['model_name']}") | |
| if not all_probs: | |
| logger.warning("⚠️ Aucune prédiction reçue, vérifie les APIs appelées.") | |
| raise Exception("No predictions received.") | |
| mean_probs = np.mean(all_probs, axis=0) | |
| final_class = int(np.argmax(mean_probs)) | |
| final_confidence = float(mean_probs[final_class]) | |
| entropy=float(compute_entropy_safe(mean_probs)) | |
| jsd_score = float(compute_js_divergence(all_probs)) | |
| logger.debug(f"🧠 Moyenne des probabilités : {mean_probs.tolist()}") | |
| # Mode 'single' : on s'arrête ici | |
| if mode == "single": | |
| is_global_uncertain=models_uncertainties[0] | |
| logger.info("🛑 Mode 'single' activé, utilisation uniquement du modèle par défaut.") | |
| logger.info(f"✅ Prediction terminé : classe={final_class}" | |
| f"confiance={final_confidence:.4f}\n" | |
| f"entropy={entropy:.4f}\n" | |
| f"jsd_score={jsd_score:.4f}\n" | |
| f"is_global_uncertain={is_global_uncertain}\n" | |
| ) | |
| return { | |
| "predicted_class": final_class, | |
| "confidence": final_confidence, | |
| "entropy":entropy, | |
| "jsd_score":jsd_score, | |
| "models": models, | |
| "is_global_uncertain":is_global_uncertain, | |
| "models_predictions": models_predictions, | |
| "models_confidences": models_confidences, | |
| "models_entropies":models_entropies, | |
| "models_uncertainties":models_uncertainties, | |
| "models_heatmaps": models_heatmaps | |
| } | |
| # Si mode == 'automatic' et confiance suffisante, on s'arrête | |
| if mode == "automatic" and prediction["confidence"] >= 0.90: | |
| is_global_uncertain=models_uncertainties[0] | |
| logger.info(f"✅ Confiance élevée ({prediction['confidence']:.2f}), pas besoin de voter.") | |
| logger.info(f"✅ Prediction terminé : classe={final_class}" | |
| f"confiance={final_confidence:.4f}\n" | |
| f"entropy={entropy:.4f}\n" | |
| f"jsd_score={jsd_score:.4f}\n" | |
| f"is_global_uncertain={is_global_uncertain}\n" | |
| ) | |
| return { | |
| "predicted_class": final_class, | |
| "confidence": final_confidence, | |
| "entropy":entropy, | |
| "jsd_score":jsd_score, | |
| "models": models, | |
| "is_global_uncertain":is_global_uncertain, | |
| "models_predictions": models_predictions, | |
| "models_confidences": models_confidences, | |
| "models_entropies":models_entropies, | |
| "models_uncertainties":models_uncertainties, | |
| "models_heatmaps": models_heatmaps | |
| } | |
| # Sinon, on continue avec tous les autres modèles (voting ou automatic avec faible confiance) | |
| logger.info(f"🔍 Mode '{mode}' : Prédictions complémentaires en cours.") | |
| for config in model_configs: | |
| if config["model_name"].lower() == default_model.lower(): | |
| continue # On a déjà traité le modèle par défaut | |
| prediction = predict_with_model(config, image_bytes, show_heatmap) | |
| all_probs.append(prediction["preds"]) | |
| models_predictions.append(prediction["predicted_class"]) | |
| models_confidences.append(prediction["confidence"]) | |
| models_entropies.append(prediction["entropy"]) | |
| models_uncertainties.append(prediction["is_uncertain_model"]) | |
| models.append(config["model_name"]) | |
| if show_heatmap: | |
| heatmap = prediction.get("heatmap") | |
| if heatmap and len(heatmap) > 0: | |
| models_heatmaps.append(heatmap) | |
| else: | |
| logger.warning(f"⚠️ Heatmap vide ou invalide pour le modèle {config['model_name']}") | |
| mean_probs = np.mean(all_probs, axis=0) | |
| final_class = int(np.argmax(mean_probs)) | |
| final_confidence = float(mean_probs[final_class]) | |
| entropy=float(compute_entropy_safe(mean_probs)) | |
| jsd_score = float(compute_js_divergence(all_probs)) | |
| is_global_uncertain = any(models_uncertainties) and jsd_score > shannon_threashold | |
| logger.info(f"✅ Prediction terminé : classe={final_class}" | |
| f"confiance={final_confidence:.4f}\n" | |
| f"entropy={entropy:.4f}\n" | |
| f"jsd_score={jsd_score:.4f}\n" | |
| f"is_global_uncertain={is_global_uncertain}\n" | |
| ) | |
| return { | |
| "predicted_class": final_class, | |
| "confidence": final_confidence, | |
| "entropy":entropy, | |
| "jsd_score":jsd_score, | |
| "models": models, | |
| "is_global_uncertain":is_global_uncertain, | |
| "models_predictions": models_predictions, | |
| "models_confidences": models_confidences, | |
| "models_entropies":models_entropies, | |
| "models_uncertainties":models_uncertainties, | |
| "models_heatmaps": models_heatmaps | |
| } | |
| async def soft_voting_v1(model_configs,image_bytes: bytes,mode,show_heatmap,default_model): | |
| logger.info("🔁 Début du vote multi-modèles") | |
| all_probs = [] | |
| models = [] | |
| models_predictions = [] | |
| models_confidences = [] | |
| models_entropies = [] | |
| models_uncertainties = [] | |
| models_heatmaps=[] | |
| async with aiohttp.ClientSession() as session: | |
| for config in model_configs: | |
| prediction=predict_with_model(config,image_bytes,show_heatmap) | |
| all_probs.append(prediction["preds"]) | |
| models_predictions.append(prediction["predicted_class"]) | |
| models_confidences.append(prediction["confidence"]) | |
| models_entropies.append(prediction["entropy"]) | |
| models_uncertainties.append(prediction["is_uncertain_model"]) | |
| if show_heatmap: | |
| heatmap = prediction.get("heatmap") | |
| if heatmap and len(heatmap) > 0: | |
| models_heatmaps.append(heatmap) | |
| else: | |
| logger.warning(f"⚠️ Heatmap vide ou invalide, non ajoutée pour le modèle {config['model_name']}") | |
| logger.info(f"Taille heatmaps :{len(models_heatmaps)}") | |
| models.append(config["model_name"]) | |
| logger.info(f"📊 Prédictions ajoutées pour {config['model_name']}") | |
| if mode == "single": | |
| logger.info("🛑 Mode 'single' activé, arrêt après le premier modèle.") | |
| break | |
| if not all_probs: | |
| logger.warning("⚠️ Aucune prédiction reçue, vérifie les APIs appelées.") | |
| raise Exception("No predictions received.") | |
| mean_probs = np.mean(all_probs, axis=0) | |
| final_class = int(np.argmax(mean_probs)) | |
| final_confidence = float(mean_probs[final_class]) | |
| entropy=float(compute_entropy_safe(mean_probs)) | |
| jsd_score = float(compute_js_divergence(all_probs)) | |
| if mode=='single': | |
| is_global_uncertain=models_uncertainties[0] | |
| else: | |
| is_global_uncertain = any(models_uncertainties) and jsd_score > shannon_threashold | |
| logger.info(f"✅ Vote terminé : classe={final_class}" | |
| f"confiance={final_confidence:.4f}\n" | |
| f"entropy={entropy:.4f}\n" | |
| f"jsd_score={jsd_score:.4f}\n" | |
| f"is_global_uncertain={is_global_uncertain}\n" | |
| ) | |
| logger.debug(f"🧠 Moyenne des probabilités : {mean_probs.tolist()}") | |
| return { | |
| "predicted_class": final_class, | |
| "confidence": final_confidence, | |
| "entropy":entropy, | |
| "jsd_score":jsd_score, | |
| "models": models, | |
| "is_global_uncertain":is_global_uncertain, | |
| "models_predictions": models_predictions, | |
| "models_confidences": models_confidences, | |
| "models_entropies":models_entropies, | |
| "models_uncertainties":models_uncertainties, | |
| "models_heatmaps": models_heatmaps | |
| } | |