Image_Classification / predict_utils.py
CircleStar's picture
Update predict_utils.py
cfe30ee verified
import random
import torch
from PIL import Image
from data_utils import get_eval_transform, prepare_splits, get_class_names
from train_utils import load_model, get_runtime_device
def predict_uploaded_image(model_name: str, image: Image.Image):
if not model_name:
return "Veuillez sélectionner un modèle.", None
if image is None:
return "Veuillez importer une image.", None
device = get_runtime_device()
model, meta = load_model(model_name, device)
class_names = meta["config"]["class_names"]
transform = get_eval_transform()
image = image.convert("RGB")
tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
logits = model(tensor)
probs = torch.softmax(logits, dim=1).squeeze(0).cpu().tolist()
pred_idx = int(torch.argmax(logits, dim=1).item())
result_text = (
f"Prédiction : {class_names[pred_idx]}\n"
f"Confiance : {max(probs):.4f}\n\n"
f"Modèle : {model_name}\n"
f"Jeu de données : {meta['config']['dataset_name']}\n"
f"Appareil utilisé : {device}"
)
prob_dict = {class_names[i]: float(probs[i]) for i in range(len(class_names))}
return result_text, prob_dict
def test_random_sample(model_name: str):
if not model_name:
return None, "Veuillez sélectionner un modèle.", None
device = get_runtime_device()
model, meta = load_model(model_name, device)
splits = prepare_splits()
class_names = get_class_names()
test_dataset = splits["test"]
idx = random.randint(0, len(test_dataset) - 1)
item = test_dataset[idx]
image = item["image"]
if not isinstance(image, Image.Image):
image = Image.open(image)
image = image.convert("RGB")
label = int(item["label"])
label_name = class_names[label]
transform = get_eval_transform()
tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
logits = model(tensor)
probs = torch.softmax(logits, dim=1).squeeze(0).cpu().tolist()
pred_idx = int(torch.argmax(logits, dim=1).item())
result_text = (
f"Échantillon test aléatoire\n"
f"Vérité terrain : {label_name}\n"
f"Prédiction : {class_names[pred_idx]}\n"
f"Confiance : {max(probs):.4f}\n"
f"Appareil utilisé : {device}"
)
prob_dict = {class_names[i]: float(probs[i]) for i in range(len(class_names))}
return image, result_text, prob_dict