Fire_Detection_Streamlit / inference.py
ninafr8175
update dashboard 6
2e0d90d
"""
inference.py
------------
Module d'inférence pour des modèles de classification binaire : FIRE (1) / NO_FIRE (0).
Supporte :
- EfficientNet-B0 (classification)
- Inception v3 (classification)
- YOLO (Ultralytics, détection) pour localiser le feu
Retour principal des fonctions de prédiction :
predicted_label (str), fire_prob (float), annotated_image (PIL.Image ou None)
"""
# ----------------------------
# 1) Imports
# ----------------------------
import os
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import timm
# YOLO (Ultralytics) – optionnel
try:
from ultralytics import YOLO
except ImportError:
YOLO = None # si la lib n'est pas installée, on gère ça proprement
# ----------------------------
# 2) Constantes globales
# ----------------------------
# Taille d'entrée par défaut si rien n'est précisé
DEFAULT_IMAGE_SIZE = 224
# Stats ImageNet
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
# Mapping idx -> label lisible
IDX_TO_LABEL = {
0: "no_fire",
1: "fire",
}
# Registre des modèles de classification (timm)
MODEL_REGISTRY = {
"efficientnet_b0": {
"timm_name": "efficientnet_b0",
"image_size": 224,
"classifier_attr": "classifier",
},
"inception_v3": {
"timm_name": "inception_v3",
"image_size": 299,
"classifier_attr": "fc",
},
}
# Modèle par défaut si on ne sait pas quoi choisir
DEFAULT_MODEL_KEY = "efficientnet_b0"
# IDs des classes "feu" pour YOLO (à adapter si besoin après entraînement)
# Exemple : si model.names == {0: 'fire'} → [0]
# si model.names == {0: 'no_fire', 1: 'fire'} → [1]
FIRE_CLASS_IDS = [0]
# ----------------------------
# 3) Device
# ----------------------------
def get_device():
"""
Retourne le device à utiliser pour l'inférence :
- 'cuda' si un GPU est disponible
- sinon 'cpu'
"""
if torch.cuda.is_available():
return torch.device("cuda")
return torch.device("cpu")
# ----------------------------
# 4) Détection du type de modèle
# ----------------------------
def infer_model_key_from_path(weights_path: str) -> str:
"""
Devine une clé de modèle (model_key) à partir du nom de fichier.
Exemple :
- "yolov8_fire.pt" → "yolo"
- "inception3_fire.pt" → "inception_v3"
- "efficientnet_fire.pt" → "efficientnet_b0" (par défaut)
"""
filename = os.path.basename(weights_path).lower()
if "yolo" in filename:
return "yolo"
if "inception" in filename:
return "inception_v3"
return DEFAULT_MODEL_KEY
def get_model_config(model_key: str) -> dict:
"""
Renvoie la config du modèle à partir d'une model_key.
Fallback : EfficientNet-B0 si model_key inconnue.
"""
if model_key in MODEL_REGISTRY:
return MODEL_REGISTRY[model_key]
return MODEL_REGISTRY[DEFAULT_MODEL_KEY]
# ----------------------------
# 5) Construction modèle (classification)
# ----------------------------
def build_model(model_key: str, num_classes: int = 2) -> torch.nn.Module:
"""
Construit un modèle de classification (EfficientNet, Inception...)
et adapte la dernière couche à num_classes sorties.
"""
config = get_model_config(model_key)
timm_name = config["timm_name"]
classifier_attr = config["classifier_attr"]
model = timm.create_model(timm_name, pretrained=False)
classifier = getattr(model, classifier_attr)
if isinstance(classifier, nn.Linear):
in_features = classifier.in_features
else:
raise ValueError(
f"Impossible de déterminer in_features pour la tête du modèle '{timm_name}'. "
f"Attribut '{classifier_attr}' de type {type(classifier)} non supporté."
)
new_classifier = nn.Linear(in_features, num_classes)
setattr(model, classifier_attr, new_classifier)
return model
# ----------------------------
# 6) Nettoyage de state_dict
# ----------------------------
def _clean_state_dict_keys(state_dict: dict) -> dict:
"""
Nettoie les clés pour gérer les prefixes 'model.' (Lightning) et 'module.' (DataParallel).
"""
new_state = {}
for k, v in state_dict.items():
new_key = k
if new_key.startswith("model."):
new_key = new_key[len("model."):]
if new_key.startswith("module."):
new_key = new_key[len("module."):]
new_state[new_key] = v
return new_state
# ----------------------------
# 7) Chargement du modèle
# ----------------------------
def load_model(weights_path: str, map_location=None, model_key: str | None = None):
"""
Charge un modèle avec les poids entraînés.
- Pour YOLO (Ultralytics) : charge un modèle de détection.
- Pour EfficientNet / Inception : charge un modèle de classification binaire.
Retour :
--------
model : torch.nn.Module ou YOLO
device : torch.device
"""
device = map_location if map_location is not None else get_device()
# Détecter le type de modèle si non fourni
if model_key is None:
model_key = infer_model_key_from_path(weights_path)
# 🔹 Cas YOLO : modèle de détection (Ultralytics)
if model_key == "yolo":
if YOLO is None:
raise ImportError(
"Le modèle YOLO est demandé mais la librairie 'ultralytics' "
"n'est pas installée. Ajoutez 'ultralytics' dans requirements.txt."
)
yolo_model = YOLO(weights_path)
# YOLO gère déjà souvent le device en interne, on essaye juste par sécurité
try:
yolo_model.to(device)
except Exception:
pass
return yolo_model, device
# 🔹 Cas classification (EfficientNet, Inception...)
checkpoint = torch.load(weights_path, map_location=device)
if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
else:
state_dict = checkpoint
state_dict = _clean_state_dict_keys(state_dict)
model = build_model(model_key=model_key, num_classes=2)
missing, unexpected = model.load_state_dict(state_dict, strict=False)
# (optionnel) debug :
# print("Missing keys:", missing)
# print("Unexpected keys:", unexpected)
model = model.to(device)
model.eval()
return model, device
# ----------------------------
# 8) Transforms pour l'inférence
# ----------------------------
def get_val_transform(image_size: int | None = None):
"""
Renvoie les transformations à appliquer aux images pour l'inférence.
Si image_size est None → DEFAULT_IMAGE_SIZE (224).
"""
if image_size is None:
image_size = DEFAULT_IMAGE_SIZE
transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])
return transform
# ----------------------------
# 9) Prétraitement d'une image
# ----------------------------
def preprocess_image(image: Image.Image, transform=None, image_size: int | None = None):
"""
Applique les transforms à une image PIL et ajoute une dimension batch.
"""
if transform is None:
transform = get_val_transform(image_size=image_size)
img_tensor = transform(image) # [3, H, W]
img_tensor = img_tensor.unsqueeze(0) # [1, 3, H, W]
return img_tensor
# ----------------------------
# 10) Prédiction depuis un tenseur (classification)
# ----------------------------
def predict_from_tensor(
image_tensor: torch.Tensor,
model: torch.nn.Module,
device: torch.device,
threshold: float = 0.5,
):
"""
Prédit la classe (fire/no_fire) à partir d'un tenseur déjà prétraité
pour les modèles de classification (EfficientNet, Inception...).
"""
image_tensor = image_tensor.to(device)
with torch.no_grad():
outputs = model(image_tensor) # logits [1, 2]
probs = torch.softmax(outputs, dim=1) # probas
fire_prob = probs[0, 1].item()
predicted_idx = 1 if fire_prob >= threshold else 0
predicted_label = IDX_TO_LABEL[predicted_idx]
return predicted_label, fire_prob
# ----------------------------
# 11) Prédiction depuis une image PIL
# ----------------------------
def predict_from_pil(
image: Image.Image,
model,
device: torch.device,
transform=None,
threshold: float = 0.5,
image_size: int | None = None,
):
"""
Prédit la classe à partir d'une image PIL.
Retour
------
predicted_label : str
"fire" ou "no_fire".
fire_prob : float
Probabilité de "fire".
annotated_image : PIL.Image.Image ou None
Image annotée (bounding boxes) pour les modèles de détection (YOLO),
None pour les modèles de classification.
"""
if image.mode != "RGB":
image = image.convert("RGB")
# 🔹 Cas YOLO : modèle de détection (Ultralytics)
if hasattr(model, "task") and getattr(model, "task", None) == "detect":
results = model(image)
result = results[0]
boxes = getattr(result, "boxes", None)
fire_prob = 0.0
if boxes is not None and len(boxes) > 0:
classes = boxes.cls # ids des classes (tensor)
confs = boxes.conf # scores de confiance (tensor)
# masque des boxes "feu"
mask_fire = torch.zeros_like(classes, dtype=torch.bool)
for cid in FIRE_CLASS_IDS:
mask_fire |= (classes == cid)
if mask_fire.any():
fire_prob = float(confs[mask_fire].max().item())
predicted_label = "fire" if fire_prob >= threshold else "no_fire"
# Image annotée avec les bounding boxes
annotated_image = None
try:
annotated_array = result.plot() # numpy array BGR
annotated_image = Image.fromarray(annotated_array[..., ::-1]) # BGR -> RGB
except Exception:
annotated_image = None
return predicted_label, fire_prob, annotated_image
# 🔹 Cas classification classique
image_tensor = preprocess_image(image, transform=transform, image_size=image_size)
predicted_label, fire_prob = predict_from_tensor(image_tensor, model, device, threshold=threshold)
annotated_image = None
return predicted_label, fire_prob, annotated_image
# ----------------------------
# 12) Prédiction depuis un chemin de fichier
# ----------------------------
def predict_from_path(
image_path: str,
model,
device: torch.device,
transform=None,
threshold: float = 0.5,
image_size: int | None = None,
):
"""
Prédit la classe à partir d'un chemin vers une image.
"""
image = Image.open(image_path)
return predict_from_pil(
image=image,
model=model,
device=device,
transform=transform,
threshold=threshold,
image_size=image_size,
)
# ----------------------------
# 13) Exemple d'utilisation en script direct
# ----------------------------
if __name__ == "__main__":
# Petit test local (à adapter)
weights_path = "efficientnet_fire.pt"
if not os.path.exists(weights_path):
print(f"[ERREUR] Fichier de poids introuvable : {weights_path}")
else:
model, device = load_model(weights_path)
print(f"Modèle chargé sur le device : {device}")
transform = get_val_transform()
test_image_path = "example.jpg"
if not os.path.exists(test_image_path):
print(f"[INFO] Aucune image test trouvée à : {test_image_path}")
else:
label, prob, _ = predict_from_path(
test_image_path,
model=model,
device=device,
transform=transform,
threshold=0.5,
)
print(f"Résultat pour {test_image_path} : label={label}, prob_fire={prob:.4f}")