ekting's picture
Update app.py
d32ca46 verified
import gradio as gr
import torch
import torch.nn as nn
from torchvision import models, transforms
import numpy as np
from PIL import Image
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
# --- 1. CONFIGURATION ---
# These are your 14 specific classes as trained in Colab
CLASS_NAMES = [
"APPLE_Apple_Scab", "APPLE_Healthy", "CORN_Cercospora_Gray_Leaf_Spot",
"CORN_Common_Rust", "CORN_Healthy", "CORN_Northern_Leaf_Blight",
"GRAPE_Black_Rot", "GRAPE_Healthy", "TOMATO_Early_Blight",
"TOMATO_Healthy", "TOMATO_Leaf_Mold", "TOMATO_Mosaic_Virus",
"TOMATO_Septoria_Leaf_Spot", "TOMATO_Yellow_Leaf_Virus"
]
def load_model():
# Base MobileNetV2
model = models.mobilenet_v2(weights=None)
num_ftrs = model.last_channel
# EXACT ARCHITECTURE MATCH: 256 Hidden Units and 14 Output Classes
model.classifier = nn.Sequential(
nn.Dropout(p=0.2), # classifier.0
nn.Sequential( # classifier.1
nn.Linear(num_ftrs, 256), # classifier.1.0 (Matching your checkpoint)
nn.ReLU(), # classifier.1.1
nn.Dropout(p=0.5), # classifier.1.2
nn.Linear(256, 14) # classifier.1.3 (14 classes)
)
)
# Load weights onto CPU for Hugging Face
model.load_state_dict(torch.load("final_tuned_plant_model.pth", map_location=torch.device('cpu')))
model.eval()
return model
# Initialize model
model = load_model()
# --- 2. PREPROCESSING ---
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def analyze_plant(img):
if img is None:
return None, None
# Prepare Image
input_tensor = transform(img).unsqueeze(0)
# 1. Prediction Phase
with torch.no_grad():
outputs = model(input_tensor)
probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
# Create dictionary of class name: probability
confidences = {CLASS_NAMES[i]: float(probabilities[i]) for i in range(len(CLASS_NAMES))}
# 2. Grad-CAM Explainability Phase
# Target the final expansion layer of MobileNetV2 features
target_layers = [model.features[-1]]
cam = GradCAM(model=model, target_layers=target_layers)
# Generate heatmap for the highest predicted class
targets = [ClassifierOutputTarget(np.argmax(probabilities.numpy()))]
grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0, :]
# Create the Visual Overlay (Heatmap + Original)
rgb_img = np.array(img.resize((224, 224))) / 255.0
cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
return confidences, cam_image
# --- 3. GRADIO INTERFACE ---
demo = gr.Interface(
fn=analyze_plant,
inputs=gr.Image(type="pil"),
outputs=[
gr.Label(num_top_classes=3, label="Top Predictions"),
gr.Image(label="Feature Focus (Grad-CAM)")
],
title="TEK_1371068G: 14-Class Plant Health Diagnostic",
description="Project X: Upload a leaf image to see the diagnostic result and the visual evidence (heatmap) used by the AI."
)
if __name__ == "__main__":
demo.launch()