Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import models, transforms | |
| import numpy as np | |
| from PIL import Image | |
| from pytorch_grad_cam import GradCAM | |
| from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
| from pytorch_grad_cam.utils.image import show_cam_on_image | |
| # --- 1. CONFIGURATION --- | |
| # These are your 14 specific classes as trained in Colab | |
| CLASS_NAMES = [ | |
| "APPLE_Apple_Scab", "APPLE_Healthy", "CORN_Cercospora_Gray_Leaf_Spot", | |
| "CORN_Common_Rust", "CORN_Healthy", "CORN_Northern_Leaf_Blight", | |
| "GRAPE_Black_Rot", "GRAPE_Healthy", "TOMATO_Early_Blight", | |
| "TOMATO_Healthy", "TOMATO_Leaf_Mold", "TOMATO_Mosaic_Virus", | |
| "TOMATO_Septoria_Leaf_Spot", "TOMATO_Yellow_Leaf_Virus" | |
| ] | |
| def load_model(): | |
| # Base MobileNetV2 | |
| model = models.mobilenet_v2(weights=None) | |
| num_ftrs = model.last_channel | |
| # EXACT ARCHITECTURE MATCH: 256 Hidden Units and 14 Output Classes | |
| model.classifier = nn.Sequential( | |
| nn.Dropout(p=0.2), # classifier.0 | |
| nn.Sequential( # classifier.1 | |
| nn.Linear(num_ftrs, 256), # classifier.1.0 (Matching your checkpoint) | |
| nn.ReLU(), # classifier.1.1 | |
| nn.Dropout(p=0.5), # classifier.1.2 | |
| nn.Linear(256, 14) # classifier.1.3 (14 classes) | |
| ) | |
| ) | |
| # Load weights onto CPU for Hugging Face | |
| model.load_state_dict(torch.load("final_tuned_plant_model.pth", map_location=torch.device('cpu'))) | |
| model.eval() | |
| return model | |
| # Initialize model | |
| model = load_model() | |
| # --- 2. PREPROCESSING --- | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| def analyze_plant(img): | |
| if img is None: | |
| return None, None | |
| # Prepare Image | |
| input_tensor = transform(img).unsqueeze(0) | |
| # 1. Prediction Phase | |
| with torch.no_grad(): | |
| outputs = model(input_tensor) | |
| probabilities = torch.nn.functional.softmax(outputs[0], dim=0) | |
| # Create dictionary of class name: probability | |
| confidences = {CLASS_NAMES[i]: float(probabilities[i]) for i in range(len(CLASS_NAMES))} | |
| # 2. Grad-CAM Explainability Phase | |
| # Target the final expansion layer of MobileNetV2 features | |
| target_layers = [model.features[-1]] | |
| cam = GradCAM(model=model, target_layers=target_layers) | |
| # Generate heatmap for the highest predicted class | |
| targets = [ClassifierOutputTarget(np.argmax(probabilities.numpy()))] | |
| grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0, :] | |
| # Create the Visual Overlay (Heatmap + Original) | |
| rgb_img = np.array(img.resize((224, 224))) / 255.0 | |
| cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True) | |
| return confidences, cam_image | |
| # --- 3. GRADIO INTERFACE --- | |
| demo = gr.Interface( | |
| fn=analyze_plant, | |
| inputs=gr.Image(type="pil"), | |
| outputs=[ | |
| gr.Label(num_top_classes=3, label="Top Predictions"), | |
| gr.Image(label="Feature Focus (Grad-CAM)") | |
| ], | |
| title="TEK_1371068G: 14-Class Plant Health Diagnostic", | |
| description="Project X: Upload a leaf image to see the diagnostic result and the visual evidence (heatmap) used by the AI." | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |