Spaces:
Sleeping
Sleeping
| import gradio | |
| from fastai.vision.all import * | |
| MODELS_PATH = Path('./models') | |
| EXAMPLES_PATH = Path('./examples') | |
| learn = load_learner(MODELS_PATH/'model.pkl') | |
| labels = learn.dls.vocab | |
| class Hook(): | |
| def __init__(self, m): | |
| self.hook = m.register_forward_hook(self.hook_func) | |
| def hook_func(self, m, i, o): self.stored = o.detach().clone() | |
| def __enter__(self, *args): return self | |
| def __exit__(self, *args): self.hook.remove() | |
| class HookBwd(): | |
| def __init__(self, m): | |
| self.hook = m.register_backward_hook(self.hook_func) | |
| def hook_func(self, m, gi, go): self.stored = go[0].detach().clone() | |
| def __enter__(self, *args): return self | |
| def __exit__(self, *args): self.hook.remove() | |
| def predict(img): | |
| img = PILImage.create(img) | |
| _pred, _pred_w_idx, probs = learn.predict(img) | |
| labels_probs = {labels[i]: float(probs[i]) for i, _ in enumerate(labels)} | |
| x, = first(learn.dls.test_dl([img])) | |
| with torch.no_grad(): | |
| output = learn.model.eval()(x) | |
| cls = int(output.argmax()) | |
| x_dec = TensorImage(learn.dls.train.decode((x,))[0][0]) | |
| with HookBwd(learn.model[0].model.layer4) as hookg: | |
| with Hook(learn.model[0].model.layer4) as hook: | |
| output = learn.model.eval()(x) | |
| act = hook.stored | |
| output[0,cls].backward() | |
| grad = hookg.stored | |
| w = grad[0].mean(dim=[1,2], keepdim=True) | |
| cam_map = (w * act[0]).sum(0) | |
| _,ax = plt.subplots() | |
| x_dec.show(ctx=ax) | |
| ax.imshow(cam_map.detach().cpu(), alpha=0.7, extent=(0,128,128,0), | |
| interpolation='bilinear', cmap='magma'); | |
| if os.path.exists("gradcam.jpg"): | |
| os.remove("gradcam.jpg") | |
| plt.savefig("gradcam.jpg", format="jpg", bbox_inches='tight') | |
| plt.close() | |
| if learn.dls.vocab[cls] == "Negative": | |
| img.save("gradcam.jpg", format="JPEG") | |
| return labels_probs, Path("gradcam.jpg") | |
| with open('gradio_article.md') as f: | |
| article = f.read() | |
| interface_options = { | |
| "title": "RSNA Pneumonia Detection", | |
| "description": "An algorithm that automatically detects potential pneumonia cases. Upload an image or select from the examples below.", | |
| "examples": [f'{EXAMPLES_PATH}/{f.name}' for f in EXAMPLES_PATH.iterdir()], | |
| "article": article, | |
| "layout": "horizontal", | |
| "theme": "default", | |
| } | |
| demo = gradio.Interface(fn=predict, | |
| inputs=gradio.inputs.Image(shape=(512, 512), label="Chest X-ray"), | |
| outputs=[gradio.outputs.Label(num_top_classes=5, label="Detected Class"), | |
| gradio.outputs.Image(type="filepath", label="GradCAM")], | |
| **interface_options) | |
| launch_options = { | |
| "enable_queue": True, | |
| "share": False, | |
| } | |
| demo.launch(**launch_options) | |