import os import torch import torch.nn as nn from torchvision import models, transforms from PIL import Image import argparse CLASS_NAMES = ['Gray Leaf Spot', 'Healthy'] def get_model(model_name, num_classes, device): if "ResNet50" in model_name: model = models.resnet50(weights=None) model.fc = nn.Linear(model.fc.in_features, num_classes) elif "EfficientNet" in model_name: model = models.efficientnet_b0(weights=None) model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes) else: raise ValueError("Model name must contain 'ResNet50' or 'EfficientNet'") return model.to(device) def predict(image_path, model_path, model_type): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) model = get_model(model_type, len(CLASS_NAMES), device) model.load_state_dict(torch.load(model_path, map_location=device)) model.eval() image = Image.open(image_path).convert('RGB') input_tensor = transform(image).unsqueeze(0).to(device) with torch.no_grad(): output = model(input_tensor) probabilities = torch.nn.functional.softmax(output, dim=1) conf, pred_idx = torch.max(probabilities, 1) result = CLASS_NAMES[pred_idx.item()] confidence = conf.item() * 100 print(f"\nTarget Disease: Gray Leaf Spot (Cercospora zeae-maydis) ") print(f"Prediction: {result}") print(f"Confidence: {confidence:.2f}%") if result == 'Gray Leaf Spot' and confidence > 80: print("Note: Detected characteristic rectangular lesions parallel to leaf veins.") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Maize Leaf Disease Inference") parser.add_argument("--image", type=str, required=True, help="Path to input leaf image") parser.add_argument("--model_path", type=str, required=True, help="Path to .pth file") parser.add_argument("--model_type", type=str, required=True, choices=["ResNet50", "EfficientNet"], help="Architecture type") args = parser.parse_args() predict(args.image, args.model_path, args.model_type)