Spaces:
Running
Running
| #!/usr/bin/env python | |
| # -*- coding: utf-8 -*- | |
| import torch | |
| import gradio as gr | |
| from lib import create_model | |
| from lib.options import ParamSet, _retrieve_parameter, _dispatch_by_group | |
| from lib.dataloader import ImageMixin | |
| test_weight = './weight_epoch-200_best.pt' | |
| parameter = './parameters.json' | |
| class ImageHandler(ImageMixin): | |
| def __init__(self, params): | |
| self.params = params | |
| self.transform = self._make_transforms() | |
| def set_image(self, image): | |
| image = self.transform(image) | |
| image = {'image': image.unsqueeze(0)} | |
| return image | |
| def load_parameter(parameter): | |
| _args = ParamSet() | |
| params = _retrieve_parameter(parameter) | |
| for _param, _arg in params.items(): | |
| setattr(_args, _param, _arg) | |
| _args.augmentation = 'no' | |
| _args.sampler = 'no' | |
| _args.pretrained = False | |
| _args.mlp = None | |
| _args.net = _args.model | |
| _args.device = torch.device('cpu') | |
| args_model = _dispatch_by_group(_args, 'model') | |
| args_dataloader = _dispatch_by_group(_args, 'dataloader') | |
| return args_model, args_dataloader | |
| args_model, args_dataloader = load_parameter(parameter) | |
| model = create_model(args_model) | |
| model.load_weight(test_weight) | |
| def main(image): | |
| model.eval() | |
| image_handler = ImageHandler(args_dataloader) | |
| image = image_handler.set_image(image) | |
| with torch.no_grad(): | |
| outputs = model(image) | |
| label_name = list(outputs.keys())[0] | |
| result = outputs[label_name].detach().numpy().item() | |
| result = f"{result:.2f}" | |
| return result | |
| html_content = """ | |
| <div style="padding: 15px; border: 1px solid #e0e0e0; border-radius: 5px;"> | |
| <h3>Image preprocess</h3> | |
| <p>Only grayscale 320×320 resolution works appropriately.</p> | |
| <p>The longest side of the Xp should be downscaled to 320 pixels while maintaining the aspect ratio, | |
| and the width along the shorter side should be padded black to 320 pixels. | |
| </p> | |
| <h3>Publication Details</h3> | |
| <p>See details in our publication, titled | |
| "Chest radiography as a biomarker of ageing: artificial intelligence-based, | |
| multi-institutional model development and validation in Japan" | |
| </p> | |
| <p><strong>Link:</strong> <a href="https://www.thelancet.com/journals/lanhl/article/PIIS2666-7568(23)00133-2/fulltext" target="_blank"> | |
| https://www.thelancet.com/journals/lanhl/article/PIIS2666-7568(23)00133-2/fulltext | |
| </a></p> | |
| </div> | |
| """ | |
| # Gradio | |
| with gr.Blocks(title="Aging Biomarker from CXR", | |
| css=".gradio-container {background:mintcream;}" | |
| ) as demo: | |
| gr.HTML("""<div style="text-align:center"><h2>Aging Biomarker from CXR</h2></div>""") | |
| gr.HTML(html_content) | |
| with gr.Row(): | |
| input_image = gr.Image(type="pil", image_mode="L") | |
| output_label=gr.Label(label="Estimated age") | |
| send_btn = gr.Button("Inference") | |
| send_btn.click(fn=main, inputs=input_image, outputs=output_label) | |
| with gr.Row(): | |
| gr.Examples(['./samples/66_female_xp.png'], label='Sample CXR 1: 66 years old female', inputs=input_image) | |
| gr.Examples(['./samples/28_male_xp.png'], label='Sample CXR 2: 28 years old male', inputs=input_image) | |
| demo.launch(debug=True) | |