corn-gray-leaf-spot-xai / inference.py
PulinduVR's picture
Added readme
5548a07
import os
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import argparse
CLASS_NAMES = ['Gray Leaf Spot', 'Healthy']
def get_model(model_name, num_classes, device):
if "ResNet50" in model_name:
model = models.resnet50(weights=None)
model.fc = nn.Linear(model.fc.in_features, num_classes)
elif "EfficientNet" in model_name:
model = models.efficientnet_b0(weights=None)
model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
else:
raise ValueError("Model name must contain 'ResNet50' or 'EfficientNet'")
return model.to(device)
def predict(image_path, model_path, model_type):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
model = get_model(model_type, len(CLASS_NAMES), device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
image = Image.open(image_path).convert('RGB')
input_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
output = model(input_tensor)
probabilities = torch.nn.functional.softmax(output, dim=1)
conf, pred_idx = torch.max(probabilities, 1)
result = CLASS_NAMES[pred_idx.item()]
confidence = conf.item() * 100
print(f"\nTarget Disease: Gray Leaf Spot (Cercospora zeae-maydis) ")
print(f"Prediction: {result}")
print(f"Confidence: {confidence:.2f}%")
if result == 'Gray Leaf Spot' and confidence > 80:
print("Note: Detected characteristic rectangular lesions parallel to leaf veins.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Maize Leaf Disease Inference")
parser.add_argument("--image", type=str, required=True, help="Path to input leaf image")
parser.add_argument("--model_path", type=str, required=True, help="Path to .pth file")
parser.add_argument("--model_type", type=str, required=True, choices=["ResNet50", "EfficientNet"], help="Architecture type")
args = parser.parse_args()
predict(args.image, args.model_path, args.model_type)