File size: 3,745 Bytes
a616e56
 
cdc317a
a616e56
 
 
cdc317a
c722f9c
cdc317a
 
 
 
 
 
 
 
 
 
 
a616e56
 
 
 
 
 
 
 
cdc317a
 
a616e56
cdc317a
a616e56
cdc317a
 
 
 
 
 
 
 
 
 
 
 
 
a616e56
 
 
 
 
cdc317a
 
a616e56
 
 
 
 
 
 
 
 
cdc317a
 
 
a616e56
 
c722f9c
 
a616e56
 
 
c722f9c
 
 
 
cdc317a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a616e56
 
 
cdc317a
 
 
 
a616e56
cdc317a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import random

import numpy as np
import torch
from PIL import Image

from config import CLASSICAL_MODEL_TYPES
from data_utils import get_eval_transform, prepare_splits, get_class_names
from train_utils import load_model, get_runtime_device, _load_meta


def _extract_feature(image: Image.Image, device: torch.device) -> np.ndarray:
    from backbone_utils import load_backbone
    backbone = load_backbone(device)
    backbone.eval()
    tensor = get_eval_transform()(image.convert("RGB")).unsqueeze(0).to(device)
    with torch.no_grad():
        feat = backbone(tensor)
    return feat.cpu().numpy()


def predict_uploaded_image(model_name: str, image: Image.Image):
    if not model_name:
        return "Veuillez sélectionner un modèle.", None
    if image is None:
        return "Veuillez importer une image.", None

    meta = _load_meta(model_name)
    model_type = meta["config"].get("model_type", "cnn")
    class_names = meta["config"]["class_names"]
    device = get_runtime_device()

    if model_type in CLASSICAL_MODEL_TYPES:
        from classical_ml_utils import load_classical_pipeline
        pipeline = load_classical_pipeline(model_name)
        feat = _extract_feature(image, device)
        probs = pipeline.predict_proba(feat)[0].tolist()
        pred_idx = int(np.argmax(probs))
    else:
        model, _ = load_model(model_name, device)
        tensor = get_eval_transform()(image.convert("RGB")).unsqueeze(0).to(device)
        with torch.no_grad():
            logits = model(tensor)
            probs = torch.softmax(logits, dim=1).squeeze(0).cpu().tolist()
            pred_idx = int(torch.argmax(logits, dim=1).item())

    result_text = (
        f"Prédiction : {class_names[pred_idx]}\n"
        f"Confiance : {max(probs):.4f}\n\n"
        f"Modèle : {model_name}\n"
        f"Type : {model_type}\n"
        f"Appareil : {device}"
    )
    prob_dict = {class_names[i]: float(probs[i]) for i in range(len(class_names))}
    return result_text, prob_dict


def test_random_sample(model_name: str):
    if not model_name:
        return None, "Veuillez sélectionner un modèle.", None

    meta = _load_meta(model_name)
    model_type = meta["config"].get("model_type", "cnn")
    class_names = get_class_names()
    device = get_runtime_device()

    splits = prepare_splits()
    test_dataset = splits["test"]

    idx = random.randint(0, len(test_dataset) - 1)
    item = test_dataset[idx]
    image = item["image"]
    if not isinstance(image, Image.Image):
        image = Image.open(image)
    image = image.convert("RGB")
    label_name = class_names[int(item["label"])]

    if model_type in CLASSICAL_MODEL_TYPES:
        from classical_ml_utils import load_classical_pipeline
        pipeline = load_classical_pipeline(model_name)
        feat = _extract_feature(image, device)
        probs = pipeline.predict_proba(feat)[0].tolist()
        pred_idx = int(np.argmax(probs))
    else:
        model, _ = load_model(model_name, device)
        tensor = get_eval_transform()(image).unsqueeze(0).to(device)
        with torch.no_grad():
            logits = model(tensor)
            probs = torch.softmax(logits, dim=1).squeeze(0).cpu().tolist()
            pred_idx = int(torch.argmax(logits, dim=1).item())

    model_class_names = meta["config"]["class_names"]
    result_text = (
        f"Échantillon test aléatoire\n"
        f"Vérité terrain : {label_name}\n"
        f"Prédiction    : {model_class_names[pred_idx]}\n"
        f"Confiance     : {max(probs):.4f}\n"
        f"Type modèle   : {model_type}\n"
        f"Appareil      : {device}"
    )
    prob_dict = {model_class_names[i]: float(probs[i]) for i in range(len(model_class_names))}
    return image, result_text, prob_dict