Efficient_NetV2 / app.py
kikogazda's picture
Update app.py
0814784 verified
import torch
import numpy as np
import gradio as gr
import timm
from PIL import Image
from torchvision import transforms
import scipy.io
from pytorch_grad_cam import GradCAMPlusPlus
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
# --- Load model and metadata ---
MODEL_PATH = "efficientnetv2_best_model.pth"
META_PATH = "cars_meta.mat"
DEVICE = torch.device("cpu")
meta = scipy.io.loadmat(META_PATH)
class_names = [x[0] for x in meta['class_names'][0]]
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]
val_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=imagenet_mean, std=imagenet_std)
])
def get_model(num_classes):
model = timm.create_model('efficientnetv2_rw_s', pretrained=False, num_classes=num_classes)
return model
model = get_model(num_classes=len(class_names))
state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
model.load_state_dict(state_dict)
model.eval()
model.to(DEVICE)
# --- Grad-CAM++ prediction function ---
def predict_and_explain(img):
image_pil = img.convert("RGB").resize((224, 224))
input_tensor = val_transform(image_pil).unsqueeze(0).to(DEVICE)
with torch.no_grad():
output = model(input_tensor)
pred_idx = output.argmax(dim=1).item()
pred_name = class_names[pred_idx]
# Grad-CAM++
targets = [ClassifierOutputTarget(pred_idx)]
cam = GradCAMPlusPlus(model=model, target_layers=[model.blocks[-1][-1]])
grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0]
image_np = np.array(image_pil).astype(np.float32) / 255.0
cam_image = show_cam_on_image(image_np, grayscale_cam, use_rgb=True)
cam_pil = Image.fromarray(cam_image).resize(img.size) # Resize to match original
return cam_pil, f"Prediction: {pred_name} (class index {pred_idx})"
# --- Gradio UI ---
demo = gr.Interface(
fn=predict_and_explain,
inputs=gr.Image(type="pil", label="Upload Car Image"),
outputs=[
gr.Image(type="pil", label="Grad-CAM++ Heatmap"),
gr.Text(label="Prediction")
],
title="๐Ÿš— EfficientNetV2 Car Classifier + Grad-CAM++",
description="Upload a car image to see its predicted make/model/year and what influenced the prediction.",
)
demo.launch()