ninafr8175 commited on
Commit ·
2e0d90d
1
Parent(s): bf91db4
update dashboard 6
Browse files- .gitignore +2 -0
- app.py +29 -3
- inference.py +143 -202
- requirements.txt +0 -0
.gitignore
CHANGED
|
@@ -11,6 +11,8 @@ __pycache__/
|
|
| 11 |
*.pptx
|
| 12 |
.dockerignore
|
| 13 |
app_v0.py
|
|
|
|
| 14 |
inference_v0.py
|
|
|
|
| 15 |
Dockerfile
|
| 16 |
Notes.txt
|
|
|
|
| 11 |
*.pptx
|
| 12 |
.dockerignore
|
| 13 |
app_v0.py
|
| 14 |
+
app_v1.py
|
| 15 |
inference_v0.py
|
| 16 |
+
inference_v1.py
|
| 17 |
Dockerfile
|
| 18 |
Notes.txt
|
app.py
CHANGED
|
@@ -14,6 +14,7 @@ ALL_MODEL_FILES = {
|
|
| 14 |
"Modèle EfficientnetB0 Baseline": "efficientnet_fire.pt",
|
| 15 |
"Modèle EfficientnetB0 FE": "efficientnet_fire_2.pt",
|
| 16 |
"Modèle EfficientnetB0 FT": "efficientnet_fire_3.pt",
|
|
|
|
| 17 |
"Modèle Inception3": "inception3_fire.pt",
|
| 18 |
}
|
| 19 |
|
|
@@ -55,7 +56,22 @@ MODEL_METRICS = {
|
|
| 55 |
"fp": 440,
|
| 56 |
"fn": 680,
|
| 57 |
},
|
| 58 |
-
# "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
}
|
| 60 |
|
| 61 |
# -----------------------------------------------------------
|
|
@@ -164,13 +180,12 @@ else:
|
|
| 164 |
# Afficher l'image uploadée
|
| 165 |
# -------------------------------------------------------
|
| 166 |
image = Image.open(uploaded_file)
|
| 167 |
-
st.image(image, caption="Image chargée", use_container_width=True)
|
| 168 |
|
| 169 |
# -------------------------------------------------------
|
| 170 |
# Prédiction
|
| 171 |
# -------------------------------------------------------
|
| 172 |
with st.spinner("Analyse de l'image en cours..."):
|
| 173 |
-
label, prob = predict_from_pil(
|
| 174 |
image=image,
|
| 175 |
model=model,
|
| 176 |
device=device,
|
|
@@ -178,6 +193,17 @@ else:
|
|
| 178 |
threshold=threshold
|
| 179 |
)
|
| 180 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
# -------------------------------------------------------
|
| 182 |
# Affichage du résultat avec couleur
|
| 183 |
# -------------------------------------------------------
|
|
|
|
| 14 |
"Modèle EfficientnetB0 Baseline": "efficientnet_fire.pt",
|
| 15 |
"Modèle EfficientnetB0 FE": "efficientnet_fire_2.pt",
|
| 16 |
"Modèle EfficientnetB0 FT": "efficientnet_fire_3.pt",
|
| 17 |
+
"Modèle YOLOv8": "yolov8_fire.pt",
|
| 18 |
"Modèle Inception3": "inception3_fire.pt",
|
| 19 |
}
|
| 20 |
|
|
|
|
| 56 |
"fp": 440,
|
| 57 |
"fn": 680,
|
| 58 |
},
|
| 59 |
+
# "yolov8_fire.pt": {
|
| 60 |
+
# "accuracy": 0.7717,
|
| 61 |
+
# "precision": 0.8713,
|
| 62 |
+
# "recall": 0.8142,
|
| 63 |
+
# "f1": 0.8418,
|
| 64 |
+
# "fp": 440,
|
| 65 |
+
# "fn": 680,
|
| 66 |
+
# },
|
| 67 |
+
# "inception3_fire.pt": {
|
| 68 |
+
# "accuracy": 0.7717,
|
| 69 |
+
# "precision": 0.8713,
|
| 70 |
+
# "recall": 0.8142,
|
| 71 |
+
# "f1": 0.8418,
|
| 72 |
+
# "fp": 440,
|
| 73 |
+
# "fn": 680,
|
| 74 |
+
# },
|
| 75 |
}
|
| 76 |
|
| 77 |
# -----------------------------------------------------------
|
|
|
|
| 180 |
# Afficher l'image uploadée
|
| 181 |
# -------------------------------------------------------
|
| 182 |
image = Image.open(uploaded_file)
|
|
|
|
| 183 |
|
| 184 |
# -------------------------------------------------------
|
| 185 |
# Prédiction
|
| 186 |
# -------------------------------------------------------
|
| 187 |
with st.spinner("Analyse de l'image en cours..."):
|
| 188 |
+
label, prob, annotated_image = predict_from_pil(
|
| 189 |
image=image,
|
| 190 |
model=model,
|
| 191 |
device=device,
|
|
|
|
| 193 |
threshold=threshold
|
| 194 |
)
|
| 195 |
|
| 196 |
+
# Image originale
|
| 197 |
+
st.image(image, caption="Image chargée", use_container_width=True)
|
| 198 |
+
|
| 199 |
+
# Si YOLO a fourni une image annotée, on l'affiche
|
| 200 |
+
if annotated_image is not None:
|
| 201 |
+
st.image(
|
| 202 |
+
annotated_image,
|
| 203 |
+
caption="Zones détectées (modèle de détection)",
|
| 204 |
+
use_container_width=True,
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
# -------------------------------------------------------
|
| 208 |
# Affichage du résultat avec couleur
|
| 209 |
# -------------------------------------------------------
|
inference.py
CHANGED
|
@@ -3,59 +3,51 @@ inference.py
|
|
| 3 |
------------
|
| 4 |
Module d'inférence pour des modèles de classification binaire : FIRE (1) / NO_FIRE (0).
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
| 8 |
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
sinon → EfficientNet-B0 par défaut),
|
| 12 |
-
- ou que le fichier .pt contienne un dictionnaire avec "state_dict" (et
|
| 13 |
-
éventuellement "model_name" / "model_key").
|
| 14 |
-
|
| 15 |
-
Usage typique :
|
| 16 |
-
---------------
|
| 17 |
-
from inference import load_model, get_val_transform, predict_from_path
|
| 18 |
-
|
| 19 |
-
model, device = load_model("efficientnet_fire.pt")
|
| 20 |
-
transform = get_val_transform() # ou get_val_transform(image_size_personnalisée)
|
| 21 |
-
|
| 22 |
-
label, prob = predict_from_path("mon_image.jpg", model, device, transform)
|
| 23 |
-
print(label, prob)
|
| 24 |
"""
|
| 25 |
|
| 26 |
# ----------------------------
|
| 27 |
# 1) Imports
|
| 28 |
# ----------------------------
|
| 29 |
-
import os
|
| 30 |
-
|
| 31 |
-
import torch
|
| 32 |
-
|
| 33 |
-
from
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
|
| 37 |
# ----------------------------
|
| 38 |
# 2) Constantes globales
|
| 39 |
# ----------------------------
|
| 40 |
|
| 41 |
-
# Taille d'entrée par défaut
|
| 42 |
-
DEFAULT_IMAGE_SIZE = 224
|
| 43 |
|
| 44 |
-
#
|
| 45 |
-
IMAGENET_MEAN = [0.485, 0.456, 0.406]
|
| 46 |
-
IMAGENET_STD = [0.229, 0.224, 0.225]
|
| 47 |
|
| 48 |
-
# Mapping
|
| 49 |
IDX_TO_LABEL = {
|
| 50 |
-
0: "no_fire",
|
| 51 |
-
1: "fire"
|
| 52 |
}
|
| 53 |
|
| 54 |
-
# Registre des modèles
|
| 55 |
-
# - model_key : identifiant interne (nos clés)
|
| 56 |
-
# - timm_name : nom utilisé dans timm.create_model(...)
|
| 57 |
-
# - image_size : taille d'entrée "recommandée" pour ce modèle (optionnelle)
|
| 58 |
-
# - classifier_attr : nom de l'attribut contenant la dernière couche de classification
|
| 59 |
MODEL_REGISTRY = {
|
| 60 |
"efficientnet_b0": {
|
| 61 |
"timm_name": "efficientnet_b0",
|
|
@@ -69,12 +61,17 @@ MODEL_REGISTRY = {
|
|
| 69 |
},
|
| 70 |
}
|
| 71 |
|
| 72 |
-
#
|
| 73 |
DEFAULT_MODEL_KEY = "efficientnet_b0"
|
| 74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
# ----------------------------
|
| 77 |
-
# 3)
|
| 78 |
# ----------------------------
|
| 79 |
|
| 80 |
def get_device():
|
|
@@ -85,8 +82,7 @@ def get_device():
|
|
| 85 |
"""
|
| 86 |
if torch.cuda.is_available():
|
| 87 |
return torch.device("cuda")
|
| 88 |
-
|
| 89 |
-
return torch.device("cpu")
|
| 90 |
|
| 91 |
|
| 92 |
# ----------------------------
|
|
@@ -96,62 +92,58 @@ def get_device():
|
|
| 96 |
def infer_model_key_from_path(weights_path: str) -> str:
|
| 97 |
"""
|
| 98 |
Devine une clé de modèle (model_key) à partir du nom de fichier.
|
|
|
|
| 99 |
Exemple :
|
| 100 |
-
- "
|
| 101 |
-
- "
|
|
|
|
| 102 |
"""
|
| 103 |
filename = os.path.basename(weights_path).lower()
|
| 104 |
|
| 105 |
-
|
|
|
|
|
|
|
| 106 |
if "inception" in filename:
|
| 107 |
return "inception_v3"
|
| 108 |
|
| 109 |
-
# Sinon, par défaut → EfficientNet-B0
|
| 110 |
return DEFAULT_MODEL_KEY
|
| 111 |
|
| 112 |
|
| 113 |
def get_model_config(model_key: str) -> dict:
|
| 114 |
"""
|
| 115 |
Renvoie la config du modèle à partir d'une model_key.
|
| 116 |
-
|
| 117 |
"""
|
| 118 |
if model_key in MODEL_REGISTRY:
|
| 119 |
return MODEL_REGISTRY[model_key]
|
| 120 |
-
# Fallback de sécurité : ne jamais casser si clé inconnue
|
| 121 |
return MODEL_REGISTRY[DEFAULT_MODEL_KEY]
|
| 122 |
|
| 123 |
|
| 124 |
# ----------------------------
|
| 125 |
-
# 5) Construction
|
| 126 |
# ----------------------------
|
| 127 |
|
| 128 |
def build_model(model_key: str, num_classes: int = 2) -> torch.nn.Module:
|
| 129 |
"""
|
| 130 |
-
Construit
|
| 131 |
-
et adapte la
|
| 132 |
"""
|
| 133 |
config = get_model_config(model_key)
|
| 134 |
timm_name = config["timm_name"]
|
| 135 |
classifier_attr = config["classifier_attr"]
|
| 136 |
|
| 137 |
-
# On crée le backbone via timm (sans pré-entraînement, les poids viendront du .pt)
|
| 138 |
model = timm.create_model(timm_name, pretrained=False)
|
| 139 |
|
| 140 |
-
# On adapte la tête de classification à notre problème binaire.
|
| 141 |
-
# On récupère l'ancienne couche de classification (fc, classifier, etc.)
|
| 142 |
classifier = getattr(model, classifier_attr)
|
| 143 |
|
| 144 |
-
# Certains modèles ont déjà un nn.Linear, on récupère in_features
|
| 145 |
if isinstance(classifier, nn.Linear):
|
| 146 |
in_features = classifier.in_features
|
| 147 |
else:
|
| 148 |
-
# Cas plus exotique : on essaie de deviner proprement, sinon on lève une erreur claire.
|
| 149 |
raise ValueError(
|
| 150 |
f"Impossible de déterminer in_features pour la tête du modèle '{timm_name}'. "
|
| 151 |
f"Attribut '{classifier_attr}' de type {type(classifier)} non supporté."
|
| 152 |
)
|
| 153 |
|
| 154 |
-
# On remplace par une couche linéaire adaptée à notre nombre de classes
|
| 155 |
new_classifier = nn.Linear(in_features, num_classes)
|
| 156 |
setattr(model, classifier_attr, new_classifier)
|
| 157 |
|
|
@@ -159,89 +151,78 @@ def build_model(model_key: str, num_classes: int = 2) -> torch.nn.Module:
|
|
| 159 |
|
| 160 |
|
| 161 |
# ----------------------------
|
| 162 |
-
# 6)
|
| 163 |
# ----------------------------
|
| 164 |
|
| 165 |
def _clean_state_dict_keys(state_dict: dict) -> dict:
|
| 166 |
"""
|
| 167 |
-
Nettoie les clés
|
| 168 |
-
- clés préfixées par 'model.' (Lightning)
|
| 169 |
-
- clés préfixées par 'module.' (DataParallel)
|
| 170 |
"""
|
| 171 |
-
|
| 172 |
for k, v in state_dict.items():
|
| 173 |
new_key = k
|
| 174 |
if new_key.startswith("model."):
|
| 175 |
new_key = new_key[len("model."):]
|
| 176 |
if new_key.startswith("module."):
|
| 177 |
new_key = new_key[len("module."):]
|
| 178 |
-
|
| 179 |
-
return
|
| 180 |
|
| 181 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
def load_model(weights_path: str, map_location=None, model_key: str | None = None):
|
| 183 |
"""
|
| 184 |
Charge un modèle avec les poids entraînés.
|
| 185 |
|
| 186 |
-
|
| 187 |
-
-
|
| 188 |
-
weights_path : str
|
| 189 |
-
Chemin vers le fichier .pt (state_dict ou dict avec 'state_dict').
|
| 190 |
-
map_location : torch.device ou None
|
| 191 |
-
Device sur lequel charger les poids. Si None, on détecte automatiquement.
|
| 192 |
-
model_key : str ou None
|
| 193 |
-
Clé de modèle à utiliser (ex: 'efficientnet_b0', 'inception_v3').
|
| 194 |
-
Si None, on essaie de la déduire du nom de fichier.
|
| 195 |
|
| 196 |
-
Retour
|
| 197 |
-
------
|
| 198 |
-
model : torch.nn.Module
|
| 199 |
-
Le modèle prêt pour l'inférence.
|
| 200 |
device : torch.device
|
| 201 |
-
Le device utilisé (cuda ou cpu).
|
| 202 |
"""
|
| 203 |
-
# 1) Sélection du device
|
| 204 |
device = map_location if map_location is not None else get_device()
|
| 205 |
|
| 206 |
-
#
|
| 207 |
if model_key is None:
|
| 208 |
model_key = infer_model_key_from_path(weights_path)
|
| 209 |
|
| 210 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
checkpoint = torch.load(weights_path, map_location=device)
|
| 212 |
|
| 213 |
-
# 4) Si le checkpoint est un dict complet (ex: {'state_dict': ..., 'model_name': ...})
|
| 214 |
if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
|
| 215 |
state_dict = checkpoint["state_dict"]
|
| 216 |
-
|
| 217 |
-
# Optionnel : si un model_name ou model_key est stocké dedans, on peut le préférer.
|
| 218 |
-
if "model_key" in checkpoint and checkpoint["model_key"] in MODEL_REGISTRY:
|
| 219 |
-
model_key = checkpoint["model_key"]
|
| 220 |
-
elif "model_name" in checkpoint:
|
| 221 |
-
# Tentative : si model_name correspond au timm_name d'une entrée du registre
|
| 222 |
-
model_name_lower = str(checkpoint["model_name"]).lower()
|
| 223 |
-
for k, cfg in MODEL_REGISTRY.items():
|
| 224 |
-
if cfg["timm_name"].lower() == model_name_lower:
|
| 225 |
-
model_key = k
|
| 226 |
-
break
|
| 227 |
else:
|
| 228 |
-
# Cas simple : le fichier .pt est directement un state_dict
|
| 229 |
state_dict = checkpoint
|
| 230 |
|
| 231 |
-
# 5) Nettoyage des clés du state_dict (Lightning, DataParallel...)
|
| 232 |
state_dict = _clean_state_dict_keys(state_dict)
|
| 233 |
|
| 234 |
-
# 6) Construction de l'architecture adaptée
|
| 235 |
model = build_model(model_key=model_key, num_classes=2)
|
| 236 |
-
|
| 237 |
-
# 7) Application des poids (strict=False pour tolérer quelques petites diff de clés)
|
| 238 |
missing, unexpected = model.load_state_dict(state_dict, strict=False)
|
| 239 |
|
| 240 |
-
# (
|
| 241 |
# print("Missing keys:", missing)
|
| 242 |
# print("Unexpected keys:", unexpected)
|
| 243 |
|
| 244 |
-
# 8) Envoi sur le bon device + mode eval
|
| 245 |
model = model.to(device)
|
| 246 |
model.eval()
|
| 247 |
|
|
@@ -249,21 +230,14 @@ def load_model(weights_path: str, map_location=None, model_key: str | None = Non
|
|
| 249 |
|
| 250 |
|
| 251 |
# ----------------------------
|
| 252 |
-
#
|
| 253 |
# ----------------------------
|
| 254 |
|
| 255 |
def get_val_transform(image_size: int | None = None):
|
| 256 |
"""
|
| 257 |
Renvoie les transformations à appliquer aux images pour l'inférence.
|
| 258 |
-
Paramètres
|
| 259 |
-
----------
|
| 260 |
-
image_size : int ou None
|
| 261 |
-
Si None, on utilise DEFAULT_IMAGE_SIZE (224).
|
| 262 |
-
Sinon, on redimensionne en (image_size, image_size).
|
| 263 |
|
| 264 |
-
|
| 265 |
-
------
|
| 266 |
-
transform : torchvision.transforms.Compose
|
| 267 |
"""
|
| 268 |
if image_size is None:
|
| 269 |
image_size = DEFAULT_IMAGE_SIZE
|
|
@@ -277,26 +251,12 @@ def get_val_transform(image_size: int | None = None):
|
|
| 277 |
|
| 278 |
|
| 279 |
# ----------------------------
|
| 280 |
-
#
|
| 281 |
# ----------------------------
|
| 282 |
|
| 283 |
def preprocess_image(image: Image.Image, transform=None, image_size: int | None = None):
|
| 284 |
"""
|
| 285 |
Applique les transforms à une image PIL et ajoute une dimension batch.
|
| 286 |
-
|
| 287 |
-
Paramètres
|
| 288 |
-
----------
|
| 289 |
-
image : PIL.Image.Image
|
| 290 |
-
Image brute chargée (par exemple via Image.open(...)).
|
| 291 |
-
transform : callable ou None
|
| 292 |
-
Transformations à appliquer (si None, on utilise get_val_transform(image_size)).
|
| 293 |
-
image_size : int ou None
|
| 294 |
-
Taille de redimensionnement si transform est None.
|
| 295 |
-
|
| 296 |
-
Retour
|
| 297 |
-
------
|
| 298 |
-
image_tensor : torch.Tensor
|
| 299 |
-
Tenseur prêt pour l'inférence, de taille [1, 3, H, W].
|
| 300 |
"""
|
| 301 |
if transform is None:
|
| 302 |
transform = get_val_transform(image_size=image_size)
|
|
@@ -308,57 +268,39 @@ def preprocess_image(image: Image.Image, transform=None, image_size: int | None
|
|
| 308 |
|
| 309 |
|
| 310 |
# ----------------------------
|
| 311 |
-
#
|
| 312 |
# ----------------------------
|
| 313 |
|
| 314 |
def predict_from_tensor(
|
| 315 |
image_tensor: torch.Tensor,
|
| 316 |
model: torch.nn.Module,
|
| 317 |
device: torch.device,
|
| 318 |
-
threshold: float = 0.5
|
| 319 |
):
|
| 320 |
"""
|
| 321 |
-
Prédit la classe (fire/no_fire) à partir d'un tenseur déjà prétraité
|
| 322 |
-
|
| 323 |
-
Paramètres
|
| 324 |
-
----------
|
| 325 |
-
image_tensor : torch.Tensor
|
| 326 |
-
Tenseur d'images de taille [1, 3, H, W] (batch de 1 image).
|
| 327 |
-
model : torch.nn.Module
|
| 328 |
-
Modèle chargé (EfficientNet, Inception, etc.).
|
| 329 |
-
device : torch.device
|
| 330 |
-
Device sur lequel le modèle est (cuda ou cpu).
|
| 331 |
-
threshold : float
|
| 332 |
-
Seuil sur la probabilité de FEU pour décider entre no_fire / fire.
|
| 333 |
-
|
| 334 |
-
Retour
|
| 335 |
-
------
|
| 336 |
-
predicted_label : str
|
| 337 |
-
"fire" ou "no_fire".
|
| 338 |
-
fire_prob : float
|
| 339 |
-
Probabilité prédite pour la classe "fire" (entre 0 et 1).
|
| 340 |
"""
|
| 341 |
image_tensor = image_tensor.to(device)
|
| 342 |
|
| 343 |
with torch.no_grad():
|
| 344 |
outputs = model(image_tensor) # logits [1, 2]
|
| 345 |
-
probs = torch.softmax(outputs, dim=1) #
|
| 346 |
|
| 347 |
fire_prob = probs[0, 1].item()
|
| 348 |
-
|
| 349 |
-
if fire_prob >= threshold:
|
| 350 |
-
predicted_idx = 1 # feu
|
| 351 |
-
else:
|
| 352 |
-
predicted_idx = 0 # pas de feu
|
| 353 |
|
| 354 |
predicted_label = IDX_TO_LABEL[predicted_idx]
|
| 355 |
-
|
| 356 |
return predicted_label, fire_prob
|
| 357 |
|
| 358 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
def predict_from_pil(
|
| 360 |
image: Image.Image,
|
| 361 |
-
model
|
| 362 |
device: torch.device,
|
| 363 |
transform=None,
|
| 364 |
threshold: float = 0.5,
|
|
@@ -367,39 +309,66 @@ def predict_from_pil(
|
|
| 367 |
"""
|
| 368 |
Prédit la classe à partir d'une image PIL.
|
| 369 |
|
| 370 |
-
Paramètres
|
| 371 |
-
----------
|
| 372 |
-
image : PIL.Image.Image
|
| 373 |
-
Image chargée (par exemple via Image.open).
|
| 374 |
-
model : torch.nn.Module
|
| 375 |
-
Modèle chargé.
|
| 376 |
-
device : torch.device
|
| 377 |
-
Device (cuda ou cpu).
|
| 378 |
-
transform : callable ou None
|
| 379 |
-
Transformations à appliquer à l'image.
|
| 380 |
-
threshold : float
|
| 381 |
-
Seuil sur la probabilité de FEU.
|
| 382 |
-
image_size : int ou None
|
| 383 |
-
Taille utilisée si transform est None.
|
| 384 |
-
|
| 385 |
Retour
|
| 386 |
------
|
| 387 |
predicted_label : str
|
| 388 |
"fire" ou "no_fire".
|
| 389 |
fire_prob : float
|
| 390 |
Probabilité de "fire".
|
|
|
|
|
|
|
|
|
|
| 391 |
"""
|
| 392 |
if image.mode != "RGB":
|
| 393 |
image = image.convert("RGB")
|
| 394 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 395 |
image_tensor = preprocess_image(image, transform=transform, image_size=image_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 396 |
|
| 397 |
-
return predict_from_tensor(image_tensor, model, device, threshold=threshold)
|
| 398 |
|
|
|
|
|
|
|
|
|
|
| 399 |
|
| 400 |
def predict_from_path(
|
| 401 |
image_path: str,
|
| 402 |
-
model
|
| 403 |
device: torch.device,
|
| 404 |
transform=None,
|
| 405 |
threshold: float = 0.5,
|
|
@@ -407,32 +376,10 @@ def predict_from_path(
|
|
| 407 |
):
|
| 408 |
"""
|
| 409 |
Prédit la classe à partir d'un chemin vers une image.
|
| 410 |
-
|
| 411 |
-
Paramètres
|
| 412 |
-
----------
|
| 413 |
-
image_path : str
|
| 414 |
-
Chemin vers le fichier image (jpg, png, etc.).
|
| 415 |
-
model : torch.nn.Module
|
| 416 |
-
Modèle chargé.
|
| 417 |
-
device : torch.device
|
| 418 |
-
Device (cuda ou cpu).
|
| 419 |
-
transform : callable ou None
|
| 420 |
-
Transformations à appliquer.
|
| 421 |
-
threshold : float
|
| 422 |
-
Seuil sur la probabilité de FEU.
|
| 423 |
-
image_size : int ou None
|
| 424 |
-
Taille utilisée si transform est None.
|
| 425 |
-
|
| 426 |
-
Retour
|
| 427 |
-
------
|
| 428 |
-
predicted_label : str
|
| 429 |
-
"fire" ou "no_fire".
|
| 430 |
-
fire_prob : float
|
| 431 |
-
Probabilité de "fire".
|
| 432 |
"""
|
| 433 |
image = Image.open(image_path)
|
| 434 |
return predict_from_pil(
|
| 435 |
-
image,
|
| 436 |
model=model,
|
| 437 |
device=device,
|
| 438 |
transform=transform,
|
|
@@ -442,31 +389,25 @@ def predict_from_path(
|
|
| 442 |
|
| 443 |
|
| 444 |
# ----------------------------
|
| 445 |
-
#
|
| 446 |
# ----------------------------
|
| 447 |
|
| 448 |
if __name__ == "__main__":
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
python inference.py
|
| 452 |
-
"""
|
| 453 |
-
weights_path = "efficientnet_fire.pt" # à adapter si besoin
|
| 454 |
-
|
| 455 |
if not os.path.exists(weights_path):
|
| 456 |
print(f"[ERREUR] Fichier de poids introuvable : {weights_path}")
|
| 457 |
else:
|
| 458 |
model, device = load_model(weights_path)
|
| 459 |
print(f"Modèle chargé sur le device : {device}")
|
| 460 |
|
| 461 |
-
# On utilise la taille par défaut (224) pour ce test
|
| 462 |
transform = get_val_transform()
|
| 463 |
-
|
| 464 |
-
test_image_path = "example.jpg" # à adapter si besoin
|
| 465 |
|
| 466 |
if not os.path.exists(test_image_path):
|
| 467 |
print(f"[INFO] Aucune image test trouvée à : {test_image_path}")
|
| 468 |
else:
|
| 469 |
-
label, prob = predict_from_path(
|
| 470 |
test_image_path,
|
| 471 |
model=model,
|
| 472 |
device=device,
|
|
|
|
| 3 |
------------
|
| 4 |
Module d'inférence pour des modèles de classification binaire : FIRE (1) / NO_FIRE (0).
|
| 5 |
|
| 6 |
+
Supporte :
|
| 7 |
+
- EfficientNet-B0 (classification)
|
| 8 |
+
- Inception v3 (classification)
|
| 9 |
+
- YOLO (Ultralytics, détection) pour localiser le feu
|
| 10 |
|
| 11 |
+
Retour principal des fonctions de prédiction :
|
| 12 |
+
predicted_label (str), fire_prob (float), annotated_image (PIL.Image ou None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
"""
|
| 14 |
|
| 15 |
# ----------------------------
|
| 16 |
# 1) Imports
|
| 17 |
# ----------------------------
|
| 18 |
+
import os
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
from torchvision import transforms
|
| 23 |
+
from PIL import Image
|
| 24 |
+
import timm
|
| 25 |
+
|
| 26 |
+
# YOLO (Ultralytics) – optionnel
|
| 27 |
+
try:
|
| 28 |
+
from ultralytics import YOLO
|
| 29 |
+
except ImportError:
|
| 30 |
+
YOLO = None # si la lib n'est pas installée, on gère ça proprement
|
| 31 |
|
| 32 |
|
| 33 |
# ----------------------------
|
| 34 |
# 2) Constantes globales
|
| 35 |
# ----------------------------
|
| 36 |
|
| 37 |
+
# Taille d'entrée par défaut si rien n'est précisé
|
| 38 |
+
DEFAULT_IMAGE_SIZE = 224
|
| 39 |
|
| 40 |
+
# Stats ImageNet
|
| 41 |
+
IMAGENET_MEAN = [0.485, 0.456, 0.406]
|
| 42 |
+
IMAGENET_STD = [0.229, 0.224, 0.225]
|
| 43 |
|
| 44 |
+
# Mapping idx -> label lisible
|
| 45 |
IDX_TO_LABEL = {
|
| 46 |
+
0: "no_fire",
|
| 47 |
+
1: "fire",
|
| 48 |
}
|
| 49 |
|
| 50 |
+
# Registre des modèles de classification (timm)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
MODEL_REGISTRY = {
|
| 52 |
"efficientnet_b0": {
|
| 53 |
"timm_name": "efficientnet_b0",
|
|
|
|
| 61 |
},
|
| 62 |
}
|
| 63 |
|
| 64 |
+
# Modèle par défaut si on ne sait pas quoi choisir
|
| 65 |
DEFAULT_MODEL_KEY = "efficientnet_b0"
|
| 66 |
|
| 67 |
+
# IDs des classes "feu" pour YOLO (à adapter si besoin après entraînement)
|
| 68 |
+
# Exemple : si model.names == {0: 'fire'} → [0]
|
| 69 |
+
# si model.names == {0: 'no_fire', 1: 'fire'} → [1]
|
| 70 |
+
FIRE_CLASS_IDS = [0]
|
| 71 |
+
|
| 72 |
|
| 73 |
# ----------------------------
|
| 74 |
+
# 3) Device
|
| 75 |
# ----------------------------
|
| 76 |
|
| 77 |
def get_device():
|
|
|
|
| 82 |
"""
|
| 83 |
if torch.cuda.is_available():
|
| 84 |
return torch.device("cuda")
|
| 85 |
+
return torch.device("cpu")
|
|
|
|
| 86 |
|
| 87 |
|
| 88 |
# ----------------------------
|
|
|
|
| 92 |
def infer_model_key_from_path(weights_path: str) -> str:
|
| 93 |
"""
|
| 94 |
Devine une clé de modèle (model_key) à partir du nom de fichier.
|
| 95 |
+
|
| 96 |
Exemple :
|
| 97 |
+
- "yolov8_fire.pt" → "yolo"
|
| 98 |
+
- "inception3_fire.pt" → "inception_v3"
|
| 99 |
+
- "efficientnet_fire.pt" → "efficientnet_b0" (par défaut)
|
| 100 |
"""
|
| 101 |
filename = os.path.basename(weights_path).lower()
|
| 102 |
|
| 103 |
+
if "yolo" in filename:
|
| 104 |
+
return "yolo"
|
| 105 |
+
|
| 106 |
if "inception" in filename:
|
| 107 |
return "inception_v3"
|
| 108 |
|
|
|
|
| 109 |
return DEFAULT_MODEL_KEY
|
| 110 |
|
| 111 |
|
| 112 |
def get_model_config(model_key: str) -> dict:
|
| 113 |
"""
|
| 114 |
Renvoie la config du modèle à partir d'une model_key.
|
| 115 |
+
Fallback : EfficientNet-B0 si model_key inconnue.
|
| 116 |
"""
|
| 117 |
if model_key in MODEL_REGISTRY:
|
| 118 |
return MODEL_REGISTRY[model_key]
|
|
|
|
| 119 |
return MODEL_REGISTRY[DEFAULT_MODEL_KEY]
|
| 120 |
|
| 121 |
|
| 122 |
# ----------------------------
|
| 123 |
+
# 5) Construction modèle (classification)
|
| 124 |
# ----------------------------
|
| 125 |
|
| 126 |
def build_model(model_key: str, num_classes: int = 2) -> torch.nn.Module:
|
| 127 |
"""
|
| 128 |
+
Construit un modèle de classification (EfficientNet, Inception...)
|
| 129 |
+
et adapte la dernière couche à num_classes sorties.
|
| 130 |
"""
|
| 131 |
config = get_model_config(model_key)
|
| 132 |
timm_name = config["timm_name"]
|
| 133 |
classifier_attr = config["classifier_attr"]
|
| 134 |
|
|
|
|
| 135 |
model = timm.create_model(timm_name, pretrained=False)
|
| 136 |
|
|
|
|
|
|
|
| 137 |
classifier = getattr(model, classifier_attr)
|
| 138 |
|
|
|
|
| 139 |
if isinstance(classifier, nn.Linear):
|
| 140 |
in_features = classifier.in_features
|
| 141 |
else:
|
|
|
|
| 142 |
raise ValueError(
|
| 143 |
f"Impossible de déterminer in_features pour la tête du modèle '{timm_name}'. "
|
| 144 |
f"Attribut '{classifier_attr}' de type {type(classifier)} non supporté."
|
| 145 |
)
|
| 146 |
|
|
|
|
| 147 |
new_classifier = nn.Linear(in_features, num_classes)
|
| 148 |
setattr(model, classifier_attr, new_classifier)
|
| 149 |
|
|
|
|
| 151 |
|
| 152 |
|
| 153 |
# ----------------------------
|
| 154 |
+
# 6) Nettoyage de state_dict
|
| 155 |
# ----------------------------
|
| 156 |
|
| 157 |
def _clean_state_dict_keys(state_dict: dict) -> dict:
|
| 158 |
"""
|
| 159 |
+
Nettoie les clés pour gérer les prefixes 'model.' (Lightning) et 'module.' (DataParallel).
|
|
|
|
|
|
|
| 160 |
"""
|
| 161 |
+
new_state = {}
|
| 162 |
for k, v in state_dict.items():
|
| 163 |
new_key = k
|
| 164 |
if new_key.startswith("model."):
|
| 165 |
new_key = new_key[len("model."):]
|
| 166 |
if new_key.startswith("module."):
|
| 167 |
new_key = new_key[len("module."):]
|
| 168 |
+
new_state[new_key] = v
|
| 169 |
+
return new_state
|
| 170 |
|
| 171 |
|
| 172 |
+
# ----------------------------
|
| 173 |
+
# 7) Chargement du modèle
|
| 174 |
+
# ----------------------------
|
| 175 |
+
|
| 176 |
def load_model(weights_path: str, map_location=None, model_key: str | None = None):
|
| 177 |
"""
|
| 178 |
Charge un modèle avec les poids entraînés.
|
| 179 |
|
| 180 |
+
- Pour YOLO (Ultralytics) : charge un modèle de détection.
|
| 181 |
+
- Pour EfficientNet / Inception : charge un modèle de classification binaire.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
|
| 183 |
+
Retour :
|
| 184 |
+
--------
|
| 185 |
+
model : torch.nn.Module ou YOLO
|
|
|
|
| 186 |
device : torch.device
|
|
|
|
| 187 |
"""
|
|
|
|
| 188 |
device = map_location if map_location is not None else get_device()
|
| 189 |
|
| 190 |
+
# Détecter le type de modèle si non fourni
|
| 191 |
if model_key is None:
|
| 192 |
model_key = infer_model_key_from_path(weights_path)
|
| 193 |
|
| 194 |
+
# 🔹 Cas YOLO : modèle de détection (Ultralytics)
|
| 195 |
+
if model_key == "yolo":
|
| 196 |
+
if YOLO is None:
|
| 197 |
+
raise ImportError(
|
| 198 |
+
"Le modèle YOLO est demandé mais la librairie 'ultralytics' "
|
| 199 |
+
"n'est pas installée. Ajoutez 'ultralytics' dans requirements.txt."
|
| 200 |
+
)
|
| 201 |
+
yolo_model = YOLO(weights_path)
|
| 202 |
+
# YOLO gère déjà souvent le device en interne, on essaye juste par sécurité
|
| 203 |
+
try:
|
| 204 |
+
yolo_model.to(device)
|
| 205 |
+
except Exception:
|
| 206 |
+
pass
|
| 207 |
+
return yolo_model, device
|
| 208 |
+
|
| 209 |
+
# 🔹 Cas classification (EfficientNet, Inception...)
|
| 210 |
checkpoint = torch.load(weights_path, map_location=device)
|
| 211 |
|
|
|
|
| 212 |
if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
|
| 213 |
state_dict = checkpoint["state_dict"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
else:
|
|
|
|
| 215 |
state_dict = checkpoint
|
| 216 |
|
|
|
|
| 217 |
state_dict = _clean_state_dict_keys(state_dict)
|
| 218 |
|
|
|
|
| 219 |
model = build_model(model_key=model_key, num_classes=2)
|
|
|
|
|
|
|
| 220 |
missing, unexpected = model.load_state_dict(state_dict, strict=False)
|
| 221 |
|
| 222 |
+
# (optionnel) debug :
|
| 223 |
# print("Missing keys:", missing)
|
| 224 |
# print("Unexpected keys:", unexpected)
|
| 225 |
|
|
|
|
| 226 |
model = model.to(device)
|
| 227 |
model.eval()
|
| 228 |
|
|
|
|
| 230 |
|
| 231 |
|
| 232 |
# ----------------------------
|
| 233 |
+
# 8) Transforms pour l'inférence
|
| 234 |
# ----------------------------
|
| 235 |
|
| 236 |
def get_val_transform(image_size: int | None = None):
|
| 237 |
"""
|
| 238 |
Renvoie les transformations à appliquer aux images pour l'inférence.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
|
| 240 |
+
Si image_size est None → DEFAULT_IMAGE_SIZE (224).
|
|
|
|
|
|
|
| 241 |
"""
|
| 242 |
if image_size is None:
|
| 243 |
image_size = DEFAULT_IMAGE_SIZE
|
|
|
|
| 251 |
|
| 252 |
|
| 253 |
# ----------------------------
|
| 254 |
+
# 9) Prétraitement d'une image
|
| 255 |
# ----------------------------
|
| 256 |
|
| 257 |
def preprocess_image(image: Image.Image, transform=None, image_size: int | None = None):
|
| 258 |
"""
|
| 259 |
Applique les transforms à une image PIL et ajoute une dimension batch.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
"""
|
| 261 |
if transform is None:
|
| 262 |
transform = get_val_transform(image_size=image_size)
|
|
|
|
| 268 |
|
| 269 |
|
| 270 |
# ----------------------------
|
| 271 |
+
# 10) Prédiction depuis un tenseur (classification)
|
| 272 |
# ----------------------------
|
| 273 |
|
| 274 |
def predict_from_tensor(
|
| 275 |
image_tensor: torch.Tensor,
|
| 276 |
model: torch.nn.Module,
|
| 277 |
device: torch.device,
|
| 278 |
+
threshold: float = 0.5,
|
| 279 |
):
|
| 280 |
"""
|
| 281 |
+
Prédit la classe (fire/no_fire) à partir d'un tenseur déjà prétraité
|
| 282 |
+
pour les modèles de classification (EfficientNet, Inception...).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
"""
|
| 284 |
image_tensor = image_tensor.to(device)
|
| 285 |
|
| 286 |
with torch.no_grad():
|
| 287 |
outputs = model(image_tensor) # logits [1, 2]
|
| 288 |
+
probs = torch.softmax(outputs, dim=1) # probas
|
| 289 |
|
| 290 |
fire_prob = probs[0, 1].item()
|
| 291 |
+
predicted_idx = 1 if fire_prob >= threshold else 0
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
|
| 293 |
predicted_label = IDX_TO_LABEL[predicted_idx]
|
|
|
|
| 294 |
return predicted_label, fire_prob
|
| 295 |
|
| 296 |
|
| 297 |
+
# ----------------------------
|
| 298 |
+
# 11) Prédiction depuis une image PIL
|
| 299 |
+
# ----------------------------
|
| 300 |
+
|
| 301 |
def predict_from_pil(
|
| 302 |
image: Image.Image,
|
| 303 |
+
model,
|
| 304 |
device: torch.device,
|
| 305 |
transform=None,
|
| 306 |
threshold: float = 0.5,
|
|
|
|
| 309 |
"""
|
| 310 |
Prédit la classe à partir d'une image PIL.
|
| 311 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
Retour
|
| 313 |
------
|
| 314 |
predicted_label : str
|
| 315 |
"fire" ou "no_fire".
|
| 316 |
fire_prob : float
|
| 317 |
Probabilité de "fire".
|
| 318 |
+
annotated_image : PIL.Image.Image ou None
|
| 319 |
+
Image annotée (bounding boxes) pour les modèles de détection (YOLO),
|
| 320 |
+
None pour les modèles de classification.
|
| 321 |
"""
|
| 322 |
if image.mode != "RGB":
|
| 323 |
image = image.convert("RGB")
|
| 324 |
|
| 325 |
+
# 🔹 Cas YOLO : modèle de détection (Ultralytics)
|
| 326 |
+
if hasattr(model, "task") and getattr(model, "task", None) == "detect":
|
| 327 |
+
results = model(image)
|
| 328 |
+
result = results[0]
|
| 329 |
+
|
| 330 |
+
boxes = getattr(result, "boxes", None)
|
| 331 |
+
fire_prob = 0.0
|
| 332 |
+
|
| 333 |
+
if boxes is not None and len(boxes) > 0:
|
| 334 |
+
classes = boxes.cls # ids des classes (tensor)
|
| 335 |
+
confs = boxes.conf # scores de confiance (tensor)
|
| 336 |
+
|
| 337 |
+
# masque des boxes "feu"
|
| 338 |
+
mask_fire = torch.zeros_like(classes, dtype=torch.bool)
|
| 339 |
+
for cid in FIRE_CLASS_IDS:
|
| 340 |
+
mask_fire |= (classes == cid)
|
| 341 |
+
|
| 342 |
+
if mask_fire.any():
|
| 343 |
+
fire_prob = float(confs[mask_fire].max().item())
|
| 344 |
+
|
| 345 |
+
predicted_label = "fire" if fire_prob >= threshold else "no_fire"
|
| 346 |
+
|
| 347 |
+
# Image annotée avec les bounding boxes
|
| 348 |
+
annotated_image = None
|
| 349 |
+
try:
|
| 350 |
+
annotated_array = result.plot() # numpy array BGR
|
| 351 |
+
annotated_image = Image.fromarray(annotated_array[..., ::-1]) # BGR -> RGB
|
| 352 |
+
except Exception:
|
| 353 |
+
annotated_image = None
|
| 354 |
+
|
| 355 |
+
return predicted_label, fire_prob, annotated_image
|
| 356 |
+
|
| 357 |
+
# 🔹 Cas classification classique
|
| 358 |
image_tensor = preprocess_image(image, transform=transform, image_size=image_size)
|
| 359 |
+
predicted_label, fire_prob = predict_from_tensor(image_tensor, model, device, threshold=threshold)
|
| 360 |
+
annotated_image = None
|
| 361 |
+
|
| 362 |
+
return predicted_label, fire_prob, annotated_image
|
| 363 |
|
|
|
|
| 364 |
|
| 365 |
+
# ----------------------------
|
| 366 |
+
# 12) Prédiction depuis un chemin de fichier
|
| 367 |
+
# ----------------------------
|
| 368 |
|
| 369 |
def predict_from_path(
|
| 370 |
image_path: str,
|
| 371 |
+
model,
|
| 372 |
device: torch.device,
|
| 373 |
transform=None,
|
| 374 |
threshold: float = 0.5,
|
|
|
|
| 376 |
):
|
| 377 |
"""
|
| 378 |
Prédit la classe à partir d'un chemin vers une image.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 379 |
"""
|
| 380 |
image = Image.open(image_path)
|
| 381 |
return predict_from_pil(
|
| 382 |
+
image=image,
|
| 383 |
model=model,
|
| 384 |
device=device,
|
| 385 |
transform=transform,
|
|
|
|
| 389 |
|
| 390 |
|
| 391 |
# ----------------------------
|
| 392 |
+
# 13) Exemple d'utilisation en script direct
|
| 393 |
# ----------------------------
|
| 394 |
|
| 395 |
if __name__ == "__main__":
|
| 396 |
+
# Petit test local (à adapter)
|
| 397 |
+
weights_path = "efficientnet_fire.pt"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 398 |
if not os.path.exists(weights_path):
|
| 399 |
print(f"[ERREUR] Fichier de poids introuvable : {weights_path}")
|
| 400 |
else:
|
| 401 |
model, device = load_model(weights_path)
|
| 402 |
print(f"Modèle chargé sur le device : {device}")
|
| 403 |
|
|
|
|
| 404 |
transform = get_val_transform()
|
| 405 |
+
test_image_path = "example.jpg"
|
|
|
|
| 406 |
|
| 407 |
if not os.path.exists(test_image_path):
|
| 408 |
print(f"[INFO] Aucune image test trouvée à : {test_image_path}")
|
| 409 |
else:
|
| 410 |
+
label, prob, _ = predict_from_path(
|
| 411 |
test_image_path,
|
| 412 |
model=model,
|
| 413 |
device=device,
|
requirements.txt
CHANGED
|
Binary files a/requirements.txt and b/requirements.txt differ
|
|
|