kikogazda commited on
Commit
ea22817
·
verified ·
1 Parent(s): af4874c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -0
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import timm
3
+ import numpy as np
4
+ from PIL import Image
5
+ import gradio as gr
6
+ import json
7
+ from torchvision import transforms
8
+ from pytorch_grad_cam import GradCAMPlusPlus
9
+ from pytorch_grad_cam.utils.image import show_cam_on_image
10
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
11
+
12
+ # --- Config ---
13
+ MODEL_WEIGHTS = "efficientnetv2_best_model.pth"
14
+ CLASS_MAPPING = "class_mapping.json"
15
+
16
+ # --- Device ---
17
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18
+
19
+ # --- Load class names ---
20
+ with open(CLASS_MAPPING, "r") as f:
21
+ class_names = json.load(f)
22
+ # handle both list and dict style
23
+ if isinstance(class_names, dict):
24
+ class_names = [class_names[str(i)] for i in range(len(class_names))]
25
+
26
+ NUM_CLASSES = len(class_names)
27
+
28
+ # --- Model ---
29
+ model = timm.create_model('efficientnetv2_rw_s', pretrained=False, num_classes=NUM_CLASSES, drop_rate=0.3)
30
+ model.load_state_dict(torch.load(MODEL_WEIGHTS, map_location=device))
31
+ model.to(device)
32
+ model.eval()
33
+
34
+ # --- Preprocessing ---
35
+ imagenet_mean = [0.485, 0.456, 0.406]
36
+ imagenet_std = [0.229, 0.224, 0.225]
37
+ val_transform = transforms.Compose([
38
+ transforms.Resize(256),
39
+ transforms.CenterCrop(224),
40
+ transforms.ToTensor(),
41
+ transforms.Normalize(mean=imagenet_mean, std=imagenet_std)
42
+ ])
43
+
44
+ # --- Grad-CAM setup (for EfficientNetV2, last block) ---
45
+ target_layer = model.blocks[-1] if hasattr(model, "blocks") else model.layer4[-1]
46
+ cam = GradCAMPlusPlus(model=model, target_layers=[target_layer], use_cuda=(device.type=='cuda'))
47
+
48
+ # --- Gradio Inference + Explainability ---
49
+ def predict_and_explain(img: Image.Image):
50
+ # Preprocess
51
+ image_pil = img.convert("RGB").resize((224, 224))
52
+ input_tensor = val_transform(image_pil).unsqueeze(0).to(device)
53
+ with torch.no_grad():
54
+ output = model(input_tensor)
55
+ pred_idx = output.argmax().item()
56
+ pred_name = class_names[pred_idx]
57
+ # Grad-CAM
58
+ targets = [ClassifierOutputTarget(pred_idx)]
59
+ grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0]
60
+ image_np = np.array(image_pil).astype(np.float32) / 255.0
61
+ cam_image = show_cam_on_image(image_np, grayscale_cam, use_rgb=True)
62
+ return Image.fromarray(cam_image), f"Prediction: {pred_name} (class index {pred_idx})"
63
+
64
+ demo = gr.Interface(
65
+ fn=predict_and_explain,
66
+ inputs=gr.Image(type="pil", label="Upload Car Image"),
67
+ outputs=[
68
+ gr.Image(label="Grad-CAM++ Output"),
69
+ gr.Text(label="Prediction")
70
+ ],
71
+ title="🚗 EfficientNetV2 Car Classifier + Grad-CAM Demo",
72
+ description="Upload a car photo to classify its make/model/year and visualize the model's attention with Grad-CAM.",
73
+ allow_flagging='never'
74
+ )
75
+
76
+ if __name__ == "__main__":
77
+ demo.launch()