import random import torch from PIL import Image from data_utils import get_eval_transform, prepare_splits, get_class_names from train_utils import load_model, get_runtime_device def predict_uploaded_image(model_name: str, image: Image.Image): if not model_name: return "Veuillez sélectionner un modèle.", None if image is None: return "Veuillez importer une image.", None device = get_runtime_device() model, meta = load_model(model_name, device) class_names = meta["config"]["class_names"] transform = get_eval_transform() image = image.convert("RGB") tensor = transform(image).unsqueeze(0).to(device) with torch.no_grad(): logits = model(tensor) probs = torch.softmax(logits, dim=1).squeeze(0).cpu().tolist() pred_idx = int(torch.argmax(logits, dim=1).item()) result_text = ( f"Prédiction : {class_names[pred_idx]}\n" f"Confiance : {max(probs):.4f}\n\n" f"Modèle : {model_name}\n" f"Jeu de données : {meta['config']['dataset_name']}\n" f"Appareil utilisé : {device}" ) prob_dict = {class_names[i]: float(probs[i]) for i in range(len(class_names))} return result_text, prob_dict def test_random_sample(model_name: str): if not model_name: return None, "Veuillez sélectionner un modèle.", None device = get_runtime_device() model, meta = load_model(model_name, device) splits = prepare_splits() class_names = get_class_names() test_dataset = splits["test"] idx = random.randint(0, len(test_dataset) - 1) item = test_dataset[idx] image = item["image"] if not isinstance(image, Image.Image): image = Image.open(image) image = image.convert("RGB") label = int(item["label"]) label_name = class_names[label] transform = get_eval_transform() tensor = transform(image).unsqueeze(0).to(device) with torch.no_grad(): logits = model(tensor) probs = torch.softmax(logits, dim=1).squeeze(0).cpu().tolist() pred_idx = int(torch.argmax(logits, dim=1).item()) result_text = ( f"Échantillon test aléatoire\n" f"Vérité terrain : {label_name}\n" f"Prédiction : {class_names[pred_idx]}\n" f"Confiance : {max(probs):.4f}\n" f"Appareil utilisé : {device}" ) prob_dict = {class_names[i]: float(probs[i]) for i in range(len(class_names))} return image, result_text, prob_dict