import torch import numpy as np import gradio as gr import timm from PIL import Image from torchvision import transforms import scipy.io from pytorch_grad_cam import GradCAMPlusPlus from pytorch_grad_cam.utils.image import show_cam_on_image from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget # --- Load model and metadata --- MODEL_PATH = "efficientnetv2_best_model.pth" META_PATH = "cars_meta.mat" DEVICE = torch.device("cpu") meta = scipy.io.loadmat(META_PATH) class_names = [x[0] for x in meta['class_names'][0]] imagenet_mean = [0.485, 0.456, 0.406] imagenet_std = [0.229, 0.224, 0.225] val_transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=imagenet_mean, std=imagenet_std) ]) def get_model(num_classes): model = timm.create_model('efficientnetv2_rw_s', pretrained=False, num_classes=num_classes) return model model = get_model(num_classes=len(class_names)) state_dict = torch.load(MODEL_PATH, map_location=DEVICE) model.load_state_dict(state_dict) model.eval() model.to(DEVICE) # --- Grad-CAM++ prediction function --- def predict_and_explain(img): image_pil = img.convert("RGB").resize((224, 224)) input_tensor = val_transform(image_pil).unsqueeze(0).to(DEVICE) with torch.no_grad(): output = model(input_tensor) pred_idx = output.argmax(dim=1).item() pred_name = class_names[pred_idx] # Grad-CAM++ targets = [ClassifierOutputTarget(pred_idx)] cam = GradCAMPlusPlus(model=model, target_layers=[model.blocks[-1][-1]]) grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0] image_np = np.array(image_pil).astype(np.float32) / 255.0 cam_image = show_cam_on_image(image_np, grayscale_cam, use_rgb=True) cam_pil = Image.fromarray(cam_image).resize(img.size) # Resize to match original return cam_pil, f"Prediction: {pred_name} (class index {pred_idx})" # --- Gradio UI --- demo = gr.Interface( fn=predict_and_explain, inputs=gr.Image(type="pil", label="Upload Car Image"), outputs=[ gr.Image(type="pil", label="Grad-CAM++ Heatmap"), gr.Text(label="Prediction") ], title="🚗 EfficientNetV2 Car Classifier + Grad-CAM++", description="Upload a car image to see its predicted make/model/year and what influenced the prediction.", ) demo.launch()