Image_Classification / predict_utils.py
CircleStar's picture
Create predict_utils.py
ca9c54c verified
raw
history blame
2.37 kB
import random
import torch
from PIL import Image
from config import IMAGE_SIZE
from data_utils import get_transform, load_charcoal_dataset
from train_utils import load_model, get_runtime_device
def predict_uploaded_image(model_name: str, image: Image.Image):
if not model_name:
return "Veuillez sélectionner un modèle.", None
if image is None:
return "Veuillez importer une image.", None
device = get_runtime_device()
model, meta = load_model(model_name, device)
class_names = meta["config"]["class_names"]
transform = get_transform()
image = image.convert("RGB")
tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
logits = model(tensor)
probs = torch.softmax(logits, dim=1).squeeze(0).cpu().tolist()
pred_idx = int(torch.argmax(logits, dim=1).item())
result_text = (
f"Prédiction : {class_names[pred_idx]}\n"
f"Confiance : {max(probs):.4f}\n\n"
f"Modèle : {model_name}\n"
f"Jeu de données : {meta['config']['dataset_name']}\n"
f"Appareil utilisé : {device}"
)
prob_dict = {class_names[i]: float(probs[i]) for i in range(len(class_names))}
return result_text, prob_dict
def test_random_sample(model_name: str):
if not model_name:
return None, "Veuillez sélectionner un modèle.", None
device = get_runtime_device()
model, meta = load_model(model_name, device)
raw, class_names = load_charcoal_dataset()
test_dataset = raw["test"]
idx = random.randint(0, len(test_dataset) - 1)
item = test_dataset[idx]
image = item["image"].convert("RGB")
label = int(item["label"])
label_name = class_names[label]
transform = get_transform()
tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
logits = model(tensor)
probs = torch.softmax(logits, dim=1).squeeze(0).cpu().tolist()
pred_idx = int(torch.argmax(logits, dim=1).item())
result_text = (
f"Échantillon test aléatoire\n"
f"Vérité terrain : {label_name}\n"
f"Prédiction : {class_names[pred_idx]}\n"
f"Confiance : {max(probs):.4f}\n"
f"Appareil utilisé : {device}"
)
prob_dict = {class_names[i]: float(probs[i]) for i in range(len(class_names))}
return image, result_text, prob_dict