ninafr8175 commited on
Commit
433b881
·
1 Parent(s): 8d8d12a

to push to hf

Browse files
Files changed (9) hide show
  1. .gitattributes +56 -0
  2. .gitignore +11 -0
  3. Dockerfile +18 -0
  4. Fire_detection_Project.ipynb +0 -0
  5. README.md +12 -119
  6. app.py +131 -0
  7. efficientnet_fire.pt +3 -0
  8. inference.py +355 -0
  9. requirements.txt +0 -0
.gitattributes ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ * text=auto
3
+
4
+ *.py text eol=lf
5
+ *.ipynb -diff
6
+ *.sh text eol=lf
7
+ Dockerfile text eol=lf
8
+ *.yml text eol=lf
9
+ *.yaml text eol=lf
10
+
11
+ *.ps1 text eol=crlf
12
+ *.bat text eol=crlf
13
+
14
+ *.csv text eol=lf working-tree-encoding=UTF-8
15
+ *.tsv text eol=lf working-tree-encoding=UTF-8
16
+
17
+ *.png binary
18
+ *.jpg binary
19
+ *.jpeg binary
20
+ *.pdf binary
21
+
22
+ *.7z filter=lfs diff=lfs merge=lfs -text
23
+ *.arrow filter=lfs diff=lfs merge=lfs -text
24
+ *.bin filter=lfs diff=lfs merge=lfs -text
25
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
26
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
27
+ *.ftz filter=lfs diff=lfs merge=lfs -text
28
+ *.gz filter=lfs diff=lfs merge=lfs -text
29
+ *.h5 filter=lfs diff=lfs merge=lfs -text
30
+ *.joblib filter=lfs diff=lfs merge=lfs -text
31
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
32
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
33
+ *.model filter=lfs diff=lfs merge=lfs -text
34
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
35
+ *.npy filter=lfs diff=lfs merge=lfs -text
36
+ *.npz filter=lfs diff=lfs merge=lfs -text
37
+ *.onnx filter=lfs diff=lfs merge=lfs -text
38
+ *.ot filter=lfs diff=lfs merge=lfs -text
39
+ *.parquet filter=lfs diff=lfs merge=lfs -text
40
+ *.pb filter=lfs diff=lfs merge=lfs -text
41
+ *.pickle filter=lfs diff=lfs merge=lfs -text
42
+ *.pkl filter=lfs diff=lfs merge=lfs -text
43
+ *.pt filter=lfs diff=lfs merge=lfs -text
44
+ *.pth filter=lfs diff=lfs merge=lfs -text
45
+ *.rar filter=lfs diff=lfs merge=lfs -text
46
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
47
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
48
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
49
+ *.tar filter=lfs diff=lfs merge=lfs -text
50
+ *.tflite filter=lfs diff=lfs merge=lfs -text
51
+ *.tgz filter=lfs diff=lfs merge=lfs -text
52
+ *.wasm filter=lfs diff=lfs merge=lfs -text
53
+ *.xz filter=lfs diff=lfs merge=lfs -text
54
+ *.zip filter=lfs diff=lfs merge=lfs -text
55
+ *.zst filter=lfs diff=lfs merge=lfs -text
56
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Environnement virtuel
2
+ .venv/
3
+
4
+ # Variables secrètes
5
+ .env
6
+
7
+ # Fichiers générés automatiquement
8
+ __pycache__/
9
+ *.pyc
10
+ *.pptx
11
+ .dockerignore
Dockerfile ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.13-slim
2
+
3
+ WORKDIR /app
4
+
5
+ ENV PYTHONDONTWRITEBYTECODE=1
6
+ ENV PYTHONUNBUFFERED=1
7
+
8
+ RUN apt-get update && apt-get install -y --no-install-recommends \
9
+ && rm -rf /var/lib/apt/lists/*
10
+
11
+ COPY requirements.txt ./requirements.txt
12
+ RUN pip install --no-cache-dir -r requirements.txt
13
+
14
+ COPY . .
15
+
16
+ EXPOSE 7860
17
+
18
+ CMD ["streamlit", "run", "app.py", "--server.port=7860", "--server.address=0.0.0.0"]
Fire_detection_Project.ipynb ADDED
Binary file (646 kB). View file
 
README.md CHANGED
@@ -1,119 +1,12 @@
1
- # Fire_detection
2
-
3
- ## Datasets et publications
4
- Datasets :
5
- https://docs.google.com/spreadsheets/d/1mSaPf1uFKBiJR4ECJ9Qw8Ne89AfqBFdzOHYIDpTkkfE/edit?gid=0#gid=0
6
- Publications :
7
- https://docs.google.com/document/d/1pBWHai2zxGqSHPpdVhFFJNARVEa5PGTwTP2JDLOGoOk/edit?usp=sharing
8
-
9
- # Idées / Pistes
10
-
11
- un site génial : https://www.alertwildfire.org/?viewMode=Grid
12
-
13
- Un réseau de caméra de sécrité anti incendies qui couvre tout l'Ouest des Etats_Unis
14
- Ils n'ont pas d'ia connectés aux caméras pour détecter automatiquement les incendies
15
- idée de projet : on peut concevoir notre ia pour qu'elle puisse surveiller ces videos en temps réel et lancer des alertes
16
-
17
- # feuille de route pour y voir plus clair
18
-
19
- Problématique : Détection précoce des incendies
20
- - afin d'avoir le plus de chances possible
21
-
22
- **Etapes** :
23
-
24
- I. **Datasets** et **publications** : (Mercredi et jeudi)
25
- - début de feu
26
- - feux déjà bien présent
27
- - avec variable cible
28
- - Choix du dataset : OK : https://www.kaggle.com/datasets/elmadafri/the-wildfire-dataset
29
- - Dataset vidéos surveillance découpées en images. Vue de loin, surtout fumée.
30
- - Publications : voir ce que les autres ont fait
31
- - lien google doc explications publications
32
- https://docs.google.com/document/d/1pBWHai2zxGqSHPpdVhFFJNARVEa5PGTwTP2JDLOGoOk/edit?usp=sharing
33
-
34
- II. **Choix** : (Jeudi)
35
- - finir publications
36
- - de la démarche à suivre / Variable retenus
37
- - cleaner le dataset
38
-
39
- III. **Baseline** : (Vendredi)
40
- - cleanner le dataset
41
- - essayer d'entrainer un modèle sur ces données
42
-
43
- # Structure deep learning - d'après les publications :
44
- https://www.mdpi.com/1424-8220/20/22/6442#Early_Fire_Detection_Systems
45
-
46
- ### Préparation des images
47
- * redimensionner les images
48
- * Augmenter les données (rotation, miroir, luminosité ...)
49
- * split de la publication : Train test Val : 60-20-20 ou 70-15-15
50
-
51
- ### CNN avancé (type VGG, LeNet, ou GoogleNet)
52
-
53
- * On charge un modèle existant (pré-entraîné) puis on ajuste les couches de sortie (fine-tuning) :
54
- * Exemple pour VGG16 :
55
- from tensorflow.keras.applications import VGG16
56
- base_model = VGG16(weights='imagenet', include_top=False, input_shape=(224,224,3))
57
- > On ajoute une nouvelle tête de classification adaptée à la détection feu/non-feu
58
-
59
- ### Combiner images fixes et mouvement
60
- * images statiques
61
- * mouvement par "optical flow" : déplacement de pixel entre deux images fixes
62
- * (différentes méthodes : " Lucas-kanade", "Pyramdal Lucas-Kanade", "Farneback"
63
- * Ca donne une carte qui indique la vitesse et le déplacement des pixels
64
- * sinon j'amagine que c'est quasi impossible de distinguer fumée / nuage
65
- * concaténer les deux
66
- * il faut qu'ils soient normalisés - échelle comparable -
67
- * et redimensionner : trop de dimension > surapprentissage / nécessite plus de données (réduction dimensionalité potentiel)
68
- * ça donne un gros vecteur final ui combine les deux ancien
69
- * si problème de dimension avant : problème de dimension entre ces vecteurs
70
- * la couche de sortie fournie la probabilité feu / non feu
71
-
72
- ### Entraîner 2 CNN, de façon hiérarchique
73
- * 1er : Dit si feu ou non
74
- * 2e : Sert à localiser le feu dans l'image
75
- * superpixels : grand groupe pixel avec ses voisins
76
-
77
- ### Détection temps-réel
78
- * Pour qu'il puisse détecter en temps réel il faut que ce soit léger
79
- * MobileNet, SqueezeNet : architecture compressé
80
- * cadence d'image (fps) limité, peut être directement embarqué dans caméra
81
- * système adaptatif : si suspicion : activer caméra HD sinon flu basse résolution
82
-
83
- ### Segmentation par réseaux spécialisés
84
- * un modèle type "FCN" ou Deeplabv3" : pas tout à fait compris
85
- * comprends rien mais pour ça il va d'aord détecter les régions par Faster R-CNN
86
- * ensuite il passe ces régions dans un LSTM pour "validation temporelle" (il prend tant de frames autour pour voir le mouvement et l'évoluion de la fumée)
87
-
88
- ### Amélioration
89
- * training avec data augmentation
90
- * validation croisée : diviser en k fold pour évaluer robestesse
91
-
92
-
93
-
94
- # Problèmatiques :
95
- influençant le choix du modèle à pré-entrainer
96
-
97
- ## Problème principal : les caméras
98
-
99
- ### Variabilité extrême :
100
- - Caméra de volontaires, publique ou privé
101
- - rafraichissement inégale : parfois 1 à plusieurs minutes voir plusieurs heures
102
- - Variabilité extrême : lumière, météo, angle
103
-
104
- ### Qualité trés basse
105
- - caméra trés distante > fumée petite > l'IA doit être sensible au faibles signaux
106
- - Donc architecture puissante pour capter les petits indices
107
- - DOnc faux positive élevé > Couche de de validation ou score de confiance (ou les deux)
108
-
109
- ### Donc le pipeline :
110
- - Le pipeline devra récupérer automatiquement les images du site : traitement de batch en temps réel
111
- - pour chaque caméra : il faut enregistrer : Quel distance / angle / paysage et créer un base de donnée de référencement
112
- - on peut envisager un modèle spécifique pour chaque type
113
-
114
- Donc, le mieux est de contacter ces charmants personnages à l'origine de ce site, ci-nommés :
115
- - Dr. Bill Savran, NSL Lab Manager, University of Nevada Reno : wsavran@unr.edu
116
- - Dr. Christie Rowe, NSL Director, University of Nevada Reno, : rowec@unr.edu
117
- On découvre que ce site a été mis en place par l'Université du Nevada, on leur fait un coucou
118
-
119
-
 
1
+ ---
2
+ title: Fire Detection
3
+ emoji: 💻
4
+ colorFrom: green
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 5.49.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ short_description: Fire Detection System using Streamlit
12
+ ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st # librairie pour le dashboard
2
+ from PIL import Image # pour ouvrir les images
3
+ from inference import ( # fonctions importées du fichier inference.py
4
+ load_model,
5
+ get_val_transform,
6
+ predict_from_pil
7
+ )
8
+
9
+ # -----------------------------------------------------------
10
+ # Configuration de la page
11
+ # -----------------------------------------------------------
12
+ st.set_page_config(
13
+ page_title="Fire Detection Dashboard", # titre de l’onglet du navigateur
14
+ page_icon="🔥", # icône (emoji)
15
+ layout="centered" # mise en page centrée
16
+ )
17
+
18
+ # -----------------------------------------------------------
19
+ # Chargement du modèle (une seule fois)
20
+ # -----------------------------------------------------------
21
+ @st.cache_resource
22
+ def load_app_model():
23
+ """
24
+ Charge le modèle, le device et la transform une seule fois,
25
+ puis les réutilise pour toutes les prédictions.
26
+ """
27
+ model, device = load_model("efficientnet_fire.pt") # charge les poids
28
+ transform = get_val_transform() # transform validation/inférence
29
+ return model, device, transform
30
+
31
+ model, device, transform = load_app_model()
32
+
33
+ # -----------------------------------------------------------
34
+ # Sidebar : infos et paramètres
35
+ # -----------------------------------------------------------
36
+ st.sidebar.title("⚙️ Paramètres")
37
+ st.sidebar.markdown(
38
+ """
39
+ Ce dashboard utilise un modèle EfficientNet-B0,
40
+ entraîné à prédire **FIRE / NO FIRE** sur des images.
41
+
42
+ - Classe 0 : **no_fire**
43
+ - Classe 1 : **fire**
44
+ """
45
+ )
46
+
47
+ threshold = st.sidebar.slider(
48
+ "Seuil de détection du feu (probabilité minimale pour 'fire')",
49
+ min_value=0.1,
50
+ max_value=0.9,
51
+ value=0.5,
52
+ step=0.05,
53
+ )
54
+
55
+ st.sidebar.markdown(f"Seuil actuel : **{threshold:.2f}**")
56
+
57
+ # -----------------------------------------------------------
58
+ # Titre principal
59
+ # -----------------------------------------------------------
60
+ st.title("🔥 Fire Detection Dashboard")
61
+ st.markdown(
62
+ """
63
+ Ce prototype permet de tester un modèle de détection de feu,
64
+ sur des images individuelles.
65
+
66
+ _Charger une image pour obtenir une prédiction._
67
+ """
68
+ )
69
+
70
+ # -----------------------------------------------------------
71
+ # Zone d'upload d'image (texte personnalisé ajouté)
72
+ # -----------------------------------------------------------
73
+ uploaded_file = st.file_uploader(
74
+ "📂 Déposez une image ici (ou cliquez sur Browse Files pour choisir une image)",
75
+ type=["jpg", "jpeg", "png"],
76
+ help="Formats supportés : JPG, JPEG, PNG\nMaximum 200MB par image",
77
+ accept_multiple_files=False
78
+ )
79
+
80
+ # -----------------------------------------------------------
81
+ # Si aucune image n'est encore uploadée
82
+ # -----------------------------------------------------------
83
+ if uploaded_file is None:
84
+ st.info("👉 En attente d'une image. Charger une photo de forêt, flamme, paysage, etc.")
85
+ else:
86
+ # -------------------------------------------------------
87
+ # Afficher l'image uploadée
88
+ # -------------------------------------------------------
89
+ image = Image.open(uploaded_file)
90
+ st.image(image, caption="Image chargée", use_container_width=True)
91
+
92
+ # -------------------------------------------------------
93
+ # Prédiction
94
+ # -------------------------------------------------------
95
+ with st.spinner("Analyse de l'image en cours..."):
96
+ label, prob = predict_from_pil(
97
+ image=image,
98
+ model=model,
99
+ device=device,
100
+ transform=transform,
101
+ threshold=threshold
102
+ )
103
+
104
+ # -------------------------------------------------------
105
+ # Affichage du résultat avec couleur
106
+ # -------------------------------------------------------
107
+ prob_percent = prob * 100
108
+
109
+ if label == "fire":
110
+ st.error(
111
+ f"🔥 **FEU DÉTECTÉ** \nProbabilité de feu : **{prob_percent:.2f}%** \n(Seuil utilisé : {threshold:.2f})"
112
+ )
113
+ else:
114
+ st.success(
115
+ f"✅ **PAS DE FEU DÉTECTÉ** \nProbabilité de feu : **{prob_percent:.2f}%** \n(Seuil utilisé : {threshold:.2f})"
116
+ )
117
+
118
+ # -------------------------------------------------------
119
+ # Détails supplémentaires (dans un expander)
120
+ # -------------------------------------------------------
121
+ with st.expander("🔍 Détails techniques (optionnel)"):
122
+ st.markdown(
123
+ f"""
124
+ - Label retourné : **{label}**
125
+ - Probabilité brute de la classe *fire* : **{prob:.4f}**
126
+ - Seuil de décision : **{threshold:.2f}**
127
+
128
+ Si `prob_fire >= seuil` → prédiction = *fire*,
129
+ sinon → *no_fire*.
130
+ """
131
+ )
efficientnet_fire.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1618ff899b64649c5752a546f57d2a05ad8d85e0cd385ffd07b7732d49f5fbda
3
+ size 16338683
inference.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ inference.py
3
+ ------------
4
+ Module d'inférence pour le modèle EfficientNet-B0 entraîné
5
+ sur la classification binaire : FIRE (1) / NO_FIRE (0).
6
+
7
+ Compatible :
8
+ - Google Colab
9
+ - Exécution locale (Python)
10
+ - Lightning AI
11
+ - HuggingFace Spaces / Streamlit
12
+
13
+ Usage typique :
14
+ ---------------
15
+ from inference import load_model, get_val_transform, predict_from_path
16
+
17
+ model, device = load_model("efficientnet_fire.pt")
18
+ transform = get_val_transform()
19
+
20
+ label, prob = predict_from_path("mon_image.jpg", model, device, transform)
21
+ print(label, prob)
22
+ """
23
+
24
+ # ----------------------------
25
+ # 1) Imports
26
+ # ----------------------------
27
+ import torch # bibliothèque principale pour le deep learning
28
+ import torch.nn as nn # pour définir la tête de classification
29
+ from torchvision import transforms # pour les pré-traitements d'images
30
+ from PIL import Image # pour charger les images depuis un fichier
31
+ import timm # pour charger EfficientNet-B0
32
+
33
+
34
+ # ----------------------------
35
+ # 2) Constantes globales
36
+ # ----------------------------
37
+
38
+ # Taille d'entrée du modèle EfficientNet-B0
39
+ IMAGE_SIZE = 224 # (224 x 224 pixels)
40
+
41
+ # Moyennes et écarts-types d'ImageNet (pour normaliser les images)
42
+ IMAGENET_MEAN = [0.485, 0.456, 0.406] # moyenne des canaux R, G, B
43
+ IMAGENET_STD = [0.229, 0.224, 0.225] # écart-type des canaux R, G, B
44
+
45
+ # Mapping des classes numériques vers des labels lisibles
46
+ IDX_TO_LABEL = {
47
+ 0: "no_fire", # classe 0 → pas de feu
48
+ 1: "fire" # classe 1 → feu
49
+ }
50
+
51
+
52
+ # ----------------------------
53
+ # 3) Utilitaires device
54
+ # ----------------------------
55
+
56
+ def get_device():
57
+ """
58
+ Retourne le device à utiliser pour l'inférence :
59
+ - 'cuda' si un GPU est disponible
60
+ - sinon 'cpu'
61
+ """
62
+ # torch.cuda.is_available() renvoie True si un GPU CUDA est accessible
63
+ if torch.cuda.is_available():
64
+ return torch.device("cuda") # on utilisera le GPU
65
+ else:
66
+ return torch.device("cpu") # sinon le CPU
67
+
68
+
69
+ # ----------------------------
70
+ # 4) Chargement du modèle
71
+ # ----------------------------
72
+
73
+ def build_model(num_classes=2):
74
+ """
75
+ Construit l'architecture EfficientNet-B0 avec une tête
76
+ adaptée à la classification binaire (2 classes).
77
+ Les poids seront chargés ensuite via load_state_dict.
78
+ """
79
+ # On crée le modèle EfficientNet-B0 sans poids pré-entraînés ici
80
+ # (les poids spécifiques à ton projet seront chargés après)
81
+ model = timm.create_model("efficientnet_b0", pretrained=False)
82
+
83
+ # On récupère le nombre de features en entrée de la dernière couche
84
+ in_features = model.classifier.in_features
85
+
86
+ # On remplace la dernière couche par une couche linéaire avec num_classes sorties
87
+ model.classifier = nn.Linear(in_features, num_classes)
88
+
89
+ return model
90
+
91
+
92
+ def load_model(weights_path: str, map_location=None):
93
+ """
94
+ Charge le modèle EfficientNet-B0 avec les poids entraînés.
95
+
96
+ Paramètres
97
+ ----------
98
+ weights_path : str
99
+ Chemin vers le fichier .pt contenant les poids (state_dict).
100
+ map_location : torch.device ou None
101
+ Device sur lequel charger les poids.
102
+ Si None, on détecte automatiquement (GPU si dispo, sinon CPU).
103
+
104
+ Retour
105
+ ------
106
+ model : torch.nn.Module
107
+ Le modèle prêt pour l'inférence.
108
+ device : torch.device
109
+ Le device utilisé (cuda ou cpu).
110
+ """
111
+ # On détecte le device si non fourni
112
+ device = map_location if map_location is not None else get_device()
113
+
114
+ # On construit l'architecture du modèle
115
+ model = build_model(num_classes=2)
116
+
117
+ # On charge le dictionnaire de poids sauvegardés (state_dict)
118
+ state_dict = torch.load(weights_path, map_location=device)
119
+
120
+ # On applique les poids au modèle
121
+ model.load_state_dict(state_dict)
122
+
123
+ # On envoie le modèle sur le bon device (GPU ou CPU)
124
+ model = model.to(device)
125
+
126
+ # On passe le modèle en mode évaluation (important pour dropout, batchnorm, etc.)
127
+ model.eval()
128
+
129
+ return model, device
130
+
131
+
132
+ # ----------------------------
133
+ # 5) Transforms pour l'inférence
134
+ # ----------------------------
135
+
136
+ def get_val_transform():
137
+ """
138
+ Renvoie les transformations à appliquer aux images pour l'inférence.
139
+ Ce sont les mêmes que pour la validation :
140
+ - Resize 224x224
141
+ - ToTensor
142
+ - Normalize (ImageNet)
143
+ """
144
+ transform = transforms.Compose([
145
+ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), # redimensionne en 224x224
146
+ transforms.ToTensor(), # convertit PIL → Tensor [0,1]
147
+ transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) # normalise selon ImageNet
148
+ ])
149
+ return transform
150
+
151
+
152
+ # ----------------------------
153
+ # 6) Prétraitement d'une image
154
+ # ----------------------------
155
+
156
+ def preprocess_image(image: Image.Image, transform=None):
157
+ """
158
+ Applique les transforms à une image PIL et ajoute une dimension batch.
159
+
160
+ Paramètres
161
+ ----------
162
+ image : PIL.Image.Image
163
+ Image brute chargée (par exemple via Image.open(...)).
164
+ transform : callable ou None
165
+ Transformations à appliquer (si None, on utilise get_val_transform()).
166
+
167
+ Retour
168
+ ------
169
+ image_tensor : torch.Tensor
170
+ Tenseur prêt pour l'inférence, de taille [1, 3, 224, 224].
171
+ """
172
+ # Si aucune transform n'est fournie, on utilise la transform par défaut
173
+ if transform is None:
174
+ transform = get_val_transform()
175
+
176
+ # On applique la transform à l'image PIL → tensor [3, 224, 224]
177
+ img_tensor = transform(image)
178
+
179
+ # On ajoute une dimension batch devant : [1, 3, 224, 224]
180
+ img_tensor = img_tensor.unsqueeze(0)
181
+
182
+ return img_tensor
183
+
184
+
185
+ # ----------------------------
186
+ # 7) Fonction de prédiction principale
187
+ # ----------------------------
188
+
189
+ def predict_from_tensor(image_tensor: torch.Tensor,
190
+ model: torch.nn.Module,
191
+ device: torch.device,
192
+ threshold: float = 0.5):
193
+ """
194
+ Prédit la classe (fire/no_fire) à partir d'un tenseur déjà prétraité.
195
+
196
+ Paramètres
197
+ ----------
198
+ image_tensor : torch.Tensor
199
+ Tenseur d'images de taille [1, 3, 224, 224] (batch de 1 image).
200
+ model : torch.nn.Module
201
+ Modèle EfficientNet-B0 chargé.
202
+ device : torch.device
203
+ Device sur lequel le modèle est (cuda ou cpu).
204
+ threshold : float
205
+ Seuil sur la probabilité de FEU pour décider entre no_fire / fire.
206
+
207
+ Retour
208
+ ------
209
+ predicted_label : str
210
+ "fire" ou "no_fire".
211
+ fire_prob : float
212
+ Probabilité prédite pour la classe "fire" (entre 0 et 1).
213
+ """
214
+ # On envoie l'image sur le même device que le modèle
215
+ image_tensor = image_tensor.to(device)
216
+
217
+ # On désactive le calcul des gradients pour l'inférence
218
+ with torch.no_grad():
219
+ # Le modèle renvoie des logits de taille [1, 2]
220
+ outputs = model(image_tensor)
221
+
222
+ # On convertit en probabilités via softmax
223
+ probs = torch.softmax(outputs, dim=1)
224
+
225
+ # Probabilité de la classe fire (indice 1)
226
+ fire_prob = probs[0, 1].item()
227
+
228
+ # On décide du label en comparant à un seuil
229
+ if fire_prob >= threshold:
230
+ predicted_idx = 1 # feu
231
+ else:
232
+ predicted_idx = 0 # pas de feu
233
+
234
+ # Conversion en label lisible
235
+ predicted_label = IDX_TO_LABEL[predicted_idx]
236
+
237
+ return predicted_label, fire_prob
238
+
239
+
240
+ def predict_from_pil(image: Image.Image,
241
+ model: torch.nn.Module,
242
+ device: torch.device,
243
+ transform=None,
244
+ threshold: float = 0.5):
245
+ """
246
+ Prédit la classe à partir d'une image PIL.
247
+
248
+ Paramètres
249
+ ----------
250
+ image : PIL.Image.Image
251
+ Image chargée (par exemple via Image.open).
252
+ model : torch.nn.Module
253
+ Modèle EfficientNet-B0 chargé.
254
+ device : torch.device
255
+ Device (cuda ou cpu).
256
+ transform : callable ou None
257
+ Transformations à appliquer à l'image.
258
+ threshold : float
259
+ Seuil sur la probabilité de FEU.
260
+
261
+ Retour
262
+ ------
263
+ predicted_label : str
264
+ "fire" ou "no_fire".
265
+ fire_prob : float
266
+ Probabilité de "fire".
267
+ """
268
+ # On s'assure que l'image est en mode RGB
269
+ if image.mode != "RGB":
270
+ image = image.convert("RGB")
271
+
272
+ # On prétraite l'image (resize, tensor, normalize, batch)
273
+ image_tensor = preprocess_image(image, transform=transform)
274
+
275
+ # On délègue la prédiction à predict_from_tensor
276
+ return predict_from_tensor(image_tensor, model, device, threshold=threshold)
277
+
278
+
279
+ def predict_from_path(image_path: str,
280
+ model: torch.nn.Module,
281
+ device: torch.device,
282
+ transform=None,
283
+ threshold: float = 0.5):
284
+ """
285
+ Prédit la classe à partir d'un chemin vers une image.
286
+
287
+ Paramètres
288
+ ----------
289
+ image_path : str
290
+ Chemin vers le fichier image (jpg, png, etc.).
291
+ model : torch.nn.Module
292
+ Modèle EfficientNet-B0 chargé.
293
+ device : torch.device
294
+ Device (cuda ou cpu).
295
+ transform : callable ou None
296
+ Transformations à appliquer.
297
+ threshold : float
298
+ Seuil sur la probabilité de FEU.
299
+
300
+ Retour
301
+ ------
302
+ predicted_label : str
303
+ "fire" ou "no_fire".
304
+ fire_prob : float
305
+ Probabilité de "fire".
306
+ """
307
+ # On charge l'image depuis le disque via PIL
308
+ image = Image.open(image_path)
309
+
310
+ # On délègue la prédiction à la fonction base sur PIL
311
+ return predict_from_pil(image, model, device, transform=transform, threshold=threshold)
312
+
313
+
314
+ # ----------------------------
315
+ # 8) Exemple d'utilisation en script direct
316
+ # ----------------------------
317
+
318
+ if __name__ == "__main__":
319
+ """
320
+ Ce bloc s'exécute uniquement si on lance le fichier directement :
321
+ python inference.py
322
+
323
+ Tu peux le modifier pour faire un petit test rapide en local
324
+ ou dans un notebook via !python inference.py.
325
+ """
326
+ import os
327
+
328
+ # Chemin vers le fichier de poids (à adapter si besoin)
329
+ weights_path = "efficientnet_fire.pt"
330
+
331
+ if not os.path.exists(weights_path):
332
+ print(f"[ERREUR] Fichier de poids introuvable : {weights_path}")
333
+ else:
334
+ # 1) On charge le modèle et on détecte le device
335
+ model, device = load_model(weights_path)
336
+ print(f"Modèle chargé sur le device : {device}")
337
+
338
+ # 2) On récupère la transform de validation/inférence
339
+ transform = get_val_transform()
340
+
341
+ # 3) Exemple : prédire sur une image de test (chemin à adapter)
342
+ test_image_path = "example.jpg" # ← remplace par une vraie image
343
+
344
+ if not os.path.exists(test_image_path):
345
+ print(f"[INFO] Aucune image test trouvée à : {test_image_path}")
346
+ print(" Modifie le chemin dans __main__ pour tester une image.")
347
+ else:
348
+ label, prob = predict_from_path(
349
+ test_image_path,
350
+ model=model,
351
+ device=device,
352
+ transform=transform,
353
+ threshold=0.5
354
+ )
355
+ print(f"Résultat pour {test_image_path} : label={label}, prob_fire={prob:.4f}")
requirements.txt ADDED
Binary file (2.07 kB). View file