ninafr8175 commited on
Commit
2e0d90d
·
1 Parent(s): bf91db4

update dashboard 6

Browse files
Files changed (4) hide show
  1. .gitignore +2 -0
  2. app.py +29 -3
  3. inference.py +143 -202
  4. requirements.txt +0 -0
.gitignore CHANGED
@@ -11,6 +11,8 @@ __pycache__/
11
  *.pptx
12
  .dockerignore
13
  app_v0.py
 
14
  inference_v0.py
 
15
  Dockerfile
16
  Notes.txt
 
11
  *.pptx
12
  .dockerignore
13
  app_v0.py
14
+ app_v1.py
15
  inference_v0.py
16
+ inference_v1.py
17
  Dockerfile
18
  Notes.txt
app.py CHANGED
@@ -14,6 +14,7 @@ ALL_MODEL_FILES = {
14
  "Modèle EfficientnetB0 Baseline": "efficientnet_fire.pt",
15
  "Modèle EfficientnetB0 FE": "efficientnet_fire_2.pt",
16
  "Modèle EfficientnetB0 FT": "efficientnet_fire_3.pt",
 
17
  "Modèle Inception3": "inception3_fire.pt",
18
  }
19
 
@@ -55,7 +56,22 @@ MODEL_METRICS = {
55
  "fp": 440,
56
  "fn": 680,
57
  },
58
- # "inception3_fire.pt": {...},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  }
60
 
61
  # -----------------------------------------------------------
@@ -164,13 +180,12 @@ else:
164
  # Afficher l'image uploadée
165
  # -------------------------------------------------------
166
  image = Image.open(uploaded_file)
167
- st.image(image, caption="Image chargée", use_container_width=True)
168
 
169
  # -------------------------------------------------------
170
  # Prédiction
171
  # -------------------------------------------------------
172
  with st.spinner("Analyse de l'image en cours..."):
173
- label, prob = predict_from_pil(
174
  image=image,
175
  model=model,
176
  device=device,
@@ -178,6 +193,17 @@ else:
178
  threshold=threshold
179
  )
180
 
 
 
 
 
 
 
 
 
 
 
 
181
  # -------------------------------------------------------
182
  # Affichage du résultat avec couleur
183
  # -------------------------------------------------------
 
14
  "Modèle EfficientnetB0 Baseline": "efficientnet_fire.pt",
15
  "Modèle EfficientnetB0 FE": "efficientnet_fire_2.pt",
16
  "Modèle EfficientnetB0 FT": "efficientnet_fire_3.pt",
17
+ "Modèle YOLOv8": "yolov8_fire.pt",
18
  "Modèle Inception3": "inception3_fire.pt",
19
  }
20
 
 
56
  "fp": 440,
57
  "fn": 680,
58
  },
59
+ # "yolov8_fire.pt": {
60
+ # "accuracy": 0.7717,
61
+ # "precision": 0.8713,
62
+ # "recall": 0.8142,
63
+ # "f1": 0.8418,
64
+ # "fp": 440,
65
+ # "fn": 680,
66
+ # },
67
+ # "inception3_fire.pt": {
68
+ # "accuracy": 0.7717,
69
+ # "precision": 0.8713,
70
+ # "recall": 0.8142,
71
+ # "f1": 0.8418,
72
+ # "fp": 440,
73
+ # "fn": 680,
74
+ # },
75
  }
76
 
77
  # -----------------------------------------------------------
 
180
  # Afficher l'image uploadée
181
  # -------------------------------------------------------
182
  image = Image.open(uploaded_file)
 
183
 
184
  # -------------------------------------------------------
185
  # Prédiction
186
  # -------------------------------------------------------
187
  with st.spinner("Analyse de l'image en cours..."):
188
+ label, prob, annotated_image = predict_from_pil(
189
  image=image,
190
  model=model,
191
  device=device,
 
193
  threshold=threshold
194
  )
195
 
196
+ # Image originale
197
+ st.image(image, caption="Image chargée", use_container_width=True)
198
+
199
+ # Si YOLO a fourni une image annotée, on l'affiche
200
+ if annotated_image is not None:
201
+ st.image(
202
+ annotated_image,
203
+ caption="Zones détectées (modèle de détection)",
204
+ use_container_width=True,
205
+ )
206
+
207
  # -------------------------------------------------------
208
  # Affichage du résultat avec couleur
209
  # -------------------------------------------------------
inference.py CHANGED
@@ -3,59 +3,51 @@ inference.py
3
  ------------
4
  Module d'inférence pour des modèles de classification binaire : FIRE (1) / NO_FIRE (0).
5
 
6
- Compatible avec plusieurs architectures (ex. EfficientNet-B0, Inception v3),
7
- sans changer le code dès qu'on ajoute un nouveau fichier .pt, à condition :
 
 
8
 
9
- - que le nom du fichier permette de deviner le type de modèle
10
- (ex : "inception3_fire.pt" Inception v3,
11
- sinon → EfficientNet-B0 par défaut),
12
- - ou que le fichier .pt contienne un dictionnaire avec "state_dict" (et
13
- éventuellement "model_name" / "model_key").
14
-
15
- Usage typique :
16
- ---------------
17
- from inference import load_model, get_val_transform, predict_from_path
18
-
19
- model, device = load_model("efficientnet_fire.pt")
20
- transform = get_val_transform() # ou get_val_transform(image_size_personnalisée)
21
-
22
- label, prob = predict_from_path("mon_image.jpg", model, device, transform)
23
- print(label, prob)
24
  """
25
 
26
  # ----------------------------
27
  # 1) Imports
28
  # ----------------------------
29
- import os # pour gérer les chemins de fichiers
30
- import torch # bibliothèque principale pour le deep learning
31
- import torch.nn as nn # pour définir les têtes de classification
32
- from torchvision import transforms # pour les pré-traitements d'images
33
- from PIL import Image # pour charger les images depuis un fichier
34
- import timm # pour charger des architectures (EfficientNet, Inception, etc.)
 
 
 
 
 
 
 
35
 
36
 
37
  # ----------------------------
38
  # 2) Constantes globales
39
  # ----------------------------
40
 
41
- # Taille d'entrée par défaut (utilisée si on ne précise rien d'autre).
42
- DEFAULT_IMAGE_SIZE = 224 # (224 x 224 pixels)
43
 
44
- # Moyennes et écarts-types d'ImageNet (pour normaliser les images)
45
- IMAGENET_MEAN = [0.485, 0.456, 0.406] # moyenne des canaux R, G, B
46
- IMAGENET_STD = [0.229, 0.224, 0.225] # écart-type des canaux R, G, B
47
 
48
- # Mapping des classes numériques vers des labels lisibles
49
  IDX_TO_LABEL = {
50
- 0: "no_fire", # classe 0 → pas de feu
51
- 1: "fire" # classe 1 → feu
52
  }
53
 
54
- # Registre des modèles supportés.
55
- # - model_key : identifiant interne (nos clés)
56
- # - timm_name : nom utilisé dans timm.create_model(...)
57
- # - image_size : taille d'entrée "recommandée" pour ce modèle (optionnelle)
58
- # - classifier_attr : nom de l'attribut contenant la dernière couche de classification
59
  MODEL_REGISTRY = {
60
  "efficientnet_b0": {
61
  "timm_name": "efficientnet_b0",
@@ -69,12 +61,17 @@ MODEL_REGISTRY = {
69
  },
70
  }
71
 
72
- # Si on ne reconnaît pas de modèle particulier, on utilisera EfficientNet-B0 par défaut.
73
  DEFAULT_MODEL_KEY = "efficientnet_b0"
74
 
 
 
 
 
 
75
 
76
  # ----------------------------
77
- # 3) Utilitaires device
78
  # ----------------------------
79
 
80
  def get_device():
@@ -85,8 +82,7 @@ def get_device():
85
  """
86
  if torch.cuda.is_available():
87
  return torch.device("cuda")
88
- else:
89
- return torch.device("cpu")
90
 
91
 
92
  # ----------------------------
@@ -96,62 +92,58 @@ def get_device():
96
  def infer_model_key_from_path(weights_path: str) -> str:
97
  """
98
  Devine une clé de modèle (model_key) à partir du nom de fichier.
 
99
  Exemple :
100
- - "inception3_fire.pt" → "inception_v3"
101
- - "efficientnet_fire.pt" → "efficientnet_b0" (par défaut)
 
102
  """
103
  filename = os.path.basename(weights_path).lower()
104
 
105
- # Règle simple : si "inception" est dans le nom → Inception v3
 
 
106
  if "inception" in filename:
107
  return "inception_v3"
108
 
109
- # Sinon, par défaut → EfficientNet-B0
110
  return DEFAULT_MODEL_KEY
111
 
112
 
113
  def get_model_config(model_key: str) -> dict:
114
  """
115
  Renvoie la config du modèle à partir d'une model_key.
116
- Si model_key n'est pas connue, on retourne la config du modèle par défaut.
117
  """
118
  if model_key in MODEL_REGISTRY:
119
  return MODEL_REGISTRY[model_key]
120
- # Fallback de sécurité : ne jamais casser si clé inconnue
121
  return MODEL_REGISTRY[DEFAULT_MODEL_KEY]
122
 
123
 
124
  # ----------------------------
125
- # 5) Construction de l'architecture
126
  # ----------------------------
127
 
128
  def build_model(model_key: str, num_classes: int = 2) -> torch.nn.Module:
129
  """
130
- Construit l'architecture correspondant à model_key
131
- et adapte la tête de classification à num_classes sorties.
132
  """
133
  config = get_model_config(model_key)
134
  timm_name = config["timm_name"]
135
  classifier_attr = config["classifier_attr"]
136
 
137
- # On crée le backbone via timm (sans pré-entraînement, les poids viendront du .pt)
138
  model = timm.create_model(timm_name, pretrained=False)
139
 
140
- # On adapte la tête de classification à notre problème binaire.
141
- # On récupère l'ancienne couche de classification (fc, classifier, etc.)
142
  classifier = getattr(model, classifier_attr)
143
 
144
- # Certains modèles ont déjà un nn.Linear, on récupère in_features
145
  if isinstance(classifier, nn.Linear):
146
  in_features = classifier.in_features
147
  else:
148
- # Cas plus exotique : on essaie de deviner proprement, sinon on lève une erreur claire.
149
  raise ValueError(
150
  f"Impossible de déterminer in_features pour la tête du modèle '{timm_name}'. "
151
  f"Attribut '{classifier_attr}' de type {type(classifier)} non supporté."
152
  )
153
 
154
- # On remplace par une couche linéaire adaptée à notre nombre de classes
155
  new_classifier = nn.Linear(in_features, num_classes)
156
  setattr(model, classifier_attr, new_classifier)
157
 
@@ -159,89 +151,78 @@ def build_model(model_key: str, num_classes: int = 2) -> torch.nn.Module:
159
 
160
 
161
  # ----------------------------
162
- # 6) Chargement du modèle
163
  # ----------------------------
164
 
165
  def _clean_state_dict_keys(state_dict: dict) -> dict:
166
  """
167
- Nettoie les clés d'un state_dict pour gérer plusieurs cas courants :
168
- - clés préfixées par 'model.' (Lightning)
169
- - clés préfixées par 'module.' (DataParallel)
170
  """
171
- new_state_dict = {}
172
  for k, v in state_dict.items():
173
  new_key = k
174
  if new_key.startswith("model."):
175
  new_key = new_key[len("model."):]
176
  if new_key.startswith("module."):
177
  new_key = new_key[len("module."):]
178
- new_state_dict[new_key] = v
179
- return new_state_dict
180
 
181
 
 
 
 
 
182
  def load_model(weights_path: str, map_location=None, model_key: str | None = None):
183
  """
184
  Charge un modèle avec les poids entraînés.
185
 
186
- Paramètres
187
- ----------
188
- weights_path : str
189
- Chemin vers le fichier .pt (state_dict ou dict avec 'state_dict').
190
- map_location : torch.device ou None
191
- Device sur lequel charger les poids. Si None, on détecte automatiquement.
192
- model_key : str ou None
193
- Clé de modèle à utiliser (ex: 'efficientnet_b0', 'inception_v3').
194
- Si None, on essaie de la déduire du nom de fichier.
195
 
196
- Retour
197
- ------
198
- model : torch.nn.Module
199
- Le modèle prêt pour l'inférence.
200
  device : torch.device
201
- Le device utilisé (cuda ou cpu).
202
  """
203
- # 1) Sélection du device
204
  device = map_location if map_location is not None else get_device()
205
 
206
- # 2) Si la model_key n'est pas donnée, on essaye de l'inférer depuis le nom de fichier
207
  if model_key is None:
208
  model_key = infer_model_key_from_path(weights_path)
209
 
210
- # 3) Chargement brut du .pt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  checkpoint = torch.load(weights_path, map_location=device)
212
 
213
- # 4) Si le checkpoint est un dict complet (ex: {'state_dict': ..., 'model_name': ...})
214
  if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
215
  state_dict = checkpoint["state_dict"]
216
-
217
- # Optionnel : si un model_name ou model_key est stocké dedans, on peut le préférer.
218
- if "model_key" in checkpoint and checkpoint["model_key"] in MODEL_REGISTRY:
219
- model_key = checkpoint["model_key"]
220
- elif "model_name" in checkpoint:
221
- # Tentative : si model_name correspond au timm_name d'une entrée du registre
222
- model_name_lower = str(checkpoint["model_name"]).lower()
223
- for k, cfg in MODEL_REGISTRY.items():
224
- if cfg["timm_name"].lower() == model_name_lower:
225
- model_key = k
226
- break
227
  else:
228
- # Cas simple : le fichier .pt est directement un state_dict
229
  state_dict = checkpoint
230
 
231
- # 5) Nettoyage des clés du state_dict (Lightning, DataParallel...)
232
  state_dict = _clean_state_dict_keys(state_dict)
233
 
234
- # 6) Construction de l'architecture adaptée
235
  model = build_model(model_key=model_key, num_classes=2)
236
-
237
- # 7) Application des poids (strict=False pour tolérer quelques petites diff de clés)
238
  missing, unexpected = model.load_state_dict(state_dict, strict=False)
239
 
240
- # (Optionnel) DEBUG : on pourrait afficher missing / unexpected si besoin
241
  # print("Missing keys:", missing)
242
  # print("Unexpected keys:", unexpected)
243
 
244
- # 8) Envoi sur le bon device + mode eval
245
  model = model.to(device)
246
  model.eval()
247
 
@@ -249,21 +230,14 @@ def load_model(weights_path: str, map_location=None, model_key: str | None = Non
249
 
250
 
251
  # ----------------------------
252
- # 7) Transforms pour l'inférence
253
  # ----------------------------
254
 
255
  def get_val_transform(image_size: int | None = None):
256
  """
257
  Renvoie les transformations à appliquer aux images pour l'inférence.
258
- Paramètres
259
- ----------
260
- image_size : int ou None
261
- Si None, on utilise DEFAULT_IMAGE_SIZE (224).
262
- Sinon, on redimensionne en (image_size, image_size).
263
 
264
- Retour
265
- ------
266
- transform : torchvision.transforms.Compose
267
  """
268
  if image_size is None:
269
  image_size = DEFAULT_IMAGE_SIZE
@@ -277,26 +251,12 @@ def get_val_transform(image_size: int | None = None):
277
 
278
 
279
  # ----------------------------
280
- # 8) Prétraitement d'une image
281
  # ----------------------------
282
 
283
  def preprocess_image(image: Image.Image, transform=None, image_size: int | None = None):
284
  """
285
  Applique les transforms à une image PIL et ajoute une dimension batch.
286
-
287
- Paramètres
288
- ----------
289
- image : PIL.Image.Image
290
- Image brute chargée (par exemple via Image.open(...)).
291
- transform : callable ou None
292
- Transformations à appliquer (si None, on utilise get_val_transform(image_size)).
293
- image_size : int ou None
294
- Taille de redimensionnement si transform est None.
295
-
296
- Retour
297
- ------
298
- image_tensor : torch.Tensor
299
- Tenseur prêt pour l'inférence, de taille [1, 3, H, W].
300
  """
301
  if transform is None:
302
  transform = get_val_transform(image_size=image_size)
@@ -308,57 +268,39 @@ def preprocess_image(image: Image.Image, transform=None, image_size: int | None
308
 
309
 
310
  # ----------------------------
311
- # 9) Fonction de prédiction principale
312
  # ----------------------------
313
 
314
  def predict_from_tensor(
315
  image_tensor: torch.Tensor,
316
  model: torch.nn.Module,
317
  device: torch.device,
318
- threshold: float = 0.5
319
  ):
320
  """
321
- Prédit la classe (fire/no_fire) à partir d'un tenseur déjà prétraité.
322
-
323
- Paramètres
324
- ----------
325
- image_tensor : torch.Tensor
326
- Tenseur d'images de taille [1, 3, H, W] (batch de 1 image).
327
- model : torch.nn.Module
328
- Modèle chargé (EfficientNet, Inception, etc.).
329
- device : torch.device
330
- Device sur lequel le modèle est (cuda ou cpu).
331
- threshold : float
332
- Seuil sur la probabilité de FEU pour décider entre no_fire / fire.
333
-
334
- Retour
335
- ------
336
- predicted_label : str
337
- "fire" ou "no_fire".
338
- fire_prob : float
339
- Probabilité prédite pour la classe "fire" (entre 0 et 1).
340
  """
341
  image_tensor = image_tensor.to(device)
342
 
343
  with torch.no_grad():
344
  outputs = model(image_tensor) # logits [1, 2]
345
- probs = torch.softmax(outputs, dim=1) # probabilités
346
 
347
  fire_prob = probs[0, 1].item()
348
-
349
- if fire_prob >= threshold:
350
- predicted_idx = 1 # feu
351
- else:
352
- predicted_idx = 0 # pas de feu
353
 
354
  predicted_label = IDX_TO_LABEL[predicted_idx]
355
-
356
  return predicted_label, fire_prob
357
 
358
 
 
 
 
 
359
  def predict_from_pil(
360
  image: Image.Image,
361
- model: torch.nn.Module,
362
  device: torch.device,
363
  transform=None,
364
  threshold: float = 0.5,
@@ -367,39 +309,66 @@ def predict_from_pil(
367
  """
368
  Prédit la classe à partir d'une image PIL.
369
 
370
- Paramètres
371
- ----------
372
- image : PIL.Image.Image
373
- Image chargée (par exemple via Image.open).
374
- model : torch.nn.Module
375
- Modèle chargé.
376
- device : torch.device
377
- Device (cuda ou cpu).
378
- transform : callable ou None
379
- Transformations à appliquer à l'image.
380
- threshold : float
381
- Seuil sur la probabilité de FEU.
382
- image_size : int ou None
383
- Taille utilisée si transform est None.
384
-
385
  Retour
386
  ------
387
  predicted_label : str
388
  "fire" ou "no_fire".
389
  fire_prob : float
390
  Probabilité de "fire".
 
 
 
391
  """
392
  if image.mode != "RGB":
393
  image = image.convert("RGB")
394
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
  image_tensor = preprocess_image(image, transform=transform, image_size=image_size)
 
 
 
 
396
 
397
- return predict_from_tensor(image_tensor, model, device, threshold=threshold)
398
 
 
 
 
399
 
400
  def predict_from_path(
401
  image_path: str,
402
- model: torch.nn.Module,
403
  device: torch.device,
404
  transform=None,
405
  threshold: float = 0.5,
@@ -407,32 +376,10 @@ def predict_from_path(
407
  ):
408
  """
409
  Prédit la classe à partir d'un chemin vers une image.
410
-
411
- Paramètres
412
- ----------
413
- image_path : str
414
- Chemin vers le fichier image (jpg, png, etc.).
415
- model : torch.nn.Module
416
- Modèle chargé.
417
- device : torch.device
418
- Device (cuda ou cpu).
419
- transform : callable ou None
420
- Transformations à appliquer.
421
- threshold : float
422
- Seuil sur la probabilité de FEU.
423
- image_size : int ou None
424
- Taille utilisée si transform est None.
425
-
426
- Retour
427
- ------
428
- predicted_label : str
429
- "fire" ou "no_fire".
430
- fire_prob : float
431
- Probabilité de "fire".
432
  """
433
  image = Image.open(image_path)
434
  return predict_from_pil(
435
- image,
436
  model=model,
437
  device=device,
438
  transform=transform,
@@ -442,31 +389,25 @@ def predict_from_path(
442
 
443
 
444
  # ----------------------------
445
- # 10) Exemple d'utilisation en script direct
446
  # ----------------------------
447
 
448
  if __name__ == "__main__":
449
- """
450
- Exemple simple pour tester le module en local :
451
- python inference.py
452
- """
453
- weights_path = "efficientnet_fire.pt" # à adapter si besoin
454
-
455
  if not os.path.exists(weights_path):
456
  print(f"[ERREUR] Fichier de poids introuvable : {weights_path}")
457
  else:
458
  model, device = load_model(weights_path)
459
  print(f"Modèle chargé sur le device : {device}")
460
 
461
- # On utilise la taille par défaut (224) pour ce test
462
  transform = get_val_transform()
463
-
464
- test_image_path = "example.jpg" # à adapter si besoin
465
 
466
  if not os.path.exists(test_image_path):
467
  print(f"[INFO] Aucune image test trouvée à : {test_image_path}")
468
  else:
469
- label, prob = predict_from_path(
470
  test_image_path,
471
  model=model,
472
  device=device,
 
3
  ------------
4
  Module d'inférence pour des modèles de classification binaire : FIRE (1) / NO_FIRE (0).
5
 
6
+ Supporte :
7
+ - EfficientNet-B0 (classification)
8
+ - Inception v3 (classification)
9
+ - YOLO (Ultralytics, détection) pour localiser le feu
10
 
11
+ Retour principal des fonctions de prédiction :
12
+ predicted_label (str), fire_prob (float), annotated_image (PIL.Image ou None)
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  """
14
 
15
  # ----------------------------
16
  # 1) Imports
17
  # ----------------------------
18
+ import os
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ from torchvision import transforms
23
+ from PIL import Image
24
+ import timm
25
+
26
+ # YOLO (Ultralytics) – optionnel
27
+ try:
28
+ from ultralytics import YOLO
29
+ except ImportError:
30
+ YOLO = None # si la lib n'est pas installée, on gère ça proprement
31
 
32
 
33
  # ----------------------------
34
  # 2) Constantes globales
35
  # ----------------------------
36
 
37
+ # Taille d'entrée par défaut si rien n'est précisé
38
+ DEFAULT_IMAGE_SIZE = 224
39
 
40
+ # Stats ImageNet
41
+ IMAGENET_MEAN = [0.485, 0.456, 0.406]
42
+ IMAGENET_STD = [0.229, 0.224, 0.225]
43
 
44
+ # Mapping idx -> label lisible
45
  IDX_TO_LABEL = {
46
+ 0: "no_fire",
47
+ 1: "fire",
48
  }
49
 
50
+ # Registre des modèles de classification (timm)
 
 
 
 
51
  MODEL_REGISTRY = {
52
  "efficientnet_b0": {
53
  "timm_name": "efficientnet_b0",
 
61
  },
62
  }
63
 
64
+ # Modèle par défaut si on ne sait pas quoi choisir
65
  DEFAULT_MODEL_KEY = "efficientnet_b0"
66
 
67
+ # IDs des classes "feu" pour YOLO (à adapter si besoin après entraînement)
68
+ # Exemple : si model.names == {0: 'fire'} → [0]
69
+ # si model.names == {0: 'no_fire', 1: 'fire'} → [1]
70
+ FIRE_CLASS_IDS = [0]
71
+
72
 
73
  # ----------------------------
74
+ # 3) Device
75
  # ----------------------------
76
 
77
  def get_device():
 
82
  """
83
  if torch.cuda.is_available():
84
  return torch.device("cuda")
85
+ return torch.device("cpu")
 
86
 
87
 
88
  # ----------------------------
 
92
  def infer_model_key_from_path(weights_path: str) -> str:
93
  """
94
  Devine une clé de modèle (model_key) à partir du nom de fichier.
95
+
96
  Exemple :
97
+ - "yolov8_fire.pt" → "yolo"
98
+ - "inception3_fire.pt" → "inception_v3"
99
+ - "efficientnet_fire.pt" → "efficientnet_b0" (par défaut)
100
  """
101
  filename = os.path.basename(weights_path).lower()
102
 
103
+ if "yolo" in filename:
104
+ return "yolo"
105
+
106
  if "inception" in filename:
107
  return "inception_v3"
108
 
 
109
  return DEFAULT_MODEL_KEY
110
 
111
 
112
  def get_model_config(model_key: str) -> dict:
113
  """
114
  Renvoie la config du modèle à partir d'une model_key.
115
+ Fallback : EfficientNet-B0 si model_key inconnue.
116
  """
117
  if model_key in MODEL_REGISTRY:
118
  return MODEL_REGISTRY[model_key]
 
119
  return MODEL_REGISTRY[DEFAULT_MODEL_KEY]
120
 
121
 
122
  # ----------------------------
123
+ # 5) Construction modèle (classification)
124
  # ----------------------------
125
 
126
  def build_model(model_key: str, num_classes: int = 2) -> torch.nn.Module:
127
  """
128
+ Construit un modèle de classification (EfficientNet, Inception...)
129
+ et adapte la dernière couche à num_classes sorties.
130
  """
131
  config = get_model_config(model_key)
132
  timm_name = config["timm_name"]
133
  classifier_attr = config["classifier_attr"]
134
 
 
135
  model = timm.create_model(timm_name, pretrained=False)
136
 
 
 
137
  classifier = getattr(model, classifier_attr)
138
 
 
139
  if isinstance(classifier, nn.Linear):
140
  in_features = classifier.in_features
141
  else:
 
142
  raise ValueError(
143
  f"Impossible de déterminer in_features pour la tête du modèle '{timm_name}'. "
144
  f"Attribut '{classifier_attr}' de type {type(classifier)} non supporté."
145
  )
146
 
 
147
  new_classifier = nn.Linear(in_features, num_classes)
148
  setattr(model, classifier_attr, new_classifier)
149
 
 
151
 
152
 
153
  # ----------------------------
154
+ # 6) Nettoyage de state_dict
155
  # ----------------------------
156
 
157
  def _clean_state_dict_keys(state_dict: dict) -> dict:
158
  """
159
+ Nettoie les clés pour gérer les prefixes 'model.' (Lightning) et 'module.' (DataParallel).
 
 
160
  """
161
+ new_state = {}
162
  for k, v in state_dict.items():
163
  new_key = k
164
  if new_key.startswith("model."):
165
  new_key = new_key[len("model."):]
166
  if new_key.startswith("module."):
167
  new_key = new_key[len("module."):]
168
+ new_state[new_key] = v
169
+ return new_state
170
 
171
 
172
+ # ----------------------------
173
+ # 7) Chargement du modèle
174
+ # ----------------------------
175
+
176
  def load_model(weights_path: str, map_location=None, model_key: str | None = None):
177
  """
178
  Charge un modèle avec les poids entraînés.
179
 
180
+ - Pour YOLO (Ultralytics) : charge un modèle de détection.
181
+ - Pour EfficientNet / Inception : charge un modèle de classification binaire.
 
 
 
 
 
 
 
182
 
183
+ Retour :
184
+ --------
185
+ model : torch.nn.Module ou YOLO
 
186
  device : torch.device
 
187
  """
 
188
  device = map_location if map_location is not None else get_device()
189
 
190
+ # Détecter le type de modèle si non fourni
191
  if model_key is None:
192
  model_key = infer_model_key_from_path(weights_path)
193
 
194
+ # 🔹 Cas YOLO : modèle de détection (Ultralytics)
195
+ if model_key == "yolo":
196
+ if YOLO is None:
197
+ raise ImportError(
198
+ "Le modèle YOLO est demandé mais la librairie 'ultralytics' "
199
+ "n'est pas installée. Ajoutez 'ultralytics' dans requirements.txt."
200
+ )
201
+ yolo_model = YOLO(weights_path)
202
+ # YOLO gère déjà souvent le device en interne, on essaye juste par sécurité
203
+ try:
204
+ yolo_model.to(device)
205
+ except Exception:
206
+ pass
207
+ return yolo_model, device
208
+
209
+ # 🔹 Cas classification (EfficientNet, Inception...)
210
  checkpoint = torch.load(weights_path, map_location=device)
211
 
 
212
  if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
213
  state_dict = checkpoint["state_dict"]
 
 
 
 
 
 
 
 
 
 
 
214
  else:
 
215
  state_dict = checkpoint
216
 
 
217
  state_dict = _clean_state_dict_keys(state_dict)
218
 
 
219
  model = build_model(model_key=model_key, num_classes=2)
 
 
220
  missing, unexpected = model.load_state_dict(state_dict, strict=False)
221
 
222
+ # (optionnel) debug :
223
  # print("Missing keys:", missing)
224
  # print("Unexpected keys:", unexpected)
225
 
 
226
  model = model.to(device)
227
  model.eval()
228
 
 
230
 
231
 
232
  # ----------------------------
233
+ # 8) Transforms pour l'inférence
234
  # ----------------------------
235
 
236
  def get_val_transform(image_size: int | None = None):
237
  """
238
  Renvoie les transformations à appliquer aux images pour l'inférence.
 
 
 
 
 
239
 
240
+ Si image_size est None → DEFAULT_IMAGE_SIZE (224).
 
 
241
  """
242
  if image_size is None:
243
  image_size = DEFAULT_IMAGE_SIZE
 
251
 
252
 
253
  # ----------------------------
254
+ # 9) Prétraitement d'une image
255
  # ----------------------------
256
 
257
  def preprocess_image(image: Image.Image, transform=None, image_size: int | None = None):
258
  """
259
  Applique les transforms à une image PIL et ajoute une dimension batch.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  """
261
  if transform is None:
262
  transform = get_val_transform(image_size=image_size)
 
268
 
269
 
270
  # ----------------------------
271
+ # 10) Prédiction depuis un tenseur (classification)
272
  # ----------------------------
273
 
274
  def predict_from_tensor(
275
  image_tensor: torch.Tensor,
276
  model: torch.nn.Module,
277
  device: torch.device,
278
+ threshold: float = 0.5,
279
  ):
280
  """
281
+ Prédit la classe (fire/no_fire) à partir d'un tenseur déjà prétraité
282
+ pour les modèles de classification (EfficientNet, Inception...).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  """
284
  image_tensor = image_tensor.to(device)
285
 
286
  with torch.no_grad():
287
  outputs = model(image_tensor) # logits [1, 2]
288
+ probs = torch.softmax(outputs, dim=1) # probas
289
 
290
  fire_prob = probs[0, 1].item()
291
+ predicted_idx = 1 if fire_prob >= threshold else 0
 
 
 
 
292
 
293
  predicted_label = IDX_TO_LABEL[predicted_idx]
 
294
  return predicted_label, fire_prob
295
 
296
 
297
+ # ----------------------------
298
+ # 11) Prédiction depuis une image PIL
299
+ # ----------------------------
300
+
301
  def predict_from_pil(
302
  image: Image.Image,
303
+ model,
304
  device: torch.device,
305
  transform=None,
306
  threshold: float = 0.5,
 
309
  """
310
  Prédit la classe à partir d'une image PIL.
311
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
  Retour
313
  ------
314
  predicted_label : str
315
  "fire" ou "no_fire".
316
  fire_prob : float
317
  Probabilité de "fire".
318
+ annotated_image : PIL.Image.Image ou None
319
+ Image annotée (bounding boxes) pour les modèles de détection (YOLO),
320
+ None pour les modèles de classification.
321
  """
322
  if image.mode != "RGB":
323
  image = image.convert("RGB")
324
 
325
+ # 🔹 Cas YOLO : modèle de détection (Ultralytics)
326
+ if hasattr(model, "task") and getattr(model, "task", None) == "detect":
327
+ results = model(image)
328
+ result = results[0]
329
+
330
+ boxes = getattr(result, "boxes", None)
331
+ fire_prob = 0.0
332
+
333
+ if boxes is not None and len(boxes) > 0:
334
+ classes = boxes.cls # ids des classes (tensor)
335
+ confs = boxes.conf # scores de confiance (tensor)
336
+
337
+ # masque des boxes "feu"
338
+ mask_fire = torch.zeros_like(classes, dtype=torch.bool)
339
+ for cid in FIRE_CLASS_IDS:
340
+ mask_fire |= (classes == cid)
341
+
342
+ if mask_fire.any():
343
+ fire_prob = float(confs[mask_fire].max().item())
344
+
345
+ predicted_label = "fire" if fire_prob >= threshold else "no_fire"
346
+
347
+ # Image annotée avec les bounding boxes
348
+ annotated_image = None
349
+ try:
350
+ annotated_array = result.plot() # numpy array BGR
351
+ annotated_image = Image.fromarray(annotated_array[..., ::-1]) # BGR -> RGB
352
+ except Exception:
353
+ annotated_image = None
354
+
355
+ return predicted_label, fire_prob, annotated_image
356
+
357
+ # 🔹 Cas classification classique
358
  image_tensor = preprocess_image(image, transform=transform, image_size=image_size)
359
+ predicted_label, fire_prob = predict_from_tensor(image_tensor, model, device, threshold=threshold)
360
+ annotated_image = None
361
+
362
+ return predicted_label, fire_prob, annotated_image
363
 
 
364
 
365
+ # ----------------------------
366
+ # 12) Prédiction depuis un chemin de fichier
367
+ # ----------------------------
368
 
369
  def predict_from_path(
370
  image_path: str,
371
+ model,
372
  device: torch.device,
373
  transform=None,
374
  threshold: float = 0.5,
 
376
  ):
377
  """
378
  Prédit la classe à partir d'un chemin vers une image.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
379
  """
380
  image = Image.open(image_path)
381
  return predict_from_pil(
382
+ image=image,
383
  model=model,
384
  device=device,
385
  transform=transform,
 
389
 
390
 
391
  # ----------------------------
392
+ # 13) Exemple d'utilisation en script direct
393
  # ----------------------------
394
 
395
  if __name__ == "__main__":
396
+ # Petit test local (à adapter)
397
+ weights_path = "efficientnet_fire.pt"
 
 
 
 
398
  if not os.path.exists(weights_path):
399
  print(f"[ERREUR] Fichier de poids introuvable : {weights_path}")
400
  else:
401
  model, device = load_model(weights_path)
402
  print(f"Modèle chargé sur le device : {device}")
403
 
 
404
  transform = get_val_transform()
405
+ test_image_path = "example.jpg"
 
406
 
407
  if not os.path.exists(test_image_path):
408
  print(f"[INFO] Aucune image test trouvée à : {test_image_path}")
409
  else:
410
+ label, prob, _ = predict_from_path(
411
  test_image_path,
412
  model=model,
413
  device=device,
requirements.txt CHANGED
Binary files a/requirements.txt and b/requirements.txt differ