Spaces:
Build error
Build error
| import torch | |
| from monai.bundle import ConfigParser | |
| import gradio as gr | |
| from utils import page_utils | |
| parser = ConfigParser() # load configuration files that specify various parameters for running the MONAI workflow. | |
| parser.read_config(f="configs/inference.json") # read the config from specified JSON file | |
| parser.read_meta(f="configs/metadata.json") # read the metadata from specified JSON file | |
| inference = parser.get_parsed_content("inferer") | |
| network = parser.get_parsed_content("network_def") | |
| preprocess = parser.get_parsed_content("preprocessing") | |
| state_dict = torch.load("models/model.pt", map_location=torch.device('cpu')) | |
| network.load_state_dict(state_dict, strict=True) # Loads a model’s parameter dictionary | |
| class_names = { | |
| 0: "Other", | |
| 1: "Inflammatory", | |
| 2: "Epithelial", | |
| 3: "Spindle-Shaped", | |
| } | |
| def classify_image(image_file, label_file): | |
| data = {"image":image_file, "label":label_file} | |
| batch = preprocess(data) | |
| batch['image'] = batch['image'] | |
| network.eval() | |
| with torch.no_grad(): | |
| pred = inference(batch['image'].unsqueeze(dim=0), network) # expect 4 channels input (3 RGB, 1 Label mask) | |
| prob = pred.softmax(-1).detach().cpu().numpy()[0] | |
| confidences = {class_names[i]: float(prob[i]) for i in range(len(class_names))} | |
| return confidences | |
| example_files1 = [ | |
| ['sample_data/Images/test_11_2_0628.png', | |
| 'sample_data/Labels/test_11_2_0628.png'], | |
| ['sample_data/Images/test_9_4_0149.png', | |
| 'sample_data/Labels/test_9_4_0149.png'], | |
| ['sample_data/Images/test_12_3_0292.png', | |
| 'sample_data/Labels/test_12_3_0292.png'], | |
| ['sample_data/Images/test_9_4_0019.png', | |
| 'sample_data/Labels/test_9_4_0019.png'] | |
| ] | |
| example_files2 = [ | |
| ['sample_data/Images/test_14_3_0433.png', | |
| 'sample_data/Labels/test_14_3_0433.png'], | |
| ['sample_data/Images/test_14_4_0544.png', | |
| 'sample_data/Labels/test_14_4_0544.png'], | |
| ['sample_data/Images/train_1_1_0095.png', | |
| 'sample_data/Labels/train_1_1_0095.png'], | |
| ['sample_data/Images/train_1_3_0020.png', | |
| 'sample_data/Labels/train_1_3_0020.png'], | |
| ] | |
| with open('index.html', encoding='utf-8') as file: | |
| html_content = file.read() | |
| with gr.Blocks(theme=gr.themes.Default(primary_hue=page_utils.KALBE_THEME_COLOR, secondary_hue=page_utils.KALBE_THEME_COLOR).set( | |
| button_primary_background_fill='*primary_600', | |
| button_primary_background_fill_hover='*primary_500', | |
| button_primary_text_color='white', | |
| )) as app: | |
| gr.HTML(html_content) | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| inp_img = gr.Image(type="filepath", image_mode="RGB") | |
| label_img = gr.Image(type="filepath", image_mode="L") | |
| with gr.Row(): | |
| clear_btn = gr.Button(value="Clear") | |
| process_btn = gr.Button(value="Process", variant="primary") | |
| out_txt = gr.Label(label="Probabilities", num_top_classes=4) | |
| process_btn.click(fn=classify_image, inputs=[inp_img, label_img], outputs=out_txt) | |
| clear_btn.click(lambda:( | |
| gr.update(value=None), | |
| gr.update(value=None), | |
| gr.update(value=None) | |
| ), | |
| inputs=None, | |
| outputs=[inp_img, label_img,out_txt] | |
| ) | |
| gr.Markdown("## Image Examples") | |
| with gr.Row(): | |
| for file in example_files1: | |
| gr.Examples( | |
| [file], inputs=[inp_img, label_img] | |
| ) | |
| with gr.Row(): | |
| for file in example_files2: | |
| gr.Examples( | |
| [file], inputs=[inp_img, label_img] | |
| ) | |
| app.launch() |