File size: 2,479 Bytes
d10b7cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch
import torch.nn as nn
from PIL import Image
from torch import Tensor
from torchvision import transforms

from model import VGG16WithCNN


def getModel(device: torch.device, model_path: str):
    model = VGG16WithCNN(5)
    # 加载训练好的权重
    model.load_state_dict(
        torch.load(
            model_path,
            weights_only=True,
        )
    )

    model.to(device)
    return model


def preprocess_image(image_path: str, image_size=(224, 224)):
    """
    预处理图片,使其符合模型输入要求
    """
    transform = transforms.Compose(
        [
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )

    # 打开图片并转换
    image: Image.Image = Image.open(image_path).convert("RGB")
    image_tensor: Tensor = transform(image)
    # 添加batch维度
    image_tensor = image_tensor.unsqueeze(0)

    return image_tensor


def predict_single_image(
    image_path: str, model: nn.Module, device: torch.device, class_names: list[str]
) -> str:
    """
    预测单个图片的标签
    Args:
        image_path: 图片路径
        model: 模型
        device: 设备

    Returns:
        预测的标签名
    """

    image_tensor = preprocess_image(image_path)

    image_tensor = image_tensor.to(device)
    # 预测
    model.eval()
    with torch.no_grad():
        output = model(image_tensor)
        _, pred = torch.max(output, 1)

    predicted_label = class_names[int(pred.item())]

    return predicted_label


if __name__ == "__main__":
    # 测试单张图片预测
    # 注意:需要替换为实际的测试图片路径

    p = "./checkpoints/vgg_net_model_50.pth"

    class_names = [
        "Bacterialblight",
        "Blast",
        "Brownspot",
        "Healthy",
        "Tungro",
    ]

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = getModel(device=device, model_path=p)

    test_image_path = "./images/BLAST1_011.jpg"
    try:
        predicted_label = predict_single_image(
            test_image_path, model, device, class_names=class_names
        )
        print("\nSingle image prediction result:")
        print(f"Image: {test_image_path}")
        print(f"Predicted label: {predicted_label}")
    except FileNotFoundError:
        print("Please provide a valid image path to test single image prediction")