import torch import torch.nn as nn from PIL import Image from torch import Tensor from torchvision import transforms from model import VGG16WithCNN def getModel(device: torch.device, model_path: str): model = VGG16WithCNN(5) # 加载训练好的权重 model.load_state_dict( torch.load( model_path, weights_only=True, ) ) model.to(device) return model def preprocess_image(image_path: str, image_size=(224, 224)): """ 预处理图片,使其符合模型输入要求 """ transform = transforms.Compose( [ transforms.Resize(image_size), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ] ) # 打开图片并转换 image: Image.Image = Image.open(image_path).convert("RGB") image_tensor: Tensor = transform(image) # 添加batch维度 image_tensor = image_tensor.unsqueeze(0) return image_tensor def predict_single_image( image_path: str, model: nn.Module, device: torch.device, class_names: list[str] ) -> str: """ 预测单个图片的标签 Args: image_path: 图片路径 model: 模型 device: 设备 Returns: 预测的标签名 """ image_tensor = preprocess_image(image_path) image_tensor = image_tensor.to(device) # 预测 model.eval() with torch.no_grad(): output = model(image_tensor) _, pred = torch.max(output, 1) predicted_label = class_names[int(pred.item())] return predicted_label if __name__ == "__main__": # 测试单张图片预测 # 注意:需要替换为实际的测试图片路径 p = "./checkpoints/vgg_net_model_50.pth" class_names = [ "Bacterialblight", "Blast", "Brownspot", "Healthy", "Tungro", ] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = getModel(device=device, model_path=p) test_image_path = "./images/BLAST1_011.jpg" try: predicted_label = predict_single_image( test_image_path, model, device, class_names=class_names ) print("\nSingle image prediction result:") print(f"Image: {test_image_path}") print(f"Predicted label: {predicted_label}") except FileNotFoundError: print("Please provide a valid image path to test single image prediction")