File size: 2,784 Bytes
7537fbd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import gradio
from fastai.vision.all import *

MODELS_PATH = Path('./models')
EXAMPLES_PATH = Path('./examples')

learn = load_learner(MODELS_PATH/'model.pkl')
labels = learn.dls.vocab

class Hook():
    def __init__(self, m):
        self.hook = m.register_forward_hook(self.hook_func)   
    def hook_func(self, m, i, o): self.stored = o.detach().clone()
    def __enter__(self, *args): return self
    def __exit__(self, *args): self.hook.remove()

class HookBwd():
    def __init__(self, m):
        self.hook = m.register_backward_hook(self.hook_func)   
    def hook_func(self, m, gi, go): self.stored = go[0].detach().clone()
    def __enter__(self, *args): return self
    def __exit__(self, *args): self.hook.remove()

def predict(img):
    img = PILImage.create(img)
    _pred, _pred_w_idx, probs = learn.predict(img)
    labels_probs = {labels[i]: float(probs[i]) for i, _ in enumerate(labels)}

    x, = first(learn.dls.test_dl([img]))
    with torch.no_grad():
        output = learn.model.eval()(x)
        cls = int(output.argmax())

    x_dec = TensorImage(learn.dls.train.decode((x,))[0][0])
    with HookBwd(learn.model[0].model.layer4) as hookg:
        with Hook(learn.model[0].model.layer4) as hook:
            output = learn.model.eval()(x)
            act = hook.stored
        output[0,cls].backward()
        grad = hookg.stored

    w = grad[0].mean(dim=[1,2], keepdim=True)
    cam_map = (w * act[0]).sum(0)

    _,ax = plt.subplots()
    x_dec.show(ctx=ax)
    ax.imshow(cam_map.detach().cpu(), alpha=0.7, extent=(0,128,128,0),
                interpolation='bilinear', cmap='magma');
    if os.path.exists("gradcam.jpg"):
        os.remove("gradcam.jpg")
    plt.savefig("gradcam.jpg", format="jpg", bbox_inches='tight')
    plt.close()

    if learn.dls.vocab[cls] == "Negative":
        img.save("gradcam.jpg", format="JPEG")
    
    return labels_probs, Path("gradcam.jpg")

with open('gradio_article.md') as f:
    article = f.read()

interface_options = {
    "title": "RSNA Pneumonia Detection",
    "description": "An algorithm that automatically detects potential pneumonia cases. Upload an image or select from the examples below.",
    "examples": [f'{EXAMPLES_PATH}/{f.name}' for f in EXAMPLES_PATH.iterdir()],
    "article": article,
    "layout": "horizontal",
    "theme": "default",
}

demo = gradio.Interface(fn=predict,
                      inputs=gradio.inputs.Image(shape=(512, 512), label="Chest X-ray"),
                      outputs=[gradio.outputs.Label(num_top_classes=5, label="Detected Class"),
                                gradio.outputs.Image(type="filepath", label="GradCAM")],
                      **interface_options)

launch_options = {
    "enable_queue": True,
    "share": False,
}

demo.launch(**launch_options)