File size: 2,383 Bytes
064be73
 
d394e62
 
064be73
d394e62
064be73
d394e62
1aa2b41
d394e62
064be73
1aa2b41
064be73
 
 
 
 
 
 
 
 
d394e62
064be73
d394e62
064be73
 
 
 
 
 
 
 
 
2a5267a
8892bda
064be73
 
 
1aa2b41
d394e62
 
 
8892bda
064be73
 
d394e62
 
 
1aa2b41
d394e62
ca8310f
d394e62
 
 
 
 
0814784
d394e62
064be73
1aa2b41
d394e62
 
0814784
1aa2b41
0814784
1aa2b41
 
d394e62
 
064be73
 
d394e62
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
import numpy as np
import gradio as gr
import timm
from PIL import Image
from torchvision import transforms
import scipy.io
from pytorch_grad_cam import GradCAMPlusPlus
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget

# --- Load model and metadata ---
MODEL_PATH = "efficientnetv2_best_model.pth"
META_PATH = "cars_meta.mat"
DEVICE = torch.device("cpu")

meta = scipy.io.loadmat(META_PATH)
class_names = [x[0] for x in meta['class_names'][0]]

imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=imagenet_mean, std=imagenet_std)
])

def get_model(num_classes):
    model = timm.create_model('efficientnetv2_rw_s', pretrained=False, num_classes=num_classes)
    return model

model = get_model(num_classes=len(class_names))
state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
model.load_state_dict(state_dict)
model.eval()
model.to(DEVICE)

# --- Grad-CAM++ prediction function ---
def predict_and_explain(img):
    image_pil = img.convert("RGB").resize((224, 224))
    input_tensor = val_transform(image_pil).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        output = model(input_tensor)
        pred_idx = output.argmax(dim=1).item()
        pred_name = class_names[pred_idx]

    # Grad-CAM++
    targets = [ClassifierOutputTarget(pred_idx)]
    cam = GradCAMPlusPlus(model=model, target_layers=[model.blocks[-1][-1]])
    grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0]

    image_np = np.array(image_pil).astype(np.float32) / 255.0
    cam_image = show_cam_on_image(image_np, grayscale_cam, use_rgb=True)

    cam_pil = Image.fromarray(cam_image).resize(img.size)  # Resize to match original
    return cam_pil, f"Prediction: {pred_name} (class index {pred_idx})"

# --- Gradio UI ---
demo = gr.Interface(
    fn=predict_and_explain,
    inputs=gr.Image(type="pil", label="Upload Car Image"),
    outputs=[
        gr.Image(type="pil", label="Grad-CAM++ Heatmap"),
        gr.Text(label="Prediction")
    ],
    title="🚗 EfficientNetV2 Car Classifier + Grad-CAM++",
    description="Upload a car image to see its predicted make/model/year and what influenced the prediction.",
)

demo.launch()