Boyun7's picture
upload all files
03d5bce
"""
Simple Demo for Pest and Disease Classification
Upload an image and get prediction
"""
import torch
from PIL import Image
import json
import argparse
import gradio as gr
from torchvision import transforms
from model import create_model
class PestDiseasePredictor:
"""Simple predictor class"""
def __init__(self, checkpoint_path, label_mapping_path, backbone='resnet50', device='cuda'):
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
# Load label mapping
with open(label_mapping_path, 'r', encoding='utf-8') as f:
mapping = json.load(f)
self.id_to_label = {int(k): v for k, v in mapping['id_to_label'].items()}
self.num_classes = mapping['num_classes']
# Load model
self.model = create_model(
num_classes=self.num_classes,
backbone=backbone,
pretrained=False
)
# Load checkpoint
checkpoint = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.model = self.model.to(self.device)
self.model.eval()
# Image transforms
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
print(f"Model loaded from {checkpoint_path}")
print(f"Device: {self.device}")
print(f"Classes: {self.num_classes}")
def predict(self, image):
"""
Predict class for input image
Args:
image: PIL Image
Returns:
dict: {class_name: probability}
"""
# Preprocess
if image.mode != 'RGB':
image = image.convert('RGB')
img_tensor = self.transform(image).unsqueeze(0)
img_tensor = img_tensor.to(self.device)
# Predict
with torch.no_grad():
outputs = self.model(img_tensor)
probabilities = torch.nn.functional.softmax(outputs, dim=1)
probs = probabilities[0].cpu().numpy()
# Create results dictionary
results = {}
for idx, prob in enumerate(probs):
class_name = self.id_to_label[idx]
results[class_name] = float(prob)
# Sort by probability
results = dict(sorted(results.items(), key=lambda x: x[1], reverse=True))
return results
def create_demo(predictor):
"""Create Gradio interface"""
def predict_image(image):
"""Prediction function for Gradio"""
if image is None:
return None
results = predictor.predict(image)
return results
# Create interface
demo = gr.Interface(
fn=predict_image,
inputs=gr.Image(type="pil", label="Upload Image"),
outputs=gr.Label(num_top_classes=10, label="Predictions"),
title="🌿 Pest and Disease Classification",
description="Upload an image of a citrus plant leaf to classify if it's healthy or has pests/diseases.",
examples=None,
theme=gr.themes.Soft(),
allow_flagging="never"
)
return demo
def main(args):
"""Main function"""
print("Starting Pest and Disease Classification Demo...")
print("=" * 60)
# Create predictor
predictor = PestDiseasePredictor(
checkpoint_path=args.checkpoint,
label_mapping_path=args.label_mapping,
backbone=args.backbone,
device=args.device
)
# Create and launch demo
demo = create_demo(predictor)
print("\n" + "=" * 60)
print("Launching demo...")
print("=" * 60)
demo.launch(
server_name=args.host,
server_port=args.port,
share=args.share
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Demo for Pest and Disease Classification')
parser.add_argument('--checkpoint', type=str, default='checkpoints/best_model.pth',
help='Path to model checkpoint')
parser.add_argument('--label_mapping', type=str, default='label_mapping.json',
help='Path to label mapping JSON')
parser.add_argument('--backbone', type=str, default='resnet50',
choices=['resnet50', 'resnet101', 'efficientnet_b0',
'efficientnet_b3', 'mobilenet_v2'],
help='Model backbone')
parser.add_argument('--device', type=str, default='cuda',
choices=['cuda', 'cpu'],
help='Device to use')
parser.add_argument('--host', type=str, default='127.0.0.1',
help='Server host')
parser.add_argument('--port', type=int, default=7860,
help='Server port')
parser.add_argument('--share', action='store_true',
help='Create public link')
args = parser.parse_args()
main(args)