""" inference.py ------------ Module d'inférence pour des modèles de classification binaire : FIRE (1) / NO_FIRE (0). Supporte : - EfficientNet-B0 (classification) - Inception v3 (classification) - YOLO (Ultralytics, détection) pour localiser le feu Retour principal des fonctions de prédiction : predicted_label (str), fire_prob (float), annotated_image (PIL.Image ou None) """ # ---------------------------- # 1) Imports # ---------------------------- import os import torch import torch.nn as nn from torchvision import transforms from PIL import Image import timm # YOLO (Ultralytics) – optionnel try: from ultralytics import YOLO except ImportError: YOLO = None # si la lib n'est pas installée, on gère ça proprement # ---------------------------- # 2) Constantes globales # ---------------------------- # Taille d'entrée par défaut si rien n'est précisé DEFAULT_IMAGE_SIZE = 224 # Stats ImageNet IMAGENET_MEAN = [0.485, 0.456, 0.406] IMAGENET_STD = [0.229, 0.224, 0.225] # Mapping idx -> label lisible IDX_TO_LABEL = { 0: "no_fire", 1: "fire", } # Registre des modèles de classification (timm) MODEL_REGISTRY = { "efficientnet_b0": { "timm_name": "efficientnet_b0", "image_size": 224, "classifier_attr": "classifier", }, "inception_v3": { "timm_name": "inception_v3", "image_size": 299, "classifier_attr": "fc", }, } # Modèle par défaut si on ne sait pas quoi choisir DEFAULT_MODEL_KEY = "efficientnet_b0" # IDs des classes "feu" pour YOLO (à adapter si besoin après entraînement) # Exemple : si model.names == {0: 'fire'} → [0] # si model.names == {0: 'no_fire', 1: 'fire'} → [1] FIRE_CLASS_IDS = [0] # ---------------------------- # 3) Device # ---------------------------- def get_device(): """ Retourne le device à utiliser pour l'inférence : - 'cuda' si un GPU est disponible - sinon 'cpu' """ if torch.cuda.is_available(): return torch.device("cuda") return torch.device("cpu") # ---------------------------- # 4) Détection du type de modèle # ---------------------------- def infer_model_key_from_path(weights_path: str) -> str: """ Devine une clé de modèle (model_key) à partir du nom de fichier. Exemple : - "yolov8_fire.pt" → "yolo" - "inception3_fire.pt" → "inception_v3" - "efficientnet_fire.pt" → "efficientnet_b0" (par défaut) """ filename = os.path.basename(weights_path).lower() if "yolo" in filename: return "yolo" if "inception" in filename: return "inception_v3" return DEFAULT_MODEL_KEY def get_model_config(model_key: str) -> dict: """ Renvoie la config du modèle à partir d'une model_key. Fallback : EfficientNet-B0 si model_key inconnue. """ if model_key in MODEL_REGISTRY: return MODEL_REGISTRY[model_key] return MODEL_REGISTRY[DEFAULT_MODEL_KEY] # ---------------------------- # 5) Construction modèle (classification) # ---------------------------- def build_model(model_key: str, num_classes: int = 2) -> torch.nn.Module: """ Construit un modèle de classification (EfficientNet, Inception...) et adapte la dernière couche à num_classes sorties. """ config = get_model_config(model_key) timm_name = config["timm_name"] classifier_attr = config["classifier_attr"] model = timm.create_model(timm_name, pretrained=False) classifier = getattr(model, classifier_attr) if isinstance(classifier, nn.Linear): in_features = classifier.in_features else: raise ValueError( f"Impossible de déterminer in_features pour la tête du modèle '{timm_name}'. " f"Attribut '{classifier_attr}' de type {type(classifier)} non supporté." ) new_classifier = nn.Linear(in_features, num_classes) setattr(model, classifier_attr, new_classifier) return model # ---------------------------- # 6) Nettoyage de state_dict # ---------------------------- def _clean_state_dict_keys(state_dict: dict) -> dict: """ Nettoie les clés pour gérer les prefixes 'model.' (Lightning) et 'module.' (DataParallel). """ new_state = {} for k, v in state_dict.items(): new_key = k if new_key.startswith("model."): new_key = new_key[len("model."):] if new_key.startswith("module."): new_key = new_key[len("module."):] new_state[new_key] = v return new_state # ---------------------------- # 7) Chargement du modèle # ---------------------------- def load_model(weights_path: str, map_location=None, model_key: str | None = None): """ Charge un modèle avec les poids entraînés. - Pour YOLO (Ultralytics) : charge un modèle de détection. - Pour EfficientNet / Inception : charge un modèle de classification binaire. Retour : -------- model : torch.nn.Module ou YOLO device : torch.device """ device = map_location if map_location is not None else get_device() # Détecter le type de modèle si non fourni if model_key is None: model_key = infer_model_key_from_path(weights_path) # 🔹 Cas YOLO : modèle de détection (Ultralytics) if model_key == "yolo": if YOLO is None: raise ImportError( "Le modèle YOLO est demandé mais la librairie 'ultralytics' " "n'est pas installée. Ajoutez 'ultralytics' dans requirements.txt." ) yolo_model = YOLO(weights_path) # YOLO gère déjà souvent le device en interne, on essaye juste par sécurité try: yolo_model.to(device) except Exception: pass return yolo_model, device # 🔹 Cas classification (EfficientNet, Inception...) checkpoint = torch.load(weights_path, map_location=device) if isinstance(checkpoint, dict) and "state_dict" in checkpoint: state_dict = checkpoint["state_dict"] else: state_dict = checkpoint state_dict = _clean_state_dict_keys(state_dict) model = build_model(model_key=model_key, num_classes=2) missing, unexpected = model.load_state_dict(state_dict, strict=False) # (optionnel) debug : # print("Missing keys:", missing) # print("Unexpected keys:", unexpected) model = model.to(device) model.eval() return model, device # ---------------------------- # 8) Transforms pour l'inférence # ---------------------------- def get_val_transform(image_size: int | None = None): """ Renvoie les transformations à appliquer aux images pour l'inférence. Si image_size est None → DEFAULT_IMAGE_SIZE (224). """ if image_size is None: image_size = DEFAULT_IMAGE_SIZE transform = transforms.Compose([ transforms.Resize((image_size, image_size)), transforms.ToTensor(), transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), ]) return transform # ---------------------------- # 9) Prétraitement d'une image # ---------------------------- def preprocess_image(image: Image.Image, transform=None, image_size: int | None = None): """ Applique les transforms à une image PIL et ajoute une dimension batch. """ if transform is None: transform = get_val_transform(image_size=image_size) img_tensor = transform(image) # [3, H, W] img_tensor = img_tensor.unsqueeze(0) # [1, 3, H, W] return img_tensor # ---------------------------- # 10) Prédiction depuis un tenseur (classification) # ---------------------------- def predict_from_tensor( image_tensor: torch.Tensor, model: torch.nn.Module, device: torch.device, threshold: float = 0.5, ): """ Prédit la classe (fire/no_fire) à partir d'un tenseur déjà prétraité pour les modèles de classification (EfficientNet, Inception...). """ image_tensor = image_tensor.to(device) with torch.no_grad(): outputs = model(image_tensor) # logits [1, 2] probs = torch.softmax(outputs, dim=1) # probas fire_prob = probs[0, 1].item() predicted_idx = 1 if fire_prob >= threshold else 0 predicted_label = IDX_TO_LABEL[predicted_idx] return predicted_label, fire_prob # ---------------------------- # 11) Prédiction depuis une image PIL # ---------------------------- def predict_from_pil( image: Image.Image, model, device: torch.device, transform=None, threshold: float = 0.5, image_size: int | None = None, ): """ Prédit la classe à partir d'une image PIL. Retour ------ predicted_label : str "fire" ou "no_fire". fire_prob : float Probabilité de "fire". annotated_image : PIL.Image.Image ou None Image annotée (bounding boxes) pour les modèles de détection (YOLO), None pour les modèles de classification. """ if image.mode != "RGB": image = image.convert("RGB") # 🔹 Cas YOLO : modèle de détection (Ultralytics) if hasattr(model, "task") and getattr(model, "task", None) == "detect": results = model(image) result = results[0] boxes = getattr(result, "boxes", None) fire_prob = 0.0 if boxes is not None and len(boxes) > 0: classes = boxes.cls # ids des classes (tensor) confs = boxes.conf # scores de confiance (tensor) # masque des boxes "feu" mask_fire = torch.zeros_like(classes, dtype=torch.bool) for cid in FIRE_CLASS_IDS: mask_fire |= (classes == cid) if mask_fire.any(): fire_prob = float(confs[mask_fire].max().item()) predicted_label = "fire" if fire_prob >= threshold else "no_fire" # Image annotée avec les bounding boxes annotated_image = None try: annotated_array = result.plot() # numpy array BGR annotated_image = Image.fromarray(annotated_array[..., ::-1]) # BGR -> RGB except Exception: annotated_image = None return predicted_label, fire_prob, annotated_image # 🔹 Cas classification classique image_tensor = preprocess_image(image, transform=transform, image_size=image_size) predicted_label, fire_prob = predict_from_tensor(image_tensor, model, device, threshold=threshold) annotated_image = None return predicted_label, fire_prob, annotated_image # ---------------------------- # 12) Prédiction depuis un chemin de fichier # ---------------------------- def predict_from_path( image_path: str, model, device: torch.device, transform=None, threshold: float = 0.5, image_size: int | None = None, ): """ Prédit la classe à partir d'un chemin vers une image. """ image = Image.open(image_path) return predict_from_pil( image=image, model=model, device=device, transform=transform, threshold=threshold, image_size=image_size, ) # ---------------------------- # 13) Exemple d'utilisation en script direct # ---------------------------- if __name__ == "__main__": # Petit test local (à adapter) weights_path = "efficientnet_fire.pt" if not os.path.exists(weights_path): print(f"[ERREUR] Fichier de poids introuvable : {weights_path}") else: model, device = load_model(weights_path) print(f"Modèle chargé sur le device : {device}") transform = get_val_transform() test_image_path = "example.jpg" if not os.path.exists(test_image_path): print(f"[INFO] Aucune image test trouvée à : {test_image_path}") else: label, prob, _ = predict_from_path( test_image_path, model=model, device=device, transform=transform, threshold=0.5, ) print(f"Résultat pour {test_image_path} : label={label}, prob_fire={prob:.4f}")