kikogazda commited on
Commit
1aa2b41
·
verified ·
1 Parent(s): ca8310f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -6
app.py CHANGED
@@ -6,10 +6,10 @@ from PIL import Image
6
  from torchvision import transforms
7
  import scipy.io
8
  from pytorch_grad_cam import GradCAMPlusPlus
9
- from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image
10
  from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
11
 
12
- # Load model and class names
13
  MODEL_PATH = "efficientnetv2_best_model.pth"
14
  META_PATH = "cars_meta.mat"
15
  DEVICE = torch.device("cpu")
@@ -36,7 +36,7 @@ model.load_state_dict(state_dict)
36
  model.eval()
37
  model.to(DEVICE)
38
 
39
- # Grad-CAM++
40
  def predict_and_explain(img):
41
  image_pil = img.convert("RGB").resize((224, 224))
42
  input_tensor = val_transform(image_pil).unsqueeze(0).to(DEVICE)
@@ -46,6 +46,7 @@ def predict_and_explain(img):
46
  pred_idx = output.argmax(dim=1).item()
47
  pred_name = class_names[pred_idx]
48
 
 
49
  targets = [ClassifierOutputTarget(pred_idx)]
50
  cam = GradCAMPlusPlus(model=model, target_layers=[model.blocks[-1][-1]])
51
  grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0]
@@ -56,11 +57,14 @@ def predict_and_explain(img):
56
  cam_pil = Image.fromarray(cam_image)
57
  return cam_pil, f"Prediction: {pred_name} (class index {pred_idx})"
58
 
59
- # Gradio Interface
60
  demo = gr.Interface(
61
  fn=predict_and_explain,
62
- inputs=gr.Image(type="pil", label="Upload Car Image"),
63
- outputs=[gr.Image(label="Grad-CAM++ Heatmap"), gr.Text(label="Prediction")],
 
 
 
64
  title="🚗 EfficientNetV2 Car Classifier + Grad-CAM++",
65
  description="Upload a car image to see its predicted make/model/year and what influenced the prediction.",
66
  )
 
6
  from torchvision import transforms
7
  import scipy.io
8
  from pytorch_grad_cam import GradCAMPlusPlus
9
+ from pytorch_grad_cam.utils.image import show_cam_on_image
10
  from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
11
 
12
+ # --- Load model and metadata ---
13
  MODEL_PATH = "efficientnetv2_best_model.pth"
14
  META_PATH = "cars_meta.mat"
15
  DEVICE = torch.device("cpu")
 
36
  model.eval()
37
  model.to(DEVICE)
38
 
39
+ # --- Grad-CAM++ prediction function ---
40
  def predict_and_explain(img):
41
  image_pil = img.convert("RGB").resize((224, 224))
42
  input_tensor = val_transform(image_pil).unsqueeze(0).to(DEVICE)
 
46
  pred_idx = output.argmax(dim=1).item()
47
  pred_name = class_names[pred_idx]
48
 
49
+ # Grad-CAM++
50
  targets = [ClassifierOutputTarget(pred_idx)]
51
  cam = GradCAMPlusPlus(model=model, target_layers=[model.blocks[-1][-1]])
52
  grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0]
 
57
  cam_pil = Image.fromarray(cam_image)
58
  return cam_pil, f"Prediction: {pred_name} (class index {pred_idx})"
59
 
60
+ # --- Gradio UI ---
61
  demo = gr.Interface(
62
  fn=predict_and_explain,
63
+ inputs=gr.Image(type="pil", label="Upload Car Image", height=350),
64
+ outputs=[
65
+ gr.Image(type="pil", label="Grad-CAM++ Heatmap", height=350),
66
+ gr.Text(label="Prediction")
67
+ ],
68
  title="🚗 EfficientNetV2 Car Classifier + Grad-CAM++",
69
  description="Upload a car image to see its predicted make/model/year and what influenced the prediction.",
70
  )