Spaces:
Sleeping
Sleeping
| import random | |
| import torch | |
| from PIL import Image | |
| from data_utils import get_eval_transform, prepare_splits, get_class_names | |
| from train_utils import load_model, get_runtime_device | |
| def predict_uploaded_image(model_name: str, image: Image.Image): | |
| if not model_name: | |
| return "Veuillez sélectionner un modèle.", None | |
| if image is None: | |
| return "Veuillez importer une image.", None | |
| device = get_runtime_device() | |
| model, meta = load_model(model_name, device) | |
| class_names = meta["config"]["class_names"] | |
| transform = get_eval_transform() | |
| image = image.convert("RGB") | |
| tensor = transform(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| logits = model(tensor) | |
| probs = torch.softmax(logits, dim=1).squeeze(0).cpu().tolist() | |
| pred_idx = int(torch.argmax(logits, dim=1).item()) | |
| result_text = ( | |
| f"Prédiction : {class_names[pred_idx]}\n" | |
| f"Confiance : {max(probs):.4f}\n\n" | |
| f"Modèle : {model_name}\n" | |
| f"Jeu de données : {meta['config']['dataset_name']}\n" | |
| f"Appareil utilisé : {device}" | |
| ) | |
| prob_dict = {class_names[i]: float(probs[i]) for i in range(len(class_names))} | |
| return result_text, prob_dict | |
| def test_random_sample(model_name: str): | |
| if not model_name: | |
| return None, "Veuillez sélectionner un modèle.", None | |
| device = get_runtime_device() | |
| model, meta = load_model(model_name, device) | |
| splits = prepare_splits() | |
| class_names = get_class_names() | |
| test_dataset = splits["test"] | |
| idx = random.randint(0, len(test_dataset) - 1) | |
| item = test_dataset[idx] | |
| image = item["image"] | |
| if not isinstance(image, Image.Image): | |
| image = Image.open(image) | |
| image = image.convert("RGB") | |
| label = int(item["label"]) | |
| label_name = class_names[label] | |
| transform = get_eval_transform() | |
| tensor = transform(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| logits = model(tensor) | |
| probs = torch.softmax(logits, dim=1).squeeze(0).cpu().tolist() | |
| pred_idx = int(torch.argmax(logits, dim=1).item()) | |
| result_text = ( | |
| f"Échantillon test aléatoire\n" | |
| f"Vérité terrain : {label_name}\n" | |
| f"Prédiction : {class_names[pred_idx]}\n" | |
| f"Confiance : {max(probs):.4f}\n" | |
| f"Appareil utilisé : {device}" | |
| ) | |
| prob_dict = {class_names[i]: float(probs[i]) for i in range(len(class_names))} | |
| return image, result_text, prob_dict |