Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| import timm | |
| from PIL import Image | |
| from torchvision import transforms | |
| import scipy.io | |
| from pytorch_grad_cam import GradCAMPlusPlus | |
| from pytorch_grad_cam.utils.image import show_cam_on_image | |
| from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
| # --- Load model and metadata --- | |
| MODEL_PATH = "efficientnetv2_best_model.pth" | |
| META_PATH = "cars_meta.mat" | |
| DEVICE = torch.device("cpu") | |
| meta = scipy.io.loadmat(META_PATH) | |
| class_names = [x[0] for x in meta['class_names'][0]] | |
| imagenet_mean = [0.485, 0.456, 0.406] | |
| imagenet_std = [0.229, 0.224, 0.225] | |
| val_transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=imagenet_mean, std=imagenet_std) | |
| ]) | |
| def get_model(num_classes): | |
| model = timm.create_model('efficientnetv2_rw_s', pretrained=False, num_classes=num_classes) | |
| return model | |
| model = get_model(num_classes=len(class_names)) | |
| state_dict = torch.load(MODEL_PATH, map_location=DEVICE) | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| model.to(DEVICE) | |
| # --- Grad-CAM++ prediction function --- | |
| def predict_and_explain(img): | |
| image_pil = img.convert("RGB").resize((224, 224)) | |
| input_tensor = val_transform(image_pil).unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| output = model(input_tensor) | |
| pred_idx = output.argmax(dim=1).item() | |
| pred_name = class_names[pred_idx] | |
| # Grad-CAM++ | |
| targets = [ClassifierOutputTarget(pred_idx)] | |
| cam = GradCAMPlusPlus(model=model, target_layers=[model.blocks[-1][-1]]) | |
| grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0] | |
| image_np = np.array(image_pil).astype(np.float32) / 255.0 | |
| cam_image = show_cam_on_image(image_np, grayscale_cam, use_rgb=True) | |
| cam_pil = Image.fromarray(cam_image).resize(img.size) # Resize to match original | |
| return cam_pil, f"Prediction: {pred_name} (class index {pred_idx})" | |
| # --- Gradio UI --- | |
| demo = gr.Interface( | |
| fn=predict_and_explain, | |
| inputs=gr.Image(type="pil", label="Upload Car Image"), | |
| outputs=[ | |
| gr.Image(type="pil", label="Grad-CAM++ Heatmap"), | |
| gr.Text(label="Prediction") | |
| ], | |
| title="๐ EfficientNetV2 Car Classifier + Grad-CAM++", | |
| description="Upload a car image to see its predicted make/model/year and what influenced the prediction.", | |
| ) | |
| demo.launch() | |