ninafr8175 commited on
Commit
6c28497
·
1 Parent(s): d318834

app update

Browse files
Files changed (3) hide show
  1. .gitignore +5 -0
  2. app.py +58 -28
  3. inference.py +240 -119
.gitignore CHANGED
@@ -7,5 +7,10 @@
7
  # Fichiers générés automatiquement
8
  __pycache__/
9
  *.pyc
 
10
  *.pptx
11
  .dockerignore
 
 
 
 
 
7
  # Fichiers générés automatiquement
8
  __pycache__/
9
  *.pyc
10
+
11
  *.pptx
12
  .dockerignore
13
+ app_v0.py
14
+ inference_v0.py
15
+ Dockerfile
16
+ Notes.txt
app.py CHANGED
@@ -1,6 +1,7 @@
1
- import streamlit as st # librairie pour le dashboard
2
- from PIL import Image # pour ouvrir les images
3
- from inference import ( # fonctions importées du fichier inference.py
 
4
  load_model,
5
  get_val_transform,
6
  predict_from_pil
@@ -10,36 +11,48 @@ from inference import ( # fonctions importées du fichier
10
  # Configuration de la page
11
  # -----------------------------------------------------------
12
  st.set_page_config(
13
- page_title="Fire Detection Dashboard", # titre de l’onglet du navigateur
14
- page_icon="🔥", # icône (emoji)
15
- layout="centered" # mise en page centrée
16
  )
17
 
18
  # -----------------------------------------------------------
19
- # Chargement du modèle (une seule fois)
20
  # -----------------------------------------------------------
21
- @st.cache_resource
22
- def load_app_model():
23
- """
24
- Charge le modèle, le device et la transform une seule fois,
25
- puis les réutilise pour toutes les prédictions.
26
- """
27
- model, device = load_model("efficientnet_fire.pt") # charge les poids
28
- transform = get_val_transform() # transform validation/inférence
29
- return model, device, transform
 
30
 
31
- model, device, transform = load_app_model()
 
 
32
 
33
  # -----------------------------------------------------------
34
- # Sidebar : infos et paramètres
35
  # -----------------------------------------------------------
36
  st.sidebar.title("⚙️ Paramètres")
 
 
 
 
 
 
 
 
 
 
37
  st.sidebar.markdown(
38
  """
39
- Ce dashboard utilise un modèle EfficientNet-B0,
40
- entraîné à prédire **FIRE / NO FIRE** sur des images.
41
 
42
- - Classe 0 : **no_fire**
43
  - Classe 1 : **fire**
44
  """
45
  )
@@ -54,26 +67,41 @@ threshold = st.sidebar.slider(
54
 
55
  st.sidebar.markdown(f"Seuil actuel : **{threshold:.2f}**")
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  # -----------------------------------------------------------
58
  # Titre principal
59
  # -----------------------------------------------------------
60
  st.title("🔥 Fire Detection Dashboard")
61
  st.markdown(
62
  """
63
- Ce prototype permet de tester un modèle de détection de feu,
64
- sur des images individuelles.
65
 
66
  _Charger une image pour obtenir une prédiction._
67
  """
68
  )
69
 
70
  # -----------------------------------------------------------
71
- # Zone d'upload d'image (texte personnalisé ajouté)
72
  # -----------------------------------------------------------
73
  uploaded_file = st.file_uploader(
74
- "📂 Déposez une image ici (ou cliquez sur Browse Files pour choisir une image)",
75
  type=["jpg", "jpeg", "png"],
76
- help="Formats supportés : JPG, JPEG, PNG\nMaximum 200MB par image",
77
  accept_multiple_files=False
78
  )
79
 
@@ -121,8 +149,10 @@ else:
121
  with st.expander("🔍 Détails techniques (optionnel)"):
122
  st.markdown(
123
  f"""
124
- - Label retourné : **{label}**
125
- - Probabilité brute de la classe *fire* : **{prob:.4f}**
 
 
126
  - Seuil de décision : **{threshold:.2f}**
127
 
128
  Si `prob_fire >= seuil` → prédiction = *fire*,
 
1
+ import os # pour vérifier la présence des fichiers de modèles
2
+ import streamlit as st # librairie pour le dashboard
3
+ from PIL import Image # pour ouvrir les images
4
+ from inference import ( # fonctions importées du fichier inference.py
5
  load_model,
6
  get_val_transform,
7
  predict_from_pil
 
11
  # Configuration de la page
12
  # -----------------------------------------------------------
13
  st.set_page_config(
14
+ page_title="Fire Detection Dashboard", # titre de l’onglet du navigateur
15
+ page_icon="🔥", # icône (emoji)
16
+ layout="centered" # mise en page centrée
17
  )
18
 
19
  # -----------------------------------------------------------
20
+ # Déclaration des modèles potentiels
21
  # -----------------------------------------------------------
22
+ ALL_MODEL_FILES = {
23
+ "Modèle Efficientnet Baseline": "efficientnet_fire.pt",
24
+ "Modèle Efficientnet Improved": "efficientnet_fire_2.pt",
25
+ "Modèle Inception3": "inception3_fire.pt",
26
+ }
27
+
28
+ # Ne garder que les modèles réellement présents dans le repo
29
+ MODEL_FILES = {
30
+ name: path for name, path in ALL_MODEL_FILES.items() if os.path.exists(path)
31
+ }
32
 
33
+ if len(MODEL_FILES) == 0:
34
+ st.error("❌ Aucun modèle trouvé dans le repository. Ajoutez au moins un fichier .pt.")
35
+ st.stop()
36
 
37
  # -----------------------------------------------------------
38
+ # Sidebar : choix du modèle + infos et paramètres
39
  # -----------------------------------------------------------
40
  st.sidebar.title("⚙️ Paramètres")
41
+
42
+ selected_model_name = st.sidebar.selectbox(
43
+ "Choisir le modèle à utiliser",
44
+ options=list(MODEL_FILES.keys()),
45
+ index=0
46
+ )
47
+ selected_model_path = MODEL_FILES[selected_model_name]
48
+
49
+ st.sidebar.markdown(f"🧠 Modèle sélectionné : **{selected_model_name}**")
50
+
51
  st.sidebar.markdown(
52
  """
53
+ Ce dashboard prédit **FIRE / NO FIRE** sur des images.
 
54
 
55
+ - Classe 0 : **no_fire**
56
  - Classe 1 : **fire**
57
  """
58
  )
 
67
 
68
  st.sidebar.markdown(f"Seuil actuel : **{threshold:.2f}**")
69
 
70
+ # -----------------------------------------------------------
71
+ # Chargement du modèle (en fonction du choix)
72
+ # -----------------------------------------------------------
73
+ @st.cache_resource
74
+ def load_app_model(model_path: str):
75
+ """
76
+ Charge le modèle, le device et la transform une seule fois
77
+ pour un chemin donné, puis les réutilise pour toutes les prédictions.
78
+ """
79
+ model, device = load_model(model_path) # charge les poids du modèle choisi
80
+ transform = get_val_transform() # transform validation/inférence
81
+ return model, device, transform
82
+
83
+ model, device, transform = load_app_model(selected_model_path)
84
+
85
  # -----------------------------------------------------------
86
  # Titre principal
87
  # -----------------------------------------------------------
88
  st.title("🔥 Fire Detection Dashboard")
89
  st.markdown(
90
  """
91
+ Ce prototype permet de tester un modèle de détection de feu
92
+ sur des images individuelles.
93
 
94
  _Charger une image pour obtenir une prédiction._
95
  """
96
  )
97
 
98
  # -----------------------------------------------------------
99
+ # Zone d'upload d'image
100
  # -----------------------------------------------------------
101
  uploaded_file = st.file_uploader(
102
+ "📂 Déposez une image ici (ou cliquez sur Browse Files pour choisir une image)",
103
  type=["jpg", "jpeg", "png"],
104
+ help="Formats supportés : JPG, JPEG, PNG\nMaximum 200MB par image",
105
  accept_multiple_files=False
106
  )
107
 
 
149
  with st.expander("🔍 Détails techniques (optionnel)"):
150
  st.markdown(
151
  f"""
152
+ - Modèle utilisé : **{selected_model_name}**
153
+ - Fichier de poids : `{selected_model_path}`
154
+ - Label retourné : **{label}**
155
+ - Probabilité brute de la classe *fire* : **{prob:.4f}**
156
  - Seuil de décision : **{threshold:.2f}**
157
 
158
  Si `prob_fire >= seuil` → prédiction = *fire*,
inference.py CHANGED
@@ -1,21 +1,23 @@
1
  """
2
  inference.py
3
  ------------
4
- Module d'inférence pour le modèle EfficientNet-B0 entraîné
5
- sur la classification binaire : FIRE (1) / NO_FIRE (0).
6
 
7
- Compatible :
8
- - Google Colab
9
- - Exécution locale (Python)
10
- - Lightning AI
11
- - HuggingFace Spaces / Streamlit
 
 
 
12
 
13
  Usage typique :
14
  ---------------
15
  from inference import load_model, get_val_transform, predict_from_path
16
 
17
  model, device = load_model("efficientnet_fire.pt")
18
- transform = get_val_transform()
19
 
20
  label, prob = predict_from_path("mon_image.jpg", model, device, transform)
21
  print(label, prob)
@@ -24,19 +26,20 @@ print(label, prob)
24
  # ----------------------------
25
  # 1) Imports
26
  # ----------------------------
27
- import torch # bibliothèque principale pour le deep learning
28
- import torch.nn as nn # pour définir la tête de classification
29
- from torchvision import transforms # pour les pré-traitements d'images
30
- from PIL import Image # pour charger les images depuis un fichier
31
- import timm # pour charger EfficientNet-B0
 
32
 
33
 
34
  # ----------------------------
35
  # 2) Constantes globales
36
  # ----------------------------
37
 
38
- # Taille d'entrée du modèle EfficientNet-B0
39
- IMAGE_SIZE = 224 # (224 x 224 pixels)
40
 
41
  # Moyennes et écarts-types d'ImageNet (pour normaliser les images)
42
  IMAGENET_MEAN = [0.485, 0.456, 0.406] # moyenne des canaux R, G, B
@@ -48,6 +51,27 @@ IDX_TO_LABEL = {
48
  1: "fire" # classe 1 → feu
49
  }
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  # ----------------------------
53
  # 3) Utilitaires device
@@ -59,47 +83,115 @@ def get_device():
59
  - 'cuda' si un GPU est disponible
60
  - sinon 'cpu'
61
  """
62
- # torch.cuda.is_available() renvoie True si un GPU CUDA est accessible
63
  if torch.cuda.is_available():
64
- return torch.device("cuda") # on utilisera le GPU
65
  else:
66
- return torch.device("cpu") # sinon le CPU
67
 
68
 
69
  # ----------------------------
70
- # 4) Chargement du modèle
71
  # ----------------------------
72
 
73
- def build_model(num_classes=2):
74
  """
75
- Construit l'architecture EfficientNet-B0 avec une tête
76
- adaptée à la classification binaire (2 classes).
77
- Les poids seront chargés ensuite via load_state_dict.
 
78
  """
79
- # On crée le modèle EfficientNet-B0 sans poids pré-entraînés ici
80
- # (les poids spécifiques à ton projet seront chargés après)
81
- model = timm.create_model("efficientnet_b0", pretrained=False)
 
 
82
 
83
- # On récupère le nombre de features en entrée de la dernière couche
84
- in_features = model.classifier.in_features
85
 
86
- # On remplace la dernière couche par une couche linéaire avec num_classes sorties
87
- model.classifier = nn.Linear(in_features, num_classes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  return model
90
 
91
 
92
- def load_model(weights_path: str, map_location=None):
 
 
 
 
 
 
 
 
93
  """
94
- Charge le modèle EfficientNet-B0 avec les poids entraînés.
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  Paramètres
97
  ----------
98
  weights_path : str
99
- Chemin vers le fichier .pt contenant les poids (state_dict).
100
  map_location : torch.device ou None
101
- Device sur lequel charger les poids.
102
- Si None, on détecte automatiquement (GPU si dispo, sinon CPU).
 
 
103
 
104
  Retour
105
  ------
@@ -108,52 +200,87 @@ def load_model(weights_path: str, map_location=None):
108
  device : torch.device
109
  Le device utilisé (cuda ou cpu).
110
  """
111
- # On détecte le device si non fourni
112
  device = map_location if map_location is not None else get_device()
113
 
114
- # On construit l'architecture du modèle
115
- model = build_model(num_classes=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
- # On charge le dictionnaire de poids sauvegardés (state_dict)
118
- state_dict = torch.load(weights_path, map_location=device)
119
 
120
- # On applique les poids au modèle
121
- model.load_state_dict(state_dict)
122
 
123
- # On envoie le modèle sur le bon device (GPU ou CPU)
124
- model = model.to(device)
 
125
 
126
- # On passe le modèle en mode évaluation (important pour dropout, batchnorm, etc.)
 
127
  model.eval()
128
 
129
  return model, device
130
 
131
 
132
  # ----------------------------
133
- # 5) Transforms pour l'inférence
134
  # ----------------------------
135
 
136
- def get_val_transform():
137
  """
138
  Renvoie les transformations à appliquer aux images pour l'inférence.
139
- Ce sont les mêmes que pour la validation :
140
- - Resize 224x224
141
- - ToTensor
142
- - Normalize (ImageNet)
 
 
 
 
 
143
  """
 
 
 
144
  transform = transforms.Compose([
145
- transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), # redimensionne en 224x224
146
- transforms.ToTensor(), # convertit PIL → Tensor [0,1]
147
- transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) # normalise selon ImageNet
148
  ])
149
  return transform
150
 
151
 
152
  # ----------------------------
153
- # 6) Prétraitement d'une image
154
  # ----------------------------
155
 
156
- def preprocess_image(image: Image.Image, transform=None):
157
  """
158
  Applique les transforms à une image PIL et ajoute une dimension batch.
159
 
@@ -162,43 +289,43 @@ def preprocess_image(image: Image.Image, transform=None):
162
  image : PIL.Image.Image
163
  Image brute chargée (par exemple via Image.open(...)).
164
  transform : callable ou None
165
- Transformations à appliquer (si None, on utilise get_val_transform()).
 
 
166
 
167
  Retour
168
  ------
169
  image_tensor : torch.Tensor
170
- Tenseur prêt pour l'inférence, de taille [1, 3, 224, 224].
171
  """
172
- # Si aucune transform n'est fournie, on utilise la transform par défaut
173
  if transform is None:
174
- transform = get_val_transform()
175
 
176
- # On applique la transform à l'image PIL → tensor [3, 224, 224]
177
- img_tensor = transform(image)
178
-
179
- # On ajoute une dimension batch devant : [1, 3, 224, 224]
180
- img_tensor = img_tensor.unsqueeze(0)
181
 
182
  return img_tensor
183
 
184
 
185
  # ----------------------------
186
- # 7) Fonction de prédiction principale
187
  # ----------------------------
188
 
189
- def predict_from_tensor(image_tensor: torch.Tensor,
190
- model: torch.nn.Module,
191
- device: torch.device,
192
- threshold: float = 0.5):
 
 
193
  """
194
  Prédit la classe (fire/no_fire) à partir d'un tenseur déjà prétraité.
195
 
196
  Paramètres
197
  ----------
198
  image_tensor : torch.Tensor
199
- Tenseur d'images de taille [1, 3, 224, 224] (batch de 1 image).
200
  model : torch.nn.Module
201
- Modèle EfficientNet-B0 chargé.
202
  device : torch.device
203
  Device sur lequel le modèle est (cuda ou cpu).
204
  threshold : float
@@ -211,37 +338,32 @@ def predict_from_tensor(image_tensor: torch.Tensor,
211
  fire_prob : float
212
  Probabilité prédite pour la classe "fire" (entre 0 et 1).
213
  """
214
- # On envoie l'image sur le même device que le modèle
215
  image_tensor = image_tensor.to(device)
216
 
217
- # On désactive le calcul des gradients pour l'inférence
218
  with torch.no_grad():
219
- # Le modèle renvoie des logits de taille [1, 2]
220
- outputs = model(image_tensor)
221
-
222
- # On convertit en probabilités via softmax
223
- probs = torch.softmax(outputs, dim=1)
224
 
225
- # Probabilité de la classe fire (indice 1)
226
  fire_prob = probs[0, 1].item()
227
 
228
- # On décide du label en comparant à un seuil
229
  if fire_prob >= threshold:
230
- predicted_idx = 1 # feu
231
  else:
232
- predicted_idx = 0 # pas de feu
233
 
234
- # Conversion en label lisible
235
  predicted_label = IDX_TO_LABEL[predicted_idx]
236
 
237
  return predicted_label, fire_prob
238
 
239
 
240
- def predict_from_pil(image: Image.Image,
241
- model: torch.nn.Module,
242
- device: torch.device,
243
- transform=None,
244
- threshold: float = 0.5):
 
 
 
245
  """
246
  Prédit la classe à partir d'une image PIL.
247
 
@@ -250,13 +372,15 @@ def predict_from_pil(image: Image.Image,
250
  image : PIL.Image.Image
251
  Image chargée (par exemple via Image.open).
252
  model : torch.nn.Module
253
- Modèle EfficientNet-B0 chargé.
254
  device : torch.device
255
  Device (cuda ou cpu).
256
  transform : callable ou None
257
  Transformations à appliquer à l'image.
258
  threshold : float
259
  Seuil sur la probabilité de FEU.
 
 
260
 
261
  Retour
262
  ------
@@ -265,22 +389,22 @@ def predict_from_pil(image: Image.Image,
265
  fire_prob : float
266
  Probabilité de "fire".
267
  """
268
- # On s'assure que l'image est en mode RGB
269
  if image.mode != "RGB":
270
  image = image.convert("RGB")
271
 
272
- # On prétraite l'image (resize, tensor, normalize, batch)
273
- image_tensor = preprocess_image(image, transform=transform)
274
 
275
- # On délègue la prédiction à predict_from_tensor
276
  return predict_from_tensor(image_tensor, model, device, threshold=threshold)
277
 
278
 
279
- def predict_from_path(image_path: str,
280
- model: torch.nn.Module,
281
- device: torch.device,
282
- transform=None,
283
- threshold: float = 0.5):
 
 
 
284
  """
285
  Prédit la classe à partir d'un chemin vers une image.
286
 
@@ -289,13 +413,15 @@ def predict_from_path(image_path: str,
289
  image_path : str
290
  Chemin vers le fichier image (jpg, png, etc.).
291
  model : torch.nn.Module
292
- Modèle EfficientNet-B0 chargé.
293
  device : torch.device
294
  Device (cuda ou cpu).
295
  transform : callable ou None
296
  Transformations à appliquer.
297
  threshold : float
298
  Seuil sur la probabilité de FEU.
 
 
299
 
300
  Retour
301
  ------
@@ -304,52 +430,47 @@ def predict_from_path(image_path: str,
304
  fire_prob : float
305
  Probabilité de "fire".
306
  """
307
- # On charge l'image depuis le disque via PIL
308
  image = Image.open(image_path)
309
-
310
- # On délègue la prédiction à la fonction base sur PIL
311
- return predict_from_pil(image, model, device, transform=transform, threshold=threshold)
 
 
 
 
 
312
 
313
 
314
  # ----------------------------
315
- # 8) Exemple d'utilisation en script direct
316
  # ----------------------------
317
 
318
  if __name__ == "__main__":
319
  """
320
- Ce bloc s'exécute uniquement si on lance le fichier directement :
321
  python inference.py
322
-
323
- Tu peux le modifier pour faire un petit test rapide en local
324
- ou dans un notebook via !python inference.py.
325
  """
326
- import os
327
-
328
- # Chemin vers le fichier de poids (à adapter si besoin)
329
- weights_path = "efficientnet_fire.pt"
330
 
331
  if not os.path.exists(weights_path):
332
  print(f"[ERREUR] Fichier de poids introuvable : {weights_path}")
333
  else:
334
- # 1) On charge le modèle et on détecte le device
335
  model, device = load_model(weights_path)
336
  print(f"Modèle chargé sur le device : {device}")
337
 
338
- # 2) On récupère la transform de validation/inférence
339
  transform = get_val_transform()
340
 
341
- # 3) Exemple : prédire sur une image de test (chemin à adapter)
342
- test_image_path = "example.jpg" # ← remplace par une vraie image
343
 
344
  if not os.path.exists(test_image_path):
345
  print(f"[INFO] Aucune image test trouvée à : {test_image_path}")
346
- print(" Modifie le chemin dans __main__ pour tester une image.")
347
  else:
348
  label, prob = predict_from_path(
349
  test_image_path,
350
  model=model,
351
  device=device,
352
  transform=transform,
353
- threshold=0.5
354
  )
355
  print(f"Résultat pour {test_image_path} : label={label}, prob_fire={prob:.4f}")
 
1
  """
2
  inference.py
3
  ------------
4
+ Module d'inférence pour des modèles de classification binaire : FIRE (1) / NO_FIRE (0).
 
5
 
6
+ Compatible avec plusieurs architectures (ex. EfficientNet-B0, Inception v3),
7
+ sans changer le code dès qu'on ajoute un nouveau fichier .pt, à condition :
8
+
9
+ - que le nom du fichier permette de deviner le type de modèle
10
+ (ex : "inception3_fire.pt" Inception v3,
11
+ sinon → EfficientNet-B0 par défaut),
12
+ - ou que le fichier .pt contienne un dictionnaire avec "state_dict" (et
13
+ éventuellement "model_name" / "model_key").
14
 
15
  Usage typique :
16
  ---------------
17
  from inference import load_model, get_val_transform, predict_from_path
18
 
19
  model, device = load_model("efficientnet_fire.pt")
20
+ transform = get_val_transform() # ou get_val_transform(image_size_personnalisée)
21
 
22
  label, prob = predict_from_path("mon_image.jpg", model, device, transform)
23
  print(label, prob)
 
26
  # ----------------------------
27
  # 1) Imports
28
  # ----------------------------
29
+ import os # pour gérer les chemins de fichiers
30
+ import torch # bibliothèque principale pour le deep learning
31
+ import torch.nn as nn # pour définir les têtes de classification
32
+ from torchvision import transforms # pour les pré-traitements d'images
33
+ from PIL import Image # pour charger les images depuis un fichier
34
+ import timm # pour charger des architectures (EfficientNet, Inception, etc.)
35
 
36
 
37
  # ----------------------------
38
  # 2) Constantes globales
39
  # ----------------------------
40
 
41
+ # Taille d'entrée par défaut (utilisée si on ne précise rien d'autre).
42
+ DEFAULT_IMAGE_SIZE = 224 # (224 x 224 pixels)
43
 
44
  # Moyennes et écarts-types d'ImageNet (pour normaliser les images)
45
  IMAGENET_MEAN = [0.485, 0.456, 0.406] # moyenne des canaux R, G, B
 
51
  1: "fire" # classe 1 → feu
52
  }
53
 
54
+ # Registre des modèles supportés.
55
+ # - model_key : identifiant interne (nos clés)
56
+ # - timm_name : nom utilisé dans timm.create_model(...)
57
+ # - image_size : taille d'entrée "recommandée" pour ce modèle (optionnelle)
58
+ # - classifier_attr : nom de l'attribut contenant la dernière couche de classification
59
+ MODEL_REGISTRY = {
60
+ "efficientnet_b0": {
61
+ "timm_name": "efficientnet_b0",
62
+ "image_size": 224,
63
+ "classifier_attr": "classifier",
64
+ },
65
+ "inception_v3": {
66
+ "timm_name": "inception_v3",
67
+ "image_size": 299,
68
+ "classifier_attr": "fc",
69
+ },
70
+ }
71
+
72
+ # Si on ne reconnaît pas de modèle particulier, on utilisera EfficientNet-B0 par défaut.
73
+ DEFAULT_MODEL_KEY = "efficientnet_b0"
74
+
75
 
76
  # ----------------------------
77
  # 3) Utilitaires device
 
83
  - 'cuda' si un GPU est disponible
84
  - sinon 'cpu'
85
  """
 
86
  if torch.cuda.is_available():
87
+ return torch.device("cuda")
88
  else:
89
+ return torch.device("cpu")
90
 
91
 
92
  # ----------------------------
93
+ # 4) Détection du type de modèle
94
  # ----------------------------
95
 
96
+ def infer_model_key_from_path(weights_path: str) -> str:
97
  """
98
+ Devine une clé de modèle (model_key) à partir du nom de fichier.
99
+ Exemple :
100
+ - "inception3_fire.pt" → "inception_v3"
101
+ - "efficientnet_fire.pt" → "efficientnet_b0" (par défaut)
102
  """
103
+ filename = os.path.basename(weights_path).lower()
104
+
105
+ # Règle simple : si "inception" est dans le nom → Inception v3
106
+ if "inception" in filename:
107
+ return "inception_v3"
108
 
109
+ # Sinon, par défaut EfficientNet-B0
110
+ return DEFAULT_MODEL_KEY
111
 
112
+
113
+ def get_model_config(model_key: str) -> dict:
114
+ """
115
+ Renvoie la config du modèle à partir d'une model_key.
116
+ Si model_key n'est pas connue, on retourne la config du modèle par défaut.
117
+ """
118
+ if model_key in MODEL_REGISTRY:
119
+ return MODEL_REGISTRY[model_key]
120
+ # Fallback de sécurité : ne jamais casser si clé inconnue
121
+ return MODEL_REGISTRY[DEFAULT_MODEL_KEY]
122
+
123
+
124
+ # ----------------------------
125
+ # 5) Construction de l'architecture
126
+ # ----------------------------
127
+
128
+ def build_model(model_key: str, num_classes: int = 2) -> torch.nn.Module:
129
+ """
130
+ Construit l'architecture correspondant à model_key
131
+ et adapte la tête de classification à num_classes sorties.
132
+ """
133
+ config = get_model_config(model_key)
134
+ timm_name = config["timm_name"]
135
+ classifier_attr = config["classifier_attr"]
136
+
137
+ # On crée le backbone via timm (sans pré-entraînement, les poids viendront du .pt)
138
+ model = timm.create_model(timm_name, pretrained=False)
139
+
140
+ # On adapte la tête de classification à notre problème binaire.
141
+ # On récupère l'ancienne couche de classification (fc, classifier, etc.)
142
+ classifier = getattr(model, classifier_attr)
143
+
144
+ # Certains modèles ont déjà un nn.Linear, on récupère in_features
145
+ if isinstance(classifier, nn.Linear):
146
+ in_features = classifier.in_features
147
+ else:
148
+ # Cas plus exotique : on essaie de deviner proprement, sinon on lève une erreur claire.
149
+ raise ValueError(
150
+ f"Impossible de déterminer in_features pour la tête du modèle '{timm_name}'. "
151
+ f"Attribut '{classifier_attr}' de type {type(classifier)} non supporté."
152
+ )
153
+
154
+ # On remplace par une couche linéaire adaptée à notre nombre de classes
155
+ new_classifier = nn.Linear(in_features, num_classes)
156
+ setattr(model, classifier_attr, new_classifier)
157
 
158
  return model
159
 
160
 
161
+ # ----------------------------
162
+ # 6) Chargement du modèle
163
+ # ----------------------------
164
+
165
+ def _clean_state_dict_keys(state_dict: dict) -> dict:
166
+ """
167
+ Nettoie les clés d'un state_dict pour gérer plusieurs cas courants :
168
+ - clés préfixées par 'model.' (Lightning)
169
+ - clés préfixées par 'module.' (DataParallel)
170
  """
171
+ new_state_dict = {}
172
+ for k, v in state_dict.items():
173
+ new_key = k
174
+ if new_key.startswith("model."):
175
+ new_key = new_key[len("model."):]
176
+ if new_key.startswith("module."):
177
+ new_key = new_key[len("module."):]
178
+ new_state_dict[new_key] = v
179
+ return new_state_dict
180
+
181
+
182
+ def load_model(weights_path: str, map_location=None, model_key: str | None = None):
183
+ """
184
+ Charge un modèle avec les poids entraînés.
185
 
186
  Paramètres
187
  ----------
188
  weights_path : str
189
+ Chemin vers le fichier .pt (state_dict ou dict avec 'state_dict').
190
  map_location : torch.device ou None
191
+ Device sur lequel charger les poids. Si None, on détecte automatiquement.
192
+ model_key : str ou None
193
+ Clé de modèle à utiliser (ex: 'efficientnet_b0', 'inception_v3').
194
+ Si None, on essaie de la déduire du nom de fichier.
195
 
196
  Retour
197
  ------
 
200
  device : torch.device
201
  Le device utilisé (cuda ou cpu).
202
  """
203
+ # 1) Sélection du device
204
  device = map_location if map_location is not None else get_device()
205
 
206
+ # 2) Si la model_key n'est pas donnée, on essaye de l'inférer depuis le nom de fichier
207
+ if model_key is None:
208
+ model_key = infer_model_key_from_path(weights_path)
209
+
210
+ # 3) Chargement brut du .pt
211
+ checkpoint = torch.load(weights_path, map_location=device)
212
+
213
+ # 4) Si le checkpoint est un dict complet (ex: {'state_dict': ..., 'model_name': ...})
214
+ if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
215
+ state_dict = checkpoint["state_dict"]
216
+
217
+ # Optionnel : si un model_name ou model_key est stocké dedans, on peut le préférer.
218
+ if "model_key" in checkpoint and checkpoint["model_key"] in MODEL_REGISTRY:
219
+ model_key = checkpoint["model_key"]
220
+ elif "model_name" in checkpoint:
221
+ # Tentative : si model_name correspond au timm_name d'une entrée du registre
222
+ model_name_lower = str(checkpoint["model_name"]).lower()
223
+ for k, cfg in MODEL_REGISTRY.items():
224
+ if cfg["timm_name"].lower() == model_name_lower:
225
+ model_key = k
226
+ break
227
+ else:
228
+ # Cas simple : le fichier .pt est directement un state_dict
229
+ state_dict = checkpoint
230
+
231
+ # 5) Nettoyage des clés du state_dict (Lightning, DataParallel...)
232
+ state_dict = _clean_state_dict_keys(state_dict)
233
 
234
+ # 6) Construction de l'architecture adaptée
235
+ model = build_model(model_key=model_key, num_classes=2)
236
 
237
+ # 7) Application des poids (strict=False pour tolérer quelques petites diff de clés)
238
+ missing, unexpected = model.load_state_dict(state_dict, strict=False)
239
 
240
+ # (Optionnel) DEBUG : on pourrait afficher missing / unexpected si besoin
241
+ # print("Missing keys:", missing)
242
+ # print("Unexpected keys:", unexpected)
243
 
244
+ # 8) Envoi sur le bon device + mode eval
245
+ model = model.to(device)
246
  model.eval()
247
 
248
  return model, device
249
 
250
 
251
  # ----------------------------
252
+ # 7) Transforms pour l'inférence
253
  # ----------------------------
254
 
255
+ def get_val_transform(image_size: int | None = None):
256
  """
257
  Renvoie les transformations à appliquer aux images pour l'inférence.
258
+ Paramètres
259
+ ----------
260
+ image_size : int ou None
261
+ Si None, on utilise DEFAULT_IMAGE_SIZE (224).
262
+ Sinon, on redimensionne en (image_size, image_size).
263
+
264
+ Retour
265
+ ------
266
+ transform : torchvision.transforms.Compose
267
  """
268
+ if image_size is None:
269
+ image_size = DEFAULT_IMAGE_SIZE
270
+
271
  transform = transforms.Compose([
272
+ transforms.Resize((image_size, image_size)),
273
+ transforms.ToTensor(),
274
+ transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
275
  ])
276
  return transform
277
 
278
 
279
  # ----------------------------
280
+ # 8) Prétraitement d'une image
281
  # ----------------------------
282
 
283
+ def preprocess_image(image: Image.Image, transform=None, image_size: int | None = None):
284
  """
285
  Applique les transforms à une image PIL et ajoute une dimension batch.
286
 
 
289
  image : PIL.Image.Image
290
  Image brute chargée (par exemple via Image.open(...)).
291
  transform : callable ou None
292
+ Transformations à appliquer (si None, on utilise get_val_transform(image_size)).
293
+ image_size : int ou None
294
+ Taille de redimensionnement si transform est None.
295
 
296
  Retour
297
  ------
298
  image_tensor : torch.Tensor
299
+ Tenseur prêt pour l'inférence, de taille [1, 3, H, W].
300
  """
 
301
  if transform is None:
302
+ transform = get_val_transform(image_size=image_size)
303
 
304
+ img_tensor = transform(image) # [3, H, W]
305
+ img_tensor = img_tensor.unsqueeze(0) # [1, 3, H, W]
 
 
 
306
 
307
  return img_tensor
308
 
309
 
310
  # ----------------------------
311
+ # 9) Fonction de prédiction principale
312
  # ----------------------------
313
 
314
+ def predict_from_tensor(
315
+ image_tensor: torch.Tensor,
316
+ model: torch.nn.Module,
317
+ device: torch.device,
318
+ threshold: float = 0.5
319
+ ):
320
  """
321
  Prédit la classe (fire/no_fire) à partir d'un tenseur déjà prétraité.
322
 
323
  Paramètres
324
  ----------
325
  image_tensor : torch.Tensor
326
+ Tenseur d'images de taille [1, 3, H, W] (batch de 1 image).
327
  model : torch.nn.Module
328
+ Modèle chargé (EfficientNet, Inception, etc.).
329
  device : torch.device
330
  Device sur lequel le modèle est (cuda ou cpu).
331
  threshold : float
 
338
  fire_prob : float
339
  Probabilité prédite pour la classe "fire" (entre 0 et 1).
340
  """
 
341
  image_tensor = image_tensor.to(device)
342
 
 
343
  with torch.no_grad():
344
+ outputs = model(image_tensor) # logits [1, 2]
345
+ probs = torch.softmax(outputs, dim=1) # probabilités
 
 
 
346
 
 
347
  fire_prob = probs[0, 1].item()
348
 
 
349
  if fire_prob >= threshold:
350
+ predicted_idx = 1 # feu
351
  else:
352
+ predicted_idx = 0 # pas de feu
353
 
 
354
  predicted_label = IDX_TO_LABEL[predicted_idx]
355
 
356
  return predicted_label, fire_prob
357
 
358
 
359
+ def predict_from_pil(
360
+ image: Image.Image,
361
+ model: torch.nn.Module,
362
+ device: torch.device,
363
+ transform=None,
364
+ threshold: float = 0.5,
365
+ image_size: int | None = None,
366
+ ):
367
  """
368
  Prédit la classe à partir d'une image PIL.
369
 
 
372
  image : PIL.Image.Image
373
  Image chargée (par exemple via Image.open).
374
  model : torch.nn.Module
375
+ Modèle chargé.
376
  device : torch.device
377
  Device (cuda ou cpu).
378
  transform : callable ou None
379
  Transformations à appliquer à l'image.
380
  threshold : float
381
  Seuil sur la probabilité de FEU.
382
+ image_size : int ou None
383
+ Taille utilisée si transform est None.
384
 
385
  Retour
386
  ------
 
389
  fire_prob : float
390
  Probabilité de "fire".
391
  """
 
392
  if image.mode != "RGB":
393
  image = image.convert("RGB")
394
 
395
+ image_tensor = preprocess_image(image, transform=transform, image_size=image_size)
 
396
 
 
397
  return predict_from_tensor(image_tensor, model, device, threshold=threshold)
398
 
399
 
400
+ def predict_from_path(
401
+ image_path: str,
402
+ model: torch.nn.Module,
403
+ device: torch.device,
404
+ transform=None,
405
+ threshold: float = 0.5,
406
+ image_size: int | None = None,
407
+ ):
408
  """
409
  Prédit la classe à partir d'un chemin vers une image.
410
 
 
413
  image_path : str
414
  Chemin vers le fichier image (jpg, png, etc.).
415
  model : torch.nn.Module
416
+ Modèle chargé.
417
  device : torch.device
418
  Device (cuda ou cpu).
419
  transform : callable ou None
420
  Transformations à appliquer.
421
  threshold : float
422
  Seuil sur la probabilité de FEU.
423
+ image_size : int ou None
424
+ Taille utilisée si transform est None.
425
 
426
  Retour
427
  ------
 
430
  fire_prob : float
431
  Probabilité de "fire".
432
  """
 
433
  image = Image.open(image_path)
434
+ return predict_from_pil(
435
+ image,
436
+ model=model,
437
+ device=device,
438
+ transform=transform,
439
+ threshold=threshold,
440
+ image_size=image_size,
441
+ )
442
 
443
 
444
  # ----------------------------
445
+ # 10) Exemple d'utilisation en script direct
446
  # ----------------------------
447
 
448
  if __name__ == "__main__":
449
  """
450
+ Exemple simple pour tester le module en local :
451
  python inference.py
 
 
 
452
  """
453
+ weights_path = "efficientnet_fire.pt" # à adapter si besoin
 
 
 
454
 
455
  if not os.path.exists(weights_path):
456
  print(f"[ERREUR] Fichier de poids introuvable : {weights_path}")
457
  else:
 
458
  model, device = load_model(weights_path)
459
  print(f"Modèle chargé sur le device : {device}")
460
 
461
+ # On utilise la taille par défaut (224) pour ce test
462
  transform = get_val_transform()
463
 
464
+ test_image_path = "example.jpg" # à adapter si besoin
 
465
 
466
  if not os.path.exists(test_image_path):
467
  print(f"[INFO] Aucune image test trouvée à : {test_image_path}")
 
468
  else:
469
  label, prob = predict_from_path(
470
  test_image_path,
471
  model=model,
472
  device=device,
473
  transform=transform,
474
+ threshold=0.5,
475
  )
476
  print(f"Résultat pour {test_image_path} : label={label}, prob_fire={prob:.4f}")