jacksonwambali commited on
Commit
c00ea70
·
verified ·
1 Parent(s): 404fe6a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -41
app.py CHANGED
@@ -3,82 +3,82 @@ import torch
3
  import matplotlib.pyplot as plt
4
  import numpy as np
5
  from fastai.vision.all import load_learner, PILImage
6
- from fastai.vision.utils import show_image
7
  import io
8
- from torchvision.transforms.functional import to_pil_image
9
 
10
- # Hook classes from your notebook
11
  class Hook:
12
- def __init__(self, m, f):
13
- self.hook = m.register_forward_hook(lambda m, i, o: f(o))
14
  def remove(self): self.hook.remove()
15
 
16
  class HookBwd:
17
- def __init__(self, m, f):
18
- self.hook = m.register_backward_hook(lambda m, gi, go: f(go[0]))
19
  def remove(self): self.hook.remove()
20
 
21
  # Load the learner
22
- learn = load_learner('export.pkl')
 
 
 
 
23
 
24
- # Function to predict + generate CAM
25
  def predict_with_cam(img):
26
  img = PILImage.create(img)
27
 
28
- # Get the model and target layer (adjust depending on your model)
29
  model = learn.model
30
- target_layer = model[0][-1] # Might need to adjust based on your architecture
31
-
32
- # Placeholders for activations and gradients
33
- activations = []
34
- gradients = []
35
-
36
- # Hook functions
37
  def hook_activations(out): activations.append(out)
38
  def hook_gradients(grad): gradients.append(grad)
39
-
40
- # Register hooks
41
  h1 = Hook(target_layer, hook_activations)
42
  h2 = HookBwd(target_layer, hook_gradients)
43
-
44
- # Prediction
45
  pred_class, pred_idx, probs = learn.predict(img)
46
-
47
- # Backward pass to get gradients
48
- output = learn.model(img.unsqueeze(0)) if not isinstance(img, torch.Tensor) else learn.model(img)
 
 
49
  output[0, pred_idx].backward()
50
-
51
  # Remove hooks
52
  h1.remove()
53
  h2.remove()
54
-
55
- # Generate CAM
56
- act = activations[0].detach().cpu()[0]
57
- grad = gradients[0].detach().cpu()[0]
58
-
59
  weights = grad.mean(dim=(1, 2), keepdim=True)
60
  cam = (weights * act).sum(0)
61
  cam = cam.clamp(min=0).numpy()
62
-
63
  # Normalize CAM
64
- cam -= cam.min()
65
- cam /= cam.max()
66
-
67
- # Convert to image
68
  fig, ax = plt.subplots()
69
  ax.imshow(img)
70
  ax.imshow(cam, alpha=0.5, cmap='jet')
71
  ax.axis('off')
72
-
73
- # Save CAM to an in-memory buffer
74
  buf = io.BytesIO()
75
- plt.savefig(buf, format='png')
76
  buf.seek(0)
77
-
78
- # Return predictions + CAM image
79
  return {learn.dls.vocab[i]: float(probs[i]) for i in range(len(probs))}, buf
80
 
81
- # Gradio interface
82
  interface = gr.Interface(
83
  fn=predict_with_cam,
84
  inputs=gr.Image(type='pil'),
 
3
  import matplotlib.pyplot as plt
4
  import numpy as np
5
  from fastai.vision.all import load_learner, PILImage
 
6
  import io
 
7
 
8
+ # Ensure custom classes exist before loading the model
9
  class Hook:
10
+ def __init__(self, module, func):
11
+ self.hook = module.register_forward_hook(lambda mod, inp, out: func(out))
12
  def remove(self): self.hook.remove()
13
 
14
  class HookBwd:
15
+ def __init__(self, module, func):
16
+ self.hook = module.register_full_backward_hook(lambda mod, grad_input, grad_output: func(grad_output[0]))
17
  def remove(self): self.hook.remove()
18
 
19
  # Load the learner
20
+ try:
21
+ learn = load_learner('export.pkl')
22
+ print("Model loaded successfully!")
23
+ except Exception as e:
24
+ print(f"Error loading model: {e}")
25
 
26
+ # Function to predict + generate Class Activation Map (CAM)
27
  def predict_with_cam(img):
28
  img = PILImage.create(img)
29
 
30
+ # Get model and target layer (modify as needed for your architecture)
31
  model = learn.model
32
+ target_layer = model[0][-1] # Adjust based on model architecture
33
+
34
+ activations, gradients = [], []
35
+
36
+ # Define hook functions
 
 
37
  def hook_activations(out): activations.append(out)
38
  def hook_gradients(grad): gradients.append(grad)
39
+
40
+ # Attach hooks
41
  h1 = Hook(target_layer, hook_activations)
42
  h2 = HookBwd(target_layer, hook_gradients)
43
+
44
+ # Run prediction
45
  pred_class, pred_idx, probs = learn.predict(img)
46
+
47
+ # Perform backward pass for gradients
48
+ img_tensor = learn.dls.test_dl([img]).one_batch()[0]
49
+ img_tensor.requires_grad_()
50
+ output = model(img_tensor)
51
  output[0, pred_idx].backward()
52
+
53
  # Remove hooks
54
  h1.remove()
55
  h2.remove()
56
+
57
+ # Generate Class Activation Map (CAM)
58
+ act = activations[0].detach().cpu().squeeze(0)
59
+ grad = gradients[0].detach().cpu().squeeze(0)
60
+
61
  weights = grad.mean(dim=(1, 2), keepdim=True)
62
  cam = (weights * act).sum(0)
63
  cam = cam.clamp(min=0).numpy()
64
+
65
  # Normalize CAM
66
+ cam = (cam - cam.min()) / (cam.max() - cam.min())
67
+
68
+ # Plot CAM
 
69
  fig, ax = plt.subplots()
70
  ax.imshow(img)
71
  ax.imshow(cam, alpha=0.5, cmap='jet')
72
  ax.axis('off')
73
+
74
+ # Save CAM image
75
  buf = io.BytesIO()
76
+ plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
77
  buf.seek(0)
78
+
 
79
  return {learn.dls.vocab[i]: float(probs[i]) for i in range(len(probs))}, buf
80
 
81
+ # Create Gradio interface
82
  interface = gr.Interface(
83
  fn=predict_with_cam,
84
  inputs=gr.Image(type='pil'),