| from pathlib import Path |
| from typing import List |
|
|
| import cv2 |
| import gradio as gr |
| import numpy as np |
| import torch |
| from PIL import Image |
|
|
| from models import phc_models |
| from utils import utils, page_utils |
|
|
| device = torch.device('cpu') |
| if torch.cuda.is_available(): |
| device = torch.device('cuda:0') |
|
|
| BILATERIAL_WEIGHT = 'weights/phresnet18_cbis2views.pt' |
| BILATERAL_MODEL = phc_models.PHCResNet18( |
| channels=2, n=2, num_classes=1, visualize=True) |
| BILATERAL_MODEL.add_top_blocks(num_classes=1) |
| BILATERAL_MODEL.load_state_dict(torch.load( |
| BILATERIAL_WEIGHT, map_location='cpu')) |
| BILATERAL_MODEL = BILATERAL_MODEL.to(device) |
| BILATERAL_MODEL.eval() |
| INPUT_HEIGHT, INPUT_WIDTH = 600, 500 |
|
|
| SUPPORTED_IMG_EXT = ['.png', '.jpg', '.jpeg'] |
| EXAMPLE_IMAGES = [ |
| ['examples/f4b2d377f43ba0bd_left_cc.png', |
| 'examples/f4b2d377f43ba0bd_left_mlo.jpg'], |
| ['examples/f4b2d377f43ba0bd_right_cc.png', |
| 'examples/f4b2d377f43ba0bd_right_mlo.jpeg'], |
| ['examples/P_00001_LEFT_cc.jpg', 'examples/P_00001_LEFT_mlo.jpeg'], |
| ] |
|
|
| |
| test_images = np.random.randint(0, 255, (2, INPUT_HEIGHT, INPUT_WIDTH)) |
| test_images = torch.from_numpy(test_images).to(device) |
| test_images = test_images.unsqueeze(0) |
| for _ in range(10): |
| _, _, _ = BILATERAL_MODEL(test_images) |
| test_images = None |
|
|
|
|
| def filter_files(files: List) -> List: |
| """Filter uploaded files. |
| |
| The model requires a pair of CC-MLO view of the breast scan. |
| This function will filter and ensure the inputs are as expected. |
| FIlter: |
| - Not enough number of files |
| - Unsupported extensions |
| - Missing required pair or part |
| |
| Parameters |
| ---------- |
| files : List[tempfile._TemporaryFileWrapper] |
| List of path to downloaded files |
| |
| Returns |
| ------- |
| List[pathlib.Path] |
| List of path to downloaded files |
| |
| Raises |
| ------ |
| gr.Error |
| If the files is not equal to 2, |
| gr.Error |
| If the extension is unsupported |
| gr.Error |
| If specific view or side of mammography is missing. |
| """ |
| if len(files) != 2: |
| raise gr.Error( |
| f'Need exactly 2 images. Currently have {len(files)} images!') |
|
|
| file_paths = [Path(file.name) for file in files] |
|
|
| if not all([path.suffix in SUPPORTED_IMG_EXT for path in file_paths]): |
| raise gr.Error(f'There is a file with unsupported type. \ |
| Make sure all files are in {SUPPORTED_IMG_EXT}!') |
|
|
| |
| table = np.zeros((2, 2), dtype=bool) |
| bin_left = 0 |
| bin_right = 0 |
| cc_first = False |
| for idx, file in enumerate(file_paths): |
|
|
| splits = file.name.split('_') |
|
|
| |
| if any(['cc' in part.lower() for part in splits]): |
| table[0, :] = [True, True] |
| if idx == 0: |
| cc_first = True |
| if any(['mlo' in part.lower() for part in splits]): |
| table[1, :] = [True, True] |
|
|
| |
| if any(['left' in part.lower() for part in splits]): |
| table[:, 0] &= True |
| bin_left += 1 |
| elif any(['right' in part.lower() for part in splits]): |
| table[:, 1] &= True |
| bin_right += 1 |
|
|
| |
| if not cc_first: |
| file_paths.reverse() |
|
|
| |
| if bin_left < 2: |
| table[:, 0] &= False |
| if bin_right < 2: |
| table[:, 1] &= False |
|
|
| if not any([all(table[:, 0]), all(table[:, 1])]): |
| raise gr.Error('Missing bilateral-view pair for Left or Right side.') |
|
|
| return file_paths |
|
|
|
|
| def predict_bilateral(cc_file, mlo_file): |
| """Predict Bilateral Mammography. |
| |
| Parameters |
| ---------- |
| files : List[tempfile._TemporaryFileWrapper] |
| TemporaryFile object for the uploaded file |
| |
| Returns |
| ------- |
| List[List, Dict] |
| List of objects that will be used to display the result |
| """ |
|
|
| filtered_files = filter_files([cc_file, mlo_file]) |
|
|
| displays_imgs = [] |
| images = [] |
|
|
| for path in filtered_files: |
| image = np.array(Image.open(str(path))) |
| image = cv2.normalize( |
| image, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U) |
| image = cv2.resize( |
| image, (INPUT_WIDTH, INPUT_HEIGHT), interpolation=cv2.INTER_LINEAR) |
|
|
| images.append(image) |
|
|
| images = np.asarray(images).astype(np.float32) |
| im_h, im_w = images[0].shape[:2] |
|
|
| images_t = torch.from_numpy(images) |
| images_t = images_t.unsqueeze(0) |
| images_t = images_t.to(device) |
|
|
| out, _, out_refiner = BILATERAL_MODEL(images_t) |
|
|
| out_refiner = utils.mean_activations(out_refiner).numpy() |
|
|
| probability = torch.sigmoid(out).detach().cpu().item() |
| label_name = 'Malignant' if probability > 0.5 else 'Normal/Benign' |
| lebels_dict = {label_name: probability} |
|
|
| refined_view_norm = cv2.normalize( |
| out_refiner, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U) |
| refined_view = cv2.applyColorMap(refined_view_norm, cv2.COLORMAP_JET) |
| refined_view = cv2.resize( |
| refined_view, (im_w, im_h), interpolation=cv2.INTER_LINEAR) |
|
|
| image0_colored = cv2.normalize( |
| images[0], None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U) |
| image0_colored = cv2.cvtColor(image0_colored, cv2.COLOR_GRAY2RGB) |
| image1_colored = cv2.normalize( |
| images[1], None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U) |
| image1_colored = cv2.cvtColor(image1_colored, cv2.COLOR_GRAY2RGB) |
|
|
| heatmap0_overlay = cv2.addWeighted( |
| image0_colored, 1.0, refined_view, 0.5, 0) |
| heatmap1_overlay = cv2.addWeighted( |
| image1_colored, 1.0, refined_view, 0.5, 0) |
|
|
| displays_imgs += [(image0_colored, 'CC'), (image1_colored, 'MLO')] |
|
|
| displays_imgs.append((heatmap0_overlay, 'CC Interest Area')) |
| displays_imgs.append((heatmap1_overlay, 'MLO Interest Area')) |
|
|
| return displays_imgs, lebels_dict |
|
|
|
|
| def run(): |
| """Run Gradio App.""" |
| with open('index.html', encoding='utf-8') as f: |
| html_content = f.read() |
|
|
| with gr.Blocks(theme=gr.themes.Default(primary_hue=page_utils.KALBE_THEME_COLOR, secondary_hue=page_utils.KALBE_THEME_COLOR).set( |
| button_primary_background_fill='*primary_600', |
| button_primary_background_fill_hover='*primary_500', |
| button_primary_text_color='white', |
| )) as demo: |
| with gr.Column(): |
| gr.HTML(html_content) |
| with gr.Row(): |
| with gr.Column(): |
| cc_file = gr.File(file_count='single', |
| file_types=SUPPORTED_IMG_EXT, label='CC View') |
| mlo_file = gr.File(file_count='single', |
| file_types=SUPPORTED_IMG_EXT, label='MLO View') |
| with gr.Row(): |
| clear_btn = gr.Button('Clear') |
| process_btn = gr.Button('Process', variant="primary") |
| with gr.Column(): |
| output_gallery = gr.Gallery( |
| label='Highlighted Area').style(grid=[2], height='auto') |
| cancer_type = gr.Label(label='Cancer Type') |
| gr.Examples( |
| examples=EXAMPLE_IMAGES, |
| inputs=[cc_file, mlo_file], |
| ) |
| gr.Markdown('Note that this method is sensitive to input image types.\ |
| Current pipeline expect the values between 0.0-255.0') |
|
|
| process_btn.click( |
| fn=predict_bilateral, |
| inputs=[cc_file, mlo_file], |
| outputs=[output_gallery, cancer_type] |
| ) |
|
|
| clear_btn.click( |
| lambda _: ( |
| gr.update(value=None), |
| gr.update(value=None), |
| gr.update(value=None), |
| gr.update(value=None), |
| ), |
| inputs=None, |
| outputs=[ |
| cc_file, |
| mlo_file, |
| output_gallery, |
| cancer_type, |
| ], |
| ) |
|
|
| demo.launch(server_name='0.0.0.0', server_port=7860) |
|
|
|
|
| if __name__ == '__main__': |
| run() |
|
|