File size: 3,333 Bytes
6085c77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import torch
import numpy as np
import cv2
import os
import torch.nn.functional as F
import torchvision.transforms as transforms
import glob
import argparse
import pathlib
from model import build_model
# Construct the argument parser.
parser = argparse.ArgumentParser()
parser.add_argument(
    '-w', '--weights', 
    default='/outputs/best_model.pth',
    help='path to the model weights',
)
args = vars(parser.parse_args())
# Constants and other configurations.
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
IMAGE_RESIZE = 224
CLASS_NAMES = ['Bacterial', 'Fungal', 'Healthy', 'Pests']

# Validation transforms
def get_test_transform(image_size):
    test_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
            )
    ])
    return test_transform
def annotate_image(output_class, orig_image):
    class_name = CLASS_NAMES[int(output_class)]
    cv2.putText(
        orig_image, 
        f"{class_name}", 
        (5, 35), 
        cv2.FONT_HERSHEY_SIMPLEX, 
        1.5, 
        (0, 0, 255), 
        2, 
        lineType=cv2.LINE_AA
    )
    return orig_image
def inference(model, testloader, device, orig_image):
    """

    Function to run inference.

    :param model: The trained model.

    :param testloader: The test data loader.

    :param DEVICE: The computation device.

    """
    model.eval()
    counter = 0
    with torch.no_grad():
        counter += 1
        image = testloader
        image = image.to(device)
        # Forward pass.
        outputs = model(image)
    # Softmax probabilities.
    predictions = F.softmax(outputs, dim=1).cpu().numpy()
    # Predicted class number.
    output_class = np.argmax(predictions)
    # Show and save the results.
    result = annotate_image(output_class, orig_image)
    return result

if __name__ == '__main__':
    weights_path = pathlib.Path(args['weights'])
    infer_result_path = os.path.join(
        '..', 'outputs', 'inference_results'
    )
    os.makedirs(infer_result_path, exist_ok=True)
    checkpoint = torch.load(weights_path)
    # Load the model.
    model = build_model(
        fine_tune=False, 
        num_classes=len(CLASS_NAMES)
    ).to(DEVICE)
    model.load_state_dict(checkpoint['model_state_dict'])
    all_image_paths = glob.glob(os.path.join('..', 'input', 'inference_data', '*'))
    transform = get_test_transform(IMAGE_RESIZE)
    for i, image_path in enumerate(all_image_paths):
        print(f"Inference on image: {i+1}")
        image = cv2.imread(image_path)
        orig_image = image.copy()
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = transform(image)
        image = torch.unsqueeze(image, 0)
        result = inference(
            model, 
            image,
            DEVICE,
            orig_image
        )
        # Save the image to disk.
        image_name = image_path.split(os.path.sep)[-1]
        cv2.imshow('Image', result)
        cv2.waitKey(1)
        cv2.imwrite(
            os.path.join(infer_result_path, image_name), result
        )