jacksonwambali commited on
Commit
a76f433
·
verified ·
1 Parent(s): a8776a7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -9
app.py CHANGED
@@ -1,17 +1,89 @@
1
  import gradio as gr
 
 
 
2
  from fastai.vision.all import load_learner, PILImage
 
 
 
3
 
4
- # Import any custom code here (if needed)
5
- # Example:
6
- # from your_module import CustomTransform, CustomModel
 
 
7
 
 
 
 
 
 
 
8
  learn = load_learner('export.pkl')
9
 
10
- def predict(img):
 
11
  img = PILImage.create(img)
12
- pred, pred_idx, probs = learn.predict(img)
13
- labels = learn.dls.vocab
14
- return {labels[i]: float(probs[i]) for i in range(len(labels))}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- interface = gr.Interface(fn=predict, inputs=gr.Image(), outputs=gr.Label(num_top_classes=3))
17
- interface.launch(share=True)
 
1
  import gradio as gr
2
+ import torch
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
  from fastai.vision.all import load_learner, PILImage
6
+ from fastai.vision.utils import show_image
7
+ import io
8
+ from torchvision.transforms.functional import to_pil_image
9
 
10
+ # Hook classes from your notebook
11
+ class Hook:
12
+ def __init__(self, m, f):
13
+ self.hook = m.register_forward_hook(lambda m, i, o: f(o))
14
+ def remove(self): self.hook.remove()
15
 
16
+ class HookBwd:
17
+ def __init__(self, m, f):
18
+ self.hook = m.register_backward_hook(lambda m, gi, go: f(go[0]))
19
+ def remove(self): self.hook.remove()
20
+
21
+ # Load the learner
22
  learn = load_learner('export.pkl')
23
 
24
+ # Function to predict + generate CAM
25
+ def predict_with_cam(img):
26
  img = PILImage.create(img)
27
+
28
+ # Get the model and target layer (adjust depending on your model)
29
+ model = learn.model
30
+ target_layer = model[0][-1] # Might need to adjust based on your architecture
31
+
32
+ # Placeholders for activations and gradients
33
+ activations = []
34
+ gradients = []
35
+
36
+ # Hook functions
37
+ def hook_activations(out): activations.append(out)
38
+ def hook_gradients(grad): gradients.append(grad)
39
+
40
+ # Register hooks
41
+ h1 = Hook(target_layer, hook_activations)
42
+ h2 = HookBwd(target_layer, hook_gradients)
43
+
44
+ # Prediction
45
+ pred_class, pred_idx, probs = learn.predict(img)
46
+
47
+ # Backward pass to get gradients
48
+ output = learn.model(img.unsqueeze(0)) if not isinstance(img, torch.Tensor) else learn.model(img)
49
+ output[0, pred_idx].backward()
50
+
51
+ # Remove hooks
52
+ h1.remove()
53
+ h2.remove()
54
+
55
+ # Generate CAM
56
+ act = activations[0].detach().cpu()[0]
57
+ grad = gradients[0].detach().cpu()[0]
58
+
59
+ weights = grad.mean(dim=(1, 2), keepdim=True)
60
+ cam = (weights * act).sum(0)
61
+ cam = cam.clamp(min=0).numpy()
62
+
63
+ # Normalize CAM
64
+ cam -= cam.min()
65
+ cam /= cam.max()
66
+
67
+ # Convert to image
68
+ fig, ax = plt.subplots()
69
+ ax.imshow(img)
70
+ ax.imshow(cam, alpha=0.5, cmap='jet')
71
+ ax.axis('off')
72
+
73
+ # Save CAM to an in-memory buffer
74
+ buf = io.BytesIO()
75
+ plt.savefig(buf, format='png')
76
+ buf.seek(0)
77
+
78
+ # Return predictions + CAM image
79
+ return {learn.dls.vocab[i]: float(probs[i]) for i in range(len(probs))}, buf
80
+
81
+ # Gradio interface
82
+ interface = gr.Interface(
83
+ fn=predict_with_cam,
84
+ inputs=gr.Image(type='pil'),
85
+ outputs=[gr.Label(num_top_classes=3), gr.Image(type='pil')],
86
+ title="Image Classifier with CAM"
87
+ )
88
 
89
+ interface.launch()