|
|
import torch |
|
|
import torch.nn as nn |
|
|
from PIL import Image |
|
|
from torch import Tensor |
|
|
from torchvision import transforms |
|
|
|
|
|
from model import VGG16WithCNN |
|
|
|
|
|
|
|
|
def getModel(device: torch.device, model_path: str): |
|
|
model = VGG16WithCNN(5) |
|
|
|
|
|
model.load_state_dict( |
|
|
torch.load( |
|
|
model_path, |
|
|
weights_only=True, |
|
|
) |
|
|
) |
|
|
|
|
|
model.to(device) |
|
|
return model |
|
|
|
|
|
|
|
|
def preprocess_image(image_path: str, image_size=(224, 224)): |
|
|
""" |
|
|
预处理图片,使其符合模型输入要求 |
|
|
""" |
|
|
transform = transforms.Compose( |
|
|
[ |
|
|
transforms.Resize(image_size), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
image: Image.Image = Image.open(image_path).convert("RGB") |
|
|
image_tensor: Tensor = transform(image) |
|
|
|
|
|
image_tensor = image_tensor.unsqueeze(0) |
|
|
|
|
|
return image_tensor |
|
|
|
|
|
|
|
|
def predict_single_image( |
|
|
image_path: str, model: nn.Module, device: torch.device, class_names: list[str] |
|
|
) -> str: |
|
|
""" |
|
|
预测单个图片的标签 |
|
|
Args: |
|
|
image_path: 图片路径 |
|
|
model: 模型 |
|
|
device: 设备 |
|
|
|
|
|
Returns: |
|
|
预测的标签名 |
|
|
""" |
|
|
|
|
|
image_tensor = preprocess_image(image_path) |
|
|
|
|
|
image_tensor = image_tensor.to(device) |
|
|
|
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
output = model(image_tensor) |
|
|
_, pred = torch.max(output, 1) |
|
|
|
|
|
predicted_label = class_names[int(pred.item())] |
|
|
|
|
|
return predicted_label |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
|
|
p = "./checkpoints/vgg_net_model_50.pth" |
|
|
|
|
|
class_names = [ |
|
|
"Bacterialblight", |
|
|
"Blast", |
|
|
"Brownspot", |
|
|
"Healthy", |
|
|
"Tungro", |
|
|
] |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model = getModel(device=device, model_path=p) |
|
|
|
|
|
test_image_path = "./images/BLAST1_011.jpg" |
|
|
try: |
|
|
predicted_label = predict_single_image( |
|
|
test_image_path, model, device, class_names=class_names |
|
|
) |
|
|
print("\nSingle image prediction result:") |
|
|
print(f"Image: {test_image_path}") |
|
|
print(f"Predicted label: {predicted_label}") |
|
|
except FileNotFoundError: |
|
|
print("Please provide a valid image path to test single image prediction") |
|
|
|