| import torch |
|
|
| from PIL import Image |
| import matplotlib.pyplot as plt |
| import os |
| import sys |
|
|
|
|
| |
| sys.path.append(os.path.abspath("./src")) |
|
|
| from src.model import CLIPClassifier, get_processor |
|
|
|
|
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| num_classes = 2 |
| model = CLIPClassifier(num_classes=num_classes).to(device) |
|
|
| |
| model_load_path = "./clip_model.pth" |
| if not os.path.exists(model_load_path): |
| raise FileNotFoundError(f"❌ Le fichier {model_load_path} n'existe pas ! Lance `train.py` d'abord.") |
|
|
| model.load_state_dict(torch.load(model_load_path, map_location=device)) |
| model.eval() |
| print("✅ Modèle chargé avec succès !") |
|
|
| |
| processor = get_processor() |
|
|
| |
| test_images = ["./001.JPG", "./0004.JPG"] |
|
|
| fig, axes = plt.subplots(1, len(test_images), figsize=(10, 5)) |
|
|
| for i, img_path in enumerate(test_images): |
| if not os.path.exists(img_path): |
| raise FileNotFoundError(f"❌ L'image {img_path} est introuvable !") |
|
|
| |
| image = Image.open(img_path).convert("RGB") |
|
|
| |
| inputs = processor(images=image, return_tensors="pt") |
| pixel_values = inputs["pixel_values"].squeeze(0).to(device) |
|
|
| |
| with torch.no_grad(): |
| logits = model(pixel_values.unsqueeze(0)) |
| predicted_label = torch.argmax(logits, dim=1).item() |
|
|
| |
| axes[i].imshow(image) |
| axes[i].set_title(f"Prédit : {'Normal' if predicted_label == 0 else 'Anomaly'}", fontsize=12) |
| axes[i].axis("off") |
|
|
| plt.show() |
|
|