Spaces:
Sleeping
Sleeping
| from typing import List | |
| from torch import topk | |
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from utils import MODELS, LABELS | |
| from preprocess_data import build_transform_classification | |
| from load_model import load_model | |
| from run_inference import run_inference | |
| from xai import grad_cam | |
| class WebUI: | |
| def __init__(self): | |
| super().__init__() | |
| self.labels = LABELS | |
| self.num_labels = len(LABELS) | |
| self.image_path = "example_images/" | |
| img_filenames = ['00010575_002', '00010828_039', '00011925_072', | |
| '00018253_059', '00020482_032', '00026221_001'] | |
| self.examples = [ | |
| f"{self.image_path}/{filename}.png" for filename in img_filenames] | |
| self.bbox_examples = [ | |
| f"{self.image_path}/{filename}_bbox.png" for filename in img_filenames] | |
| self.examples_dict = dict(zip(self.examples, self.bbox_examples)) | |
| self.model_str = "densenet121" | |
| self.normalization = "imagenet" | |
| self.img_size = 224 | |
| self.set_model(self.model_str) | |
| self.possible_models = MODELS.keys() | |
| self.device = self.set_device() | |
| def set_model(self, model_str) -> None: | |
| if model_str not in MODELS.keys(): | |
| print("Model not found") | |
| if model_str is None: | |
| print("No model found") | |
| ckpt_file = MODELS[model_str]["ckpt_path"] | |
| self.model, self.normalization, self.img_size = load_model( | |
| ckpt_file, num_labels=self.num_labels, model_str=model_str) | |
| self.model.eval() | |
| def preprocess_image(self, image: Image) -> torch.Tensor: | |
| if image is None: | |
| print("No image found") | |
| transform_pipeline = build_transform_classification( | |
| normalize=self.normalization, crop_size=224, resize=256, tta=True) | |
| transformed_image = transform_pipeline(image) | |
| return transformed_image | |
| def set_device(self): | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda") # Use GPU | |
| else: | |
| device = torch.device("cpu") # Use CPU | |
| return device | |
| def classify_image(self, image): | |
| if image is None: | |
| raise ValueError("No image found") | |
| top10 = self.run_inference(image) | |
| return {self.labels[top10[1][0][i]]: float(top10[0][0][i]) for i in range(self.num_labels)} | |
| def run_inference(self, image: Image) -> List[float]: | |
| if image is None: | |
| raise ValueError("No image found") | |
| input_tensor = self.preprocess_image(image) | |
| outputs = run_inference(self.model, self.model_str, | |
| input_tensor, self.device, tta=True) | |
| outputs = topk(outputs, k=self.num_labels) | |
| return outputs | |
| def explain_pred(self, image): | |
| if image is None: | |
| raise ValueError("No image found") | |
| input_tensor = self.preprocess_image(image) | |
| cam_image = grad_cam(self.model, self.model_str, input_tensor, image) | |
| return cam_image | |
| def run(self): | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| gr.Markdown(""" | |
| <h1 align="center">Image classification and explainability</h3> | |
| <p align="center"> | |
| This demo shows classification probabilities given four models fine-tuned for the chest X-ray14-dataset. | |
| To create the saliency maps the Grad-CAM method is used. | |
| </p> | |
| <p align="center">HOW TO: (1) Choose a model, (2) choose an image from the examples and (3) click on "Run analysis".</p> | |
| """) | |
| with gr.Row(): | |
| with gr.Group(): | |
| model_select = gr.Dropdown( | |
| choices=self.possible_models, | |
| label="Model selection", | |
| value=self.model_str | |
| ) | |
| with gr.Group(): | |
| image = gr.Image(type="pil", height=512, | |
| label="Input image", show_label=True) | |
| gr.Examples( | |
| label="Examples", | |
| examples=self.examples, | |
| inputs=image, | |
| ) | |
| with gr.Column(scale=1, variant="panel"): | |
| run_btn = gr.Button( | |
| "Run analysis", variant="primary", elem_id="run-button") | |
| with gr.Row(): | |
| with gr.Column(scale=1, variant="panel"): | |
| labels = gr.Label(num_top_classes=self.num_labels, | |
| label="Classification probabilities", show_label=True) | |
| with gr.Column(scale=1, variant="panel"): | |
| saliency = gr.Image( | |
| height=512, label="Saliency map given Grad-CAM", show_label=True) | |
| run_btn.click( | |
| fn=self.run_inference, | |
| inputs=image, | |
| outputs=labels, | |
| ) | |
| run_btn.click( | |
| fn=lambda x: self.classify_image(x), | |
| inputs=image, | |
| outputs=labels, | |
| ) | |
| run_btn.click( | |
| fn=lambda x: self.explain_pred(x), | |
| inputs=image, | |
| outputs=saliency, | |
| ) | |
| run_btn.click( | |
| fn=lambda x: self.set_model(x), | |
| inputs=model_select, | |
| ) | |
| demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=False) | |
| def main(): | |
| ui = WebUI() | |
| ui.run() | |
| if __name__ == "__main__": | |
| main() | |