Spaces:
Running
Running
| """ | |
| Simple Demo for Pest and Disease Classification | |
| For Hugging Face Space Deployment | |
| Supports both single model and ensemble prediction | |
| """ | |
| import torch | |
| from PIL import Image | |
| import json | |
| import gradio as gr | |
| from torchvision import transforms | |
| import numpy as np | |
| from pathlib import Path | |
| from model import create_model | |
| class PestDiseasePredictor: | |
| """Simple predictor class""" | |
| def __init__(self, checkpoint_path, label_mapping_path, backbone='resnet50', device='cuda'): | |
| self.device = torch.device(device if torch.cuda.is_available() else 'cpu') | |
| # Load label mapping | |
| with open(label_mapping_path, 'r', encoding='utf-8') as f: | |
| mapping = json.load(f) | |
| self.id_to_label = {int(k): v for k, v in mapping['id_to_label'].items()} | |
| self.num_classes = mapping['num_classes'] | |
| # Load model | |
| self.model = create_model( | |
| num_classes=self.num_classes, | |
| backbone=backbone, | |
| pretrained=False | |
| ) | |
| # Load checkpoint | |
| checkpoint = torch.load(checkpoint_path, map_location=self.device) | |
| self.model.load_state_dict(checkpoint['model_state_dict']) | |
| self.model = self.model.to(self.device) | |
| self.model.eval() | |
| # Image transforms | |
| self.transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| print(f"β Model loaded from {checkpoint_path}") | |
| print(f"π» Device: {self.device}") | |
| print(f"π Classes: {self.num_classes}") | |
| def predict(self, image): | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| img_tensor = self.transform(image).unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model(img_tensor) | |
| probs = torch.nn.functional.softmax(outputs, dim=1)[0].cpu().numpy() | |
| results = {self.id_to_label[i]: float(p) for i, p in enumerate(probs)} | |
| return dict(sorted(results.items(), key=lambda x: x[1], reverse=True)) | |
| class EnsemblePredictor: | |
| """Ensemble predictor using weighted soft voting""" | |
| def __init__(self, checkpoint_paths, weights, label_mapping_path, backbone='efficientnet_b3', device='cuda'): | |
| self.device = torch.device(device if torch.cuda.is_available() else 'cpu') | |
| # Normalize weights to sum to 1 | |
| weights = np.array(weights) | |
| self.weights = weights / weights.sum() | |
| # Load label mapping | |
| with open(label_mapping_path, 'r', encoding='utf-8') as f: | |
| mapping = json.load(f) | |
| self.id_to_label = {int(k): v for k, v in mapping['id_to_label'].items()} | |
| self.num_classes = mapping['num_classes'] | |
| # Load all models | |
| self.models = [] | |
| print(f"\n{'='*80}") | |
| print("Loading Ensemble Models") | |
| print(f"{'='*80}") | |
| for i, checkpoint_path in enumerate(checkpoint_paths): | |
| print(f"\nModel {i+1}/{len(checkpoint_paths)}") | |
| print(f" Checkpoint: {checkpoint_path}") | |
| print(f" Weight: {self.weights[i]:.4f}") | |
| # Create model | |
| model = create_model( | |
| num_classes=self.num_classes, | |
| backbone=backbone, | |
| pretrained=False | |
| ) | |
| # Load checkpoint | |
| if Path(checkpoint_path).exists(): | |
| checkpoint = torch.load(checkpoint_path, map_location=self.device) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model = model.to(self.device) | |
| model.eval() | |
| self.models.append(model) | |
| print(f" β Loaded successfully") | |
| else: | |
| print(f" β Checkpoint not found: {checkpoint_path}") | |
| raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") | |
| print(f"\n{'='*80}") | |
| print(f"β Ensemble loaded: {len(self.models)} models") | |
| print(f"π» Device: {self.device}") | |
| print(f"π Classes: {self.num_classes}") | |
| print(f"{'='*80}\n") | |
| # Image transforms | |
| self.transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| def predict(self, image): | |
| """Predict using weighted soft voting""" | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| img_tensor = self.transform(image).unsqueeze(0).to(self.device) | |
| # Get predictions from all models | |
| ensemble_probs = np.zeros(self.num_classes) | |
| with torch.no_grad(): | |
| for model, weight in zip(self.models, self.weights): | |
| outputs = model(img_tensor) | |
| probabilities = torch.nn.functional.softmax(outputs, dim=1) | |
| probs = probabilities[0].cpu().numpy() | |
| ensemble_probs += weight * probs | |
| # Create results dictionary | |
| results = {} | |
| for idx, prob in enumerate(ensemble_probs): | |
| class_name = self.id_to_label[idx] | |
| results[class_name] = float(prob) | |
| return dict(sorted(results.items(), key=lambda x: x[1], reverse=True)) | |
| # ========== For Hugging Face Space ========== | |
| label_mapping_path = "label_mapping.json" | |
| backbone = 'efficientnet_b3' | |
| device = "cuda" | |
| # Load single model predictor | |
| single_checkpoint = "checkpoints/best_model_fold1.pth" | |
| single_predictor = PestDiseasePredictor( | |
| checkpoint_path=single_checkpoint, | |
| label_mapping_path=label_mapping_path, | |
| backbone=backbone, | |
| device=device | |
| ) | |
| # Load ensemble predictor | |
| ensemble_checkpoints = [ | |
| "checkpoints/best_model_fold1.pth", | |
| "checkpoints/best_model_fold2.pth", | |
| "checkpoints/best_model_fold3.pth", | |
| "checkpoints/best_model_fold4.pth", | |
| "checkpoints/best_model_fold5.pth" | |
| ] | |
| ensemble_weights = [1.0, 1.0, 1.0, 1.0, 1.0] | |
| ensemble_predictor = EnsemblePredictor( | |
| checkpoint_paths=ensemble_checkpoints, | |
| weights=ensemble_weights, | |
| label_mapping_path=label_mapping_path, | |
| backbone=backbone, | |
| device=device | |
| ) | |
| def predict_image(image): | |
| """Predict with ensemble model""" | |
| if image is None: | |
| return None | |
| return ensemble_predictor.predict(image) # return single_predictor.predict(image) | |
| demo = gr.Interface( | |
| fn=predict_image, | |
| inputs=gr.Image(type="pil", label="Upload Image"), | |
| outputs=gr.Label(num_top_classes=10, label="Predictions"), | |
| title="<center>πΏ Pest and Disease Classification</center>", | |
| description="Upload an image of a citrus leaf to classify its pest or disease type.", | |
| theme=gr.themes.Soft(), | |
| allow_flagging="never" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |