File size: 2,483 Bytes
ca9c54c
 
 
 
 
cfe30ee
ca9c54c
 
 
 
 
 
 
 
 
 
 
 
 
 
cfe30ee
ca9c54c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cfe30ee
 
 
ca9c54c
 
 
 
cfe30ee
 
 
 
 
 
ca9c54c
 
 
cfe30ee
ca9c54c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import random

import torch
from PIL import Image

from data_utils import get_eval_transform, prepare_splits, get_class_names
from train_utils import load_model, get_runtime_device


def predict_uploaded_image(model_name: str, image: Image.Image):
    if not model_name:
        return "Veuillez sélectionner un modèle.", None

    if image is None:
        return "Veuillez importer une image.", None

    device = get_runtime_device()
    model, meta = load_model(model_name, device)

    class_names = meta["config"]["class_names"]
    transform = get_eval_transform()

    image = image.convert("RGB")
    tensor = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        logits = model(tensor)
        probs = torch.softmax(logits, dim=1).squeeze(0).cpu().tolist()
        pred_idx = int(torch.argmax(logits, dim=1).item())

    result_text = (
        f"Prédiction : {class_names[pred_idx]}\n"
        f"Confiance : {max(probs):.4f}\n\n"
        f"Modèle : {model_name}\n"
        f"Jeu de données : {meta['config']['dataset_name']}\n"
        f"Appareil utilisé : {device}"
    )

    prob_dict = {class_names[i]: float(probs[i]) for i in range(len(class_names))}
    return result_text, prob_dict


def test_random_sample(model_name: str):
    if not model_name:
        return None, "Veuillez sélectionner un modèle.", None

    device = get_runtime_device()
    model, meta = load_model(model_name, device)

    splits = prepare_splits()
    class_names = get_class_names()
    test_dataset = splits["test"]

    idx = random.randint(0, len(test_dataset) - 1)
    item = test_dataset[idx]

    image = item["image"]
    if not isinstance(image, Image.Image):
        image = Image.open(image)

    image = image.convert("RGB")

    label = int(item["label"])
    label_name = class_names[label]

    transform = get_eval_transform()
    tensor = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        logits = model(tensor)
        probs = torch.softmax(logits, dim=1).squeeze(0).cpu().tolist()
        pred_idx = int(torch.argmax(logits, dim=1).item())

    result_text = (
        f"Échantillon test aléatoire\n"
        f"Vérité terrain : {label_name}\n"
        f"Prédiction : {class_names[pred_idx]}\n"
        f"Confiance : {max(probs):.4f}\n"
        f"Appareil utilisé : {device}"
    )

    prob_dict = {class_names[i]: float(probs[i]) for i in range(len(class_names))}
    return image, result_text, prob_dict