File size: 12,112 Bytes
433b881
 
 
6c28497
433b881
2e0d90d
 
 
 
6c28497
2e0d90d
 
433b881
 
 
 
 
2e0d90d
 
 
 
 
 
 
 
 
 
 
 
 
433b881
 
 
 
 
 
2e0d90d
 
433b881
2e0d90d
 
 
433b881
2e0d90d
433b881
2e0d90d
 
433b881
 
2e0d90d
6c28497
 
 
 
 
 
 
 
 
 
 
 
 
2e0d90d
6c28497
 
2e0d90d
 
 
 
 
433b881
 
2e0d90d
433b881
 
 
 
 
 
 
 
 
6c28497
2e0d90d
433b881
 
 
6c28497
433b881
 
6c28497
433b881
6c28497
2e0d90d
6c28497
2e0d90d
 
 
433b881
6c28497
 
2e0d90d
 
 
6c28497
 
433b881
6c28497
433b881
6c28497
 
 
 
2e0d90d
6c28497
 
 
 
 
 
 
2e0d90d
6c28497
 
 
 
2e0d90d
 
6c28497
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433b881
 
 
 
6c28497
2e0d90d
6c28497
 
 
 
2e0d90d
433b881
2e0d90d
6c28497
 
 
 
 
 
2e0d90d
 
6c28497
 
2e0d90d
 
 
 
6c28497
 
 
433b881
2e0d90d
 
433b881
2e0d90d
 
 
433b881
 
 
 
2e0d90d
6c28497
 
 
2e0d90d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c28497
 
 
 
 
 
 
 
433b881
6c28497
 
433b881
2e0d90d
6c28497
 
433b881
6c28497
433b881
 
 
 
 
 
2e0d90d
433b881
 
6c28497
433b881
 
6c28497
2e0d90d
433b881
6c28497
 
 
433b881
6c28497
 
 
433b881
 
 
 
 
2e0d90d
433b881
 
6c28497
433b881
 
 
 
6c28497
433b881
6c28497
 
433b881
 
 
 
 
2e0d90d
433b881
 
6c28497
 
 
 
2e0d90d
6c28497
433b881
2e0d90d
 
433b881
 
 
 
6c28497
2e0d90d
433b881
 
2e0d90d
433b881
 
 
 
 
2e0d90d
 
 
 
6c28497
 
2e0d90d
6c28497
 
 
 
 
433b881
 
 
 
 
 
 
 
 
2e0d90d
 
 
433b881
 
 
 
2e0d90d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c28497
2e0d90d
 
 
 
433b881
 
2e0d90d
 
 
433b881
6c28497
 
2e0d90d
6c28497
 
 
 
 
433b881
 
 
 
6c28497
2e0d90d
6c28497
 
 
 
 
 
433b881
 
 
2e0d90d
433b881
 
 
2e0d90d
 
433b881
 
 
 
 
 
 
2e0d90d
433b881
 
 
 
2e0d90d
433b881
 
 
 
6c28497
433b881
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
"""
inference.py
------------
Module d'inférence pour des modèles de classification binaire : FIRE (1) / NO_FIRE (0).

Supporte :
- EfficientNet-B0 (classification)
- Inception v3 (classification)
- YOLO (Ultralytics, détection) pour localiser le feu

Retour principal des fonctions de prédiction :
    predicted_label (str), fire_prob (float), annotated_image (PIL.Image ou None)
"""

# ----------------------------
# 1) Imports
# ----------------------------
import os

import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import timm

# YOLO (Ultralytics) – optionnel
try:
    from ultralytics import YOLO
except ImportError:
    YOLO = None  # si la lib n'est pas installée, on gère ça proprement


# ----------------------------
# 2) Constantes globales
# ----------------------------

# Taille d'entrée par défaut si rien n'est précisé
DEFAULT_IMAGE_SIZE = 224

# Stats ImageNet
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

# Mapping idx -> label lisible
IDX_TO_LABEL = {
    0: "no_fire",
    1: "fire",
}

# Registre des modèles de classification (timm)
MODEL_REGISTRY = {
    "efficientnet_b0": {
        "timm_name": "efficientnet_b0",
        "image_size": 224,
        "classifier_attr": "classifier",
    },
    "inception_v3": {
        "timm_name": "inception_v3",
        "image_size": 299,
        "classifier_attr": "fc",
    },
}

# Modèle par défaut si on ne sait pas quoi choisir
DEFAULT_MODEL_KEY = "efficientnet_b0"

# IDs des classes "feu" pour YOLO (à adapter si besoin après entraînement)
# Exemple : si model.names == {0: 'fire'} → [0]
#          si model.names == {0: 'no_fire', 1: 'fire'} → [1]
FIRE_CLASS_IDS = [0]


# ----------------------------
# 3) Device
# ----------------------------

def get_device():
    """
    Retourne le device à utiliser pour l'inférence :
    - 'cuda' si un GPU est disponible
    - sinon 'cpu'
    """
    if torch.cuda.is_available():
        return torch.device("cuda")
    return torch.device("cpu")


# ----------------------------
# 4) Détection du type de modèle
# ----------------------------

def infer_model_key_from_path(weights_path: str) -> str:
    """
    Devine une clé de modèle (model_key) à partir du nom de fichier.

    Exemple :
    - "yolov8_fire.pt"        → "yolo"
    - "inception3_fire.pt"    → "inception_v3"
    - "efficientnet_fire.pt"  → "efficientnet_b0" (par défaut)
    """
    filename = os.path.basename(weights_path).lower()

    if "yolo" in filename:
        return "yolo"

    if "inception" in filename:
        return "inception_v3"

    return DEFAULT_MODEL_KEY


def get_model_config(model_key: str) -> dict:
    """
    Renvoie la config du modèle à partir d'une model_key.
    Fallback : EfficientNet-B0 si model_key inconnue.
    """
    if model_key in MODEL_REGISTRY:
        return MODEL_REGISTRY[model_key]
    return MODEL_REGISTRY[DEFAULT_MODEL_KEY]


# ----------------------------
# 5) Construction modèle (classification)
# ----------------------------

def build_model(model_key: str, num_classes: int = 2) -> torch.nn.Module:
    """
    Construit un modèle de classification (EfficientNet, Inception...)
    et adapte la dernière couche à num_classes sorties.
    """
    config = get_model_config(model_key)
    timm_name = config["timm_name"]
    classifier_attr = config["classifier_attr"]

    model = timm.create_model(timm_name, pretrained=False)

    classifier = getattr(model, classifier_attr)

    if isinstance(classifier, nn.Linear):
        in_features = classifier.in_features
    else:
        raise ValueError(
            f"Impossible de déterminer in_features pour la tête du modèle '{timm_name}'. "
            f"Attribut '{classifier_attr}' de type {type(classifier)} non supporté."
        )

    new_classifier = nn.Linear(in_features, num_classes)
    setattr(model, classifier_attr, new_classifier)

    return model


# ----------------------------
# 6) Nettoyage de state_dict
# ----------------------------

def _clean_state_dict_keys(state_dict: dict) -> dict:
    """
    Nettoie les clés pour gérer les prefixes 'model.' (Lightning) et 'module.' (DataParallel).
    """
    new_state = {}
    for k, v in state_dict.items():
        new_key = k
        if new_key.startswith("model."):
            new_key = new_key[len("model."):]
        if new_key.startswith("module."):
            new_key = new_key[len("module."):]
        new_state[new_key] = v
    return new_state


# ----------------------------
# 7) Chargement du modèle
# ----------------------------

def load_model(weights_path: str, map_location=None, model_key: str | None = None):
    """
    Charge un modèle avec les poids entraînés.

    - Pour YOLO (Ultralytics) : charge un modèle de détection.
    - Pour EfficientNet / Inception : charge un modèle de classification binaire.

    Retour :
    --------
    model : torch.nn.Module ou YOLO
    device : torch.device
    """
    device = map_location if map_location is not None else get_device()

    # Détecter le type de modèle si non fourni
    if model_key is None:
        model_key = infer_model_key_from_path(weights_path)

    # 🔹 Cas YOLO : modèle de détection (Ultralytics)
    if model_key == "yolo":
        if YOLO is None:
            raise ImportError(
                "Le modèle YOLO est demandé mais la librairie 'ultralytics' "
                "n'est pas installée. Ajoutez 'ultralytics' dans requirements.txt."
            )
        yolo_model = YOLO(weights_path)
        # YOLO gère déjà souvent le device en interne, on essaye juste par sécurité
        try:
            yolo_model.to(device)
        except Exception:
            pass
        return yolo_model, device

    # 🔹 Cas classification (EfficientNet, Inception...)
    checkpoint = torch.load(weights_path, map_location=device)

    if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
        state_dict = checkpoint["state_dict"]
    else:
        state_dict = checkpoint

    state_dict = _clean_state_dict_keys(state_dict)

    model = build_model(model_key=model_key, num_classes=2)
    missing, unexpected = model.load_state_dict(state_dict, strict=False)

    # (optionnel) debug :
    # print("Missing keys:", missing)
    # print("Unexpected keys:", unexpected)

    model = model.to(device)
    model.eval()

    return model, device


# ----------------------------
# 8) Transforms pour l'inférence
# ----------------------------

def get_val_transform(image_size: int | None = None):
    """
    Renvoie les transformations à appliquer aux images pour l'inférence.

    Si image_size est None → DEFAULT_IMAGE_SIZE (224).
    """
    if image_size is None:
        image_size = DEFAULT_IMAGE_SIZE

    transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
    ])
    return transform


# ----------------------------
# 9) Prétraitement d'une image
# ----------------------------

def preprocess_image(image: Image.Image, transform=None, image_size: int | None = None):
    """
    Applique les transforms à une image PIL et ajoute une dimension batch.
    """
    if transform is None:
        transform = get_val_transform(image_size=image_size)

    img_tensor = transform(image)        # [3, H, W]
    img_tensor = img_tensor.unsqueeze(0) # [1, 3, H, W]

    return img_tensor


# ----------------------------
# 10) Prédiction depuis un tenseur (classification)
# ----------------------------

def predict_from_tensor(
    image_tensor: torch.Tensor,
    model: torch.nn.Module,
    device: torch.device,
    threshold: float = 0.5,
):
    """
    Prédit la classe (fire/no_fire) à partir d'un tenseur déjà prétraité
    pour les modèles de classification (EfficientNet, Inception...).
    """
    image_tensor = image_tensor.to(device)

    with torch.no_grad():
        outputs = model(image_tensor)           # logits [1, 2]
        probs = torch.softmax(outputs, dim=1)   # probas

        fire_prob = probs[0, 1].item()
        predicted_idx = 1 if fire_prob >= threshold else 0

    predicted_label = IDX_TO_LABEL[predicted_idx]
    return predicted_label, fire_prob


# ----------------------------
# 11) Prédiction depuis une image PIL
# ----------------------------

def predict_from_pil(
    image: Image.Image,
    model,
    device: torch.device,
    transform=None,
    threshold: float = 0.5,
    image_size: int | None = None,
):
    """
    Prédit la classe à partir d'une image PIL.

    Retour
    ------
    predicted_label : str
        "fire" ou "no_fire".
    fire_prob : float
        Probabilité de "fire".
    annotated_image : PIL.Image.Image ou None
        Image annotée (bounding boxes) pour les modèles de détection (YOLO),
        None pour les modèles de classification.
    """
    if image.mode != "RGB":
        image = image.convert("RGB")

    # 🔹 Cas YOLO : modèle de détection (Ultralytics)
    if hasattr(model, "task") and getattr(model, "task", None) == "detect":
        results = model(image)
        result = results[0]

        boxes = getattr(result, "boxes", None)
        fire_prob = 0.0

        if boxes is not None and len(boxes) > 0:
            classes = boxes.cls  # ids des classes (tensor)
            confs = boxes.conf   # scores de confiance (tensor)

            # masque des boxes "feu"
            mask_fire = torch.zeros_like(classes, dtype=torch.bool)
            for cid in FIRE_CLASS_IDS:
                mask_fire |= (classes == cid)

            if mask_fire.any():
                fire_prob = float(confs[mask_fire].max().item())

        predicted_label = "fire" if fire_prob >= threshold else "no_fire"

        # Image annotée avec les bounding boxes
        annotated_image = None
        try:
            annotated_array = result.plot()          # numpy array BGR
            annotated_image = Image.fromarray(annotated_array[..., ::-1])  # BGR -> RGB
        except Exception:
            annotated_image = None

        return predicted_label, fire_prob, annotated_image

    # 🔹 Cas classification classique
    image_tensor = preprocess_image(image, transform=transform, image_size=image_size)
    predicted_label, fire_prob = predict_from_tensor(image_tensor, model, device, threshold=threshold)
    annotated_image = None

    return predicted_label, fire_prob, annotated_image


# ----------------------------
# 12) Prédiction depuis un chemin de fichier
# ----------------------------

def predict_from_path(
    image_path: str,
    model,
    device: torch.device,
    transform=None,
    threshold: float = 0.5,
    image_size: int | None = None,
):
    """
    Prédit la classe à partir d'un chemin vers une image.
    """
    image = Image.open(image_path)
    return predict_from_pil(
        image=image,
        model=model,
        device=device,
        transform=transform,
        threshold=threshold,
        image_size=image_size,
    )


# ----------------------------
# 13) Exemple d'utilisation en script direct
# ----------------------------

if __name__ == "__main__":
    # Petit test local (à adapter)
    weights_path = "efficientnet_fire.pt"
    if not os.path.exists(weights_path):
        print(f"[ERREUR] Fichier de poids introuvable : {weights_path}")
    else:
        model, device = load_model(weights_path)
        print(f"Modèle chargé sur le device : {device}")

        transform = get_val_transform()
        test_image_path = "example.jpg"

        if not os.path.exists(test_image_path):
            print(f"[INFO] Aucune image test trouvée à : {test_image_path}")
        else:
            label, prob, _ = predict_from_path(
                test_image_path,
                model=model,
                device=device,
                transform=transform,
                threshold=0.5,
            )
            print(f"Résultat pour {test_image_path} : label={label}, prob_fire={prob:.4f}")