| import os | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import models, transforms | |
| from PIL import Image | |
| import argparse | |
| CLASS_NAMES = ['Gray Leaf Spot', 'Healthy'] | |
| def get_model(model_name, num_classes, device): | |
| if "ResNet50" in model_name: | |
| model = models.resnet50(weights=None) | |
| model.fc = nn.Linear(model.fc.in_features, num_classes) | |
| elif "EfficientNet" in model_name: | |
| model = models.efficientnet_b0(weights=None) | |
| model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes) | |
| else: | |
| raise ValueError("Model name must contain 'ResNet50' or 'EfficientNet'") | |
| return model.to(device) | |
| def predict(image_path, model_path, model_type): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| transform = transforms.Compose([ | |
| transforms.Resize((256, 256)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| model = get_model(model_type, len(CLASS_NAMES), device) | |
| model.load_state_dict(torch.load(model_path, map_location=device)) | |
| model.eval() | |
| image = Image.open(image_path).convert('RGB') | |
| input_tensor = transform(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| output = model(input_tensor) | |
| probabilities = torch.nn.functional.softmax(output, dim=1) | |
| conf, pred_idx = torch.max(probabilities, 1) | |
| result = CLASS_NAMES[pred_idx.item()] | |
| confidence = conf.item() * 100 | |
| print(f"\nTarget Disease: Gray Leaf Spot (Cercospora zeae-maydis) ") | |
| print(f"Prediction: {result}") | |
| print(f"Confidence: {confidence:.2f}%") | |
| if result == 'Gray Leaf Spot' and confidence > 80: | |
| print("Note: Detected characteristic rectangular lesions parallel to leaf veins.") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Maize Leaf Disease Inference") | |
| parser.add_argument("--image", type=str, required=True, help="Path to input leaf image") | |
| parser.add_argument("--model_path", type=str, required=True, help="Path to .pth file") | |
| parser.add_argument("--model_type", type=str, required=True, choices=["ResNet50", "EfficientNet"], help="Architecture type") | |
| args = parser.parse_args() | |
| predict(args.image, args.model_path, args.model_type) |