Subh775's picture
Upload folder using huggingface_hub
6085c77 verified
import torch
import numpy as np
import cv2
import os
import torch.nn.functional as F
import torchvision.transforms as transforms
import glob
import argparse
import pathlib
from model import build_model
# Construct the argument parser.
parser = argparse.ArgumentParser()
parser.add_argument(
'-w', '--weights',
default='/outputs/best_model.pth',
help='path to the model weights',
)
args = vars(parser.parse_args())
# Constants and other configurations.
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
IMAGE_RESIZE = 224
CLASS_NAMES = ['Bacterial', 'Fungal', 'Healthy', 'Pests']
# Validation transforms
def get_test_transform(image_size):
test_transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
return test_transform
def annotate_image(output_class, orig_image):
class_name = CLASS_NAMES[int(output_class)]
cv2.putText(
orig_image,
f"{class_name}",
(5, 35),
cv2.FONT_HERSHEY_SIMPLEX,
1.5,
(0, 0, 255),
2,
lineType=cv2.LINE_AA
)
return orig_image
def inference(model, testloader, device, orig_image):
"""
Function to run inference.
:param model: The trained model.
:param testloader: The test data loader.
:param DEVICE: The computation device.
"""
model.eval()
counter = 0
with torch.no_grad():
counter += 1
image = testloader
image = image.to(device)
# Forward pass.
outputs = model(image)
# Softmax probabilities.
predictions = F.softmax(outputs, dim=1).cpu().numpy()
# Predicted class number.
output_class = np.argmax(predictions)
# Show and save the results.
result = annotate_image(output_class, orig_image)
return result
if __name__ == '__main__':
weights_path = pathlib.Path(args['weights'])
infer_result_path = os.path.join(
'..', 'outputs', 'inference_results'
)
os.makedirs(infer_result_path, exist_ok=True)
checkpoint = torch.load(weights_path)
# Load the model.
model = build_model(
fine_tune=False,
num_classes=len(CLASS_NAMES)
).to(DEVICE)
model.load_state_dict(checkpoint['model_state_dict'])
all_image_paths = glob.glob(os.path.join('..', 'input', 'inference_data', '*'))
transform = get_test_transform(IMAGE_RESIZE)
for i, image_path in enumerate(all_image_paths):
print(f"Inference on image: {i+1}")
image = cv2.imread(image_path)
orig_image = image.copy()
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = transform(image)
image = torch.unsqueeze(image, 0)
result = inference(
model,
image,
DEVICE,
orig_image
)
# Save the image to disk.
image_name = image_path.split(os.path.sep)[-1]
cv2.imshow('Image', result)
cv2.waitKey(1)
cv2.imwrite(
os.path.join(infer_result_path, image_name), result
)