Spaces:
Sleeping
Sleeping
| """ | |
| Simple Demo for Pest and Disease Classification | |
| Upload an image and get prediction | |
| """ | |
| import torch | |
| from PIL import Image | |
| import json | |
| import argparse | |
| import gradio as gr | |
| from torchvision import transforms | |
| from model import create_model | |
| class PestDiseasePredictor: | |
| """Simple predictor class""" | |
| def __init__(self, checkpoint_path, label_mapping_path, backbone='resnet50', device='cuda'): | |
| self.device = torch.device(device if torch.cuda.is_available() else 'cpu') | |
| # Load label mapping | |
| with open(label_mapping_path, 'r', encoding='utf-8') as f: | |
| mapping = json.load(f) | |
| self.id_to_label = {int(k): v for k, v in mapping['id_to_label'].items()} | |
| self.num_classes = mapping['num_classes'] | |
| # Load model | |
| self.model = create_model( | |
| num_classes=self.num_classes, | |
| backbone=backbone, | |
| pretrained=False | |
| ) | |
| # Load checkpoint | |
| checkpoint = torch.load(checkpoint_path, map_location=self.device) | |
| self.model.load_state_dict(checkpoint['model_state_dict']) | |
| self.model = self.model.to(self.device) | |
| self.model.eval() | |
| # Image transforms | |
| self.transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| print(f"Model loaded from {checkpoint_path}") | |
| print(f"Device: {self.device}") | |
| print(f"Classes: {self.num_classes}") | |
| def predict(self, image): | |
| """ | |
| Predict class for input image | |
| Args: | |
| image: PIL Image | |
| Returns: | |
| dict: {class_name: probability} | |
| """ | |
| # Preprocess | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| img_tensor = self.transform(image).unsqueeze(0) | |
| img_tensor = img_tensor.to(self.device) | |
| # Predict | |
| with torch.no_grad(): | |
| outputs = self.model(img_tensor) | |
| probabilities = torch.nn.functional.softmax(outputs, dim=1) | |
| probs = probabilities[0].cpu().numpy() | |
| # Create results dictionary | |
| results = {} | |
| for idx, prob in enumerate(probs): | |
| class_name = self.id_to_label[idx] | |
| results[class_name] = float(prob) | |
| # Sort by probability | |
| results = dict(sorted(results.items(), key=lambda x: x[1], reverse=True)) | |
| return results | |
| def create_demo(predictor): | |
| """Create Gradio interface""" | |
| def predict_image(image): | |
| """Prediction function for Gradio""" | |
| if image is None: | |
| return None | |
| results = predictor.predict(image) | |
| return results | |
| # Create interface | |
| demo = gr.Interface( | |
| fn=predict_image, | |
| inputs=gr.Image(type="pil", label="Upload Image"), | |
| outputs=gr.Label(num_top_classes=10, label="Predictions"), | |
| title="🌿 Pest and Disease Classification", | |
| description="Upload an image of a citrus plant leaf to classify if it's healthy or has pests/diseases.", | |
| examples=None, | |
| theme=gr.themes.Soft(), | |
| allow_flagging="never" | |
| ) | |
| return demo | |
| def main(args): | |
| """Main function""" | |
| print("Starting Pest and Disease Classification Demo...") | |
| print("=" * 60) | |
| # Create predictor | |
| predictor = PestDiseasePredictor( | |
| checkpoint_path=args.checkpoint, | |
| label_mapping_path=args.label_mapping, | |
| backbone=args.backbone, | |
| device=args.device | |
| ) | |
| # Create and launch demo | |
| demo = create_demo(predictor) | |
| print("\n" + "=" * 60) | |
| print("Launching demo...") | |
| print("=" * 60) | |
| demo.launch( | |
| server_name=args.host, | |
| server_port=args.port, | |
| share=args.share | |
| ) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description='Demo for Pest and Disease Classification') | |
| parser.add_argument('--checkpoint', type=str, default='checkpoints/best_model.pth', | |
| help='Path to model checkpoint') | |
| parser.add_argument('--label_mapping', type=str, default='label_mapping.json', | |
| help='Path to label mapping JSON') | |
| parser.add_argument('--backbone', type=str, default='resnet50', | |
| choices=['resnet50', 'resnet101', 'efficientnet_b0', | |
| 'efficientnet_b3', 'mobilenet_v2'], | |
| help='Model backbone') | |
| parser.add_argument('--device', type=str, default='cuda', | |
| choices=['cuda', 'cpu'], | |
| help='Device to use') | |
| parser.add_argument('--host', type=str, default='127.0.0.1', | |
| help='Server host') | |
| parser.add_argument('--port', type=int, default=7860, | |
| help='Server port') | |
| parser.add_argument('--share', action='store_true', | |
| help='Create public link') | |
| args = parser.parse_args() | |
| main(args) | |