import random import torch from PIL import Image from config import IMAGE_SIZE from data_utils import get_transform, load_charcoal_dataset from train_utils import load_model, get_runtime_device def predict_uploaded_image(model_name: str, image: Image.Image): if not model_name: return "Veuillez sélectionner un modèle.", None if image is None: return "Veuillez importer une image.", None device = get_runtime_device() model, meta = load_model(model_name, device) class_names = meta["config"]["class_names"] transform = get_transform() image = image.convert("RGB") tensor = transform(image).unsqueeze(0).to(device) with torch.no_grad(): logits = model(tensor) probs = torch.softmax(logits, dim=1).squeeze(0).cpu().tolist() pred_idx = int(torch.argmax(logits, dim=1).item()) result_text = ( f"Prédiction : {class_names[pred_idx]}\n" f"Confiance : {max(probs):.4f}\n\n" f"Modèle : {model_name}\n" f"Jeu de données : {meta['config']['dataset_name']}\n" f"Appareil utilisé : {device}" ) prob_dict = {class_names[i]: float(probs[i]) for i in range(len(class_names))} return result_text, prob_dict def test_random_sample(model_name: str): if not model_name: return None, "Veuillez sélectionner un modèle.", None device = get_runtime_device() model, meta = load_model(model_name, device) raw, class_names = load_charcoal_dataset() test_dataset = raw["test"] idx = random.randint(0, len(test_dataset) - 1) item = test_dataset[idx] image = item["image"].convert("RGB") label = int(item["label"]) label_name = class_names[label] transform = get_transform() tensor = transform(image).unsqueeze(0).to(device) with torch.no_grad(): logits = model(tensor) probs = torch.softmax(logits, dim=1).squeeze(0).cpu().tolist() pred_idx = int(torch.argmax(logits, dim=1).item()) result_text = ( f"Échantillon test aléatoire\n" f"Vérité terrain : {label_name}\n" f"Prédiction : {class_names[pred_idx]}\n" f"Confiance : {max(probs):.4f}\n" f"Appareil utilisé : {device}" ) prob_dict = {class_names[i]: float(probs[i]) for i in range(len(class_names))} return image, result_text, prob_dict