| """ |
| inference.py |
| ------------ |
| Module d'inférence pour des modèles de classification binaire : FIRE (1) / NO_FIRE (0). |
| |
| Supporte : |
| - EfficientNet-B0 (classification) |
| - Inception v3 (classification) |
| - YOLO (Ultralytics, détection) pour localiser le feu |
| |
| Retour principal des fonctions de prédiction : |
| predicted_label (str), fire_prob (float), annotated_image (PIL.Image ou None) |
| """ |
|
|
| |
| |
| |
| import os |
|
|
| import torch |
| import torch.nn as nn |
| from torchvision import transforms |
| from PIL import Image |
| import timm |
|
|
| |
| try: |
| from ultralytics import YOLO |
| except ImportError: |
| YOLO = None |
|
|
|
|
| |
| |
| |
|
|
| |
| DEFAULT_IMAGE_SIZE = 224 |
|
|
| |
| IMAGENET_MEAN = [0.485, 0.456, 0.406] |
| IMAGENET_STD = [0.229, 0.224, 0.225] |
|
|
| |
| IDX_TO_LABEL = { |
| 0: "no_fire", |
| 1: "fire", |
| } |
|
|
| |
| MODEL_REGISTRY = { |
| "efficientnet_b0": { |
| "timm_name": "efficientnet_b0", |
| "image_size": 224, |
| "classifier_attr": "classifier", |
| }, |
| "inception_v3": { |
| "timm_name": "inception_v3", |
| "image_size": 299, |
| "classifier_attr": "fc", |
| }, |
| } |
|
|
| |
| DEFAULT_MODEL_KEY = "efficientnet_b0" |
|
|
| |
| |
| |
| FIRE_CLASS_IDS = [0] |
|
|
|
|
| |
| |
| |
|
|
| def get_device(): |
| """ |
| Retourne le device à utiliser pour l'inférence : |
| - 'cuda' si un GPU est disponible |
| - sinon 'cpu' |
| """ |
| if torch.cuda.is_available(): |
| return torch.device("cuda") |
| return torch.device("cpu") |
|
|
|
|
| |
| |
| |
|
|
| def infer_model_key_from_path(weights_path: str) -> str: |
| """ |
| Devine une clé de modèle (model_key) à partir du nom de fichier. |
| |
| Exemple : |
| - "yolov8_fire.pt" → "yolo" |
| - "inception3_fire.pt" → "inception_v3" |
| - "efficientnet_fire.pt" → "efficientnet_b0" (par défaut) |
| """ |
| filename = os.path.basename(weights_path).lower() |
|
|
| if "yolo" in filename: |
| return "yolo" |
|
|
| if "inception" in filename: |
| return "inception_v3" |
|
|
| return DEFAULT_MODEL_KEY |
|
|
|
|
| def get_model_config(model_key: str) -> dict: |
| """ |
| Renvoie la config du modèle à partir d'une model_key. |
| Fallback : EfficientNet-B0 si model_key inconnue. |
| """ |
| if model_key in MODEL_REGISTRY: |
| return MODEL_REGISTRY[model_key] |
| return MODEL_REGISTRY[DEFAULT_MODEL_KEY] |
|
|
|
|
| |
| |
| |
|
|
| def build_model(model_key: str, num_classes: int = 2) -> torch.nn.Module: |
| """ |
| Construit un modèle de classification (EfficientNet, Inception...) |
| et adapte la dernière couche à num_classes sorties. |
| """ |
| config = get_model_config(model_key) |
| timm_name = config["timm_name"] |
| classifier_attr = config["classifier_attr"] |
|
|
| model = timm.create_model(timm_name, pretrained=False) |
|
|
| classifier = getattr(model, classifier_attr) |
|
|
| if isinstance(classifier, nn.Linear): |
| in_features = classifier.in_features |
| else: |
| raise ValueError( |
| f"Impossible de déterminer in_features pour la tête du modèle '{timm_name}'. " |
| f"Attribut '{classifier_attr}' de type {type(classifier)} non supporté." |
| ) |
|
|
| new_classifier = nn.Linear(in_features, num_classes) |
| setattr(model, classifier_attr, new_classifier) |
|
|
| return model |
|
|
|
|
| |
| |
| |
|
|
| def _clean_state_dict_keys(state_dict: dict) -> dict: |
| """ |
| Nettoie les clés pour gérer les prefixes 'model.' (Lightning) et 'module.' (DataParallel). |
| """ |
| new_state = {} |
| for k, v in state_dict.items(): |
| new_key = k |
| if new_key.startswith("model."): |
| new_key = new_key[len("model."):] |
| if new_key.startswith("module."): |
| new_key = new_key[len("module."):] |
| new_state[new_key] = v |
| return new_state |
|
|
|
|
| |
| |
| |
|
|
| def load_model(weights_path: str, map_location=None, model_key: str | None = None): |
| """ |
| Charge un modèle avec les poids entraînés. |
| |
| - Pour YOLO (Ultralytics) : charge un modèle de détection. |
| - Pour EfficientNet / Inception : charge un modèle de classification binaire. |
| |
| Retour : |
| -------- |
| model : torch.nn.Module ou YOLO |
| device : torch.device |
| """ |
| device = map_location if map_location is not None else get_device() |
|
|
| |
| if model_key is None: |
| model_key = infer_model_key_from_path(weights_path) |
|
|
| |
| if model_key == "yolo": |
| if YOLO is None: |
| raise ImportError( |
| "Le modèle YOLO est demandé mais la librairie 'ultralytics' " |
| "n'est pas installée. Ajoutez 'ultralytics' dans requirements.txt." |
| ) |
| yolo_model = YOLO(weights_path) |
| |
| try: |
| yolo_model.to(device) |
| except Exception: |
| pass |
| return yolo_model, device |
|
|
| |
| checkpoint = torch.load(weights_path, map_location=device) |
|
|
| if isinstance(checkpoint, dict) and "state_dict" in checkpoint: |
| state_dict = checkpoint["state_dict"] |
| else: |
| state_dict = checkpoint |
|
|
| state_dict = _clean_state_dict_keys(state_dict) |
|
|
| model = build_model(model_key=model_key, num_classes=2) |
| missing, unexpected = model.load_state_dict(state_dict, strict=False) |
|
|
| |
| |
| |
|
|
| model = model.to(device) |
| model.eval() |
|
|
| return model, device |
|
|
|
|
| |
| |
| |
|
|
| def get_val_transform(image_size: int | None = None): |
| """ |
| Renvoie les transformations à appliquer aux images pour l'inférence. |
| |
| Si image_size est None → DEFAULT_IMAGE_SIZE (224). |
| """ |
| if image_size is None: |
| image_size = DEFAULT_IMAGE_SIZE |
|
|
| transform = transforms.Compose([ |
| transforms.Resize((image_size, image_size)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), |
| ]) |
| return transform |
|
|
|
|
| |
| |
| |
|
|
| def preprocess_image(image: Image.Image, transform=None, image_size: int | None = None): |
| """ |
| Applique les transforms à une image PIL et ajoute une dimension batch. |
| """ |
| if transform is None: |
| transform = get_val_transform(image_size=image_size) |
|
|
| img_tensor = transform(image) |
| img_tensor = img_tensor.unsqueeze(0) |
|
|
| return img_tensor |
|
|
|
|
| |
| |
| |
|
|
| def predict_from_tensor( |
| image_tensor: torch.Tensor, |
| model: torch.nn.Module, |
| device: torch.device, |
| threshold: float = 0.5, |
| ): |
| """ |
| Prédit la classe (fire/no_fire) à partir d'un tenseur déjà prétraité |
| pour les modèles de classification (EfficientNet, Inception...). |
| """ |
| image_tensor = image_tensor.to(device) |
|
|
| with torch.no_grad(): |
| outputs = model(image_tensor) |
| probs = torch.softmax(outputs, dim=1) |
|
|
| fire_prob = probs[0, 1].item() |
| predicted_idx = 1 if fire_prob >= threshold else 0 |
|
|
| predicted_label = IDX_TO_LABEL[predicted_idx] |
| return predicted_label, fire_prob |
|
|
|
|
| |
| |
| |
|
|
| def predict_from_pil( |
| image: Image.Image, |
| model, |
| device: torch.device, |
| transform=None, |
| threshold: float = 0.5, |
| image_size: int | None = None, |
| ): |
| """ |
| Prédit la classe à partir d'une image PIL. |
| |
| Retour |
| ------ |
| predicted_label : str |
| "fire" ou "no_fire". |
| fire_prob : float |
| Probabilité de "fire". |
| annotated_image : PIL.Image.Image ou None |
| Image annotée (bounding boxes) pour les modèles de détection (YOLO), |
| None pour les modèles de classification. |
| """ |
| if image.mode != "RGB": |
| image = image.convert("RGB") |
|
|
| |
| if hasattr(model, "task") and getattr(model, "task", None) == "detect": |
| results = model(image) |
| result = results[0] |
|
|
| boxes = getattr(result, "boxes", None) |
| fire_prob = 0.0 |
|
|
| if boxes is not None and len(boxes) > 0: |
| classes = boxes.cls |
| confs = boxes.conf |
|
|
| |
| mask_fire = torch.zeros_like(classes, dtype=torch.bool) |
| for cid in FIRE_CLASS_IDS: |
| mask_fire |= (classes == cid) |
|
|
| if mask_fire.any(): |
| fire_prob = float(confs[mask_fire].max().item()) |
|
|
| predicted_label = "fire" if fire_prob >= threshold else "no_fire" |
|
|
| |
| annotated_image = None |
| try: |
| annotated_array = result.plot() |
| annotated_image = Image.fromarray(annotated_array[..., ::-1]) |
| except Exception: |
| annotated_image = None |
|
|
| return predicted_label, fire_prob, annotated_image |
|
|
| |
| image_tensor = preprocess_image(image, transform=transform, image_size=image_size) |
| predicted_label, fire_prob = predict_from_tensor(image_tensor, model, device, threshold=threshold) |
| annotated_image = None |
|
|
| return predicted_label, fire_prob, annotated_image |
|
|
|
|
| |
| |
| |
|
|
| def predict_from_path( |
| image_path: str, |
| model, |
| device: torch.device, |
| transform=None, |
| threshold: float = 0.5, |
| image_size: int | None = None, |
| ): |
| """ |
| Prédit la classe à partir d'un chemin vers une image. |
| """ |
| image = Image.open(image_path) |
| return predict_from_pil( |
| image=image, |
| model=model, |
| device=device, |
| transform=transform, |
| threshold=threshold, |
| image_size=image_size, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| |
| weights_path = "efficientnet_fire.pt" |
| if not os.path.exists(weights_path): |
| print(f"[ERREUR] Fichier de poids introuvable : {weights_path}") |
| else: |
| model, device = load_model(weights_path) |
| print(f"Modèle chargé sur le device : {device}") |
|
|
| transform = get_val_transform() |
| test_image_path = "example.jpg" |
|
|
| if not os.path.exists(test_image_path): |
| print(f"[INFO] Aucune image test trouvée à : {test_image_path}") |
| else: |
| label, prob, _ = predict_from_path( |
| test_image_path, |
| model=model, |
| device=device, |
| transform=transform, |
| threshold=0.5, |
| ) |
| print(f"Résultat pour {test_image_path} : label={label}, prob_fire={prob:.4f}") |
|
|