Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import os | |
| import numpy as np | |
| import cv2 | |
| import matplotlib.pyplot as plt | |
| import torch | |
| import albumentations as albu | |
| from torch.utils.data import DataLoader | |
| from torch.utils.data import Dataset as BaseDataset | |
| from catalyst.dl import SupervisedRunner | |
| import segmentation_models_pytorch as smp | |
| from io import StringIO | |
| # streamlit run c:/Users/ronni/Downloads/polyp_seg_web_app/app.py | |
| x_test_dir = 'test/test/images' | |
| y_test_dir = 'test/test/masks' | |
| ENCODER = 'mobilenet_v2' | |
| ENCODER_WEIGHTS = 'imagenet' | |
| CLASSES = ['polyp', 'background'] | |
| ACTIVATION = 'sigmoid' | |
| preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS) | |
| def visualize(**images): | |
| """Plot images in one row.""" | |
| n = len(images) | |
| plt.figure(figsize=(16, 5)) | |
| for i, (name, image) in enumerate(images.items()): | |
| plt.subplot(1, n, i + 1) | |
| plt.xticks([]) | |
| plt.yticks([]) | |
| plt.title(' '.join(name.split('_')).title()) | |
| plt.imshow(image) | |
| plt.savefig('x',dpi=400) | |
| st.image('x.png') | |
| def get_training_augmentation(): | |
| train_transform = [ | |
| albu.HorizontalFlip(p=0.5), | |
| albu.ShiftScaleRotate(scale_limit=0.5, rotate_limit=0, shift_limit=0.1, p=1, border_mode=0), | |
| albu.Resize(576, 736, always_apply=True, p=1), | |
| albu.IAAAdditiveGaussianNoise(p=0.2), | |
| albu.IAAPerspective(p=0.5), | |
| albu.OneOf( | |
| [ | |
| albu.CLAHE(p=1), | |
| albu.RandomBrightness(p=1), | |
| albu.RandomGamma(p=1), | |
| ], | |
| p=0.9, | |
| ), | |
| albu.OneOf( | |
| [ | |
| albu.IAASharpen(p=1), | |
| albu.Blur(blur_limit=3, p=1), | |
| albu.MotionBlur(blur_limit=3, p=1), | |
| ], | |
| p=0.9, | |
| ), | |
| albu.OneOf( | |
| [ | |
| albu.RandomContrast(p=1), | |
| albu.HueSaturationValue(p=1), | |
| ], | |
| p=0.9, | |
| ), | |
| ] | |
| return albu.Compose(train_transform) | |
| def get_validation_augmentation(): | |
| """Add paddings to make image shape divisible by 32""" | |
| test_transform = [ | |
| albu.Resize(576, 736) | |
| ] | |
| return albu.Compose(test_transform) | |
| def to_tensor(x, **kwargs): | |
| return x.transpose(2, 0, 1).astype('float32') | |
| def get_preprocessing(preprocessing_fn): | |
| """Construct preprocessing transform | |
| Args: | |
| preprocessing_fn (callbale): data normalization function | |
| (can be specific for each pretrained neural network) | |
| Return: | |
| transform: albumentations.Compose | |
| """ | |
| _transform = [ | |
| albu.Lambda(image=preprocessing_fn), | |
| albu.Lambda(image=to_tensor, mask=to_tensor), | |
| ] | |
| return albu.Compose(_transform) | |
| class Dataset(BaseDataset): | |
| """Args: | |
| images_dir (str): path to images folder | |
| masks_dir (str): path to segmentation masks folder | |
| class_values (list): values of classes to extract from segmentation mask | |
| augmentation (albumentations.Compose): data transfromation pipeline | |
| (e.g. flip, scale, etc.) | |
| preprocessing (albumentations.Compose): data preprocessing | |
| (e.g. noralization, shape manipulation, etc.) | |
| """ | |
| CLASSES = ['polyp', 'background'] | |
| def __init__( | |
| self, | |
| images_dir, | |
| masks_dir, | |
| classes=None, | |
| augmentation=None, | |
| preprocessing=None, | |
| single_file=False | |
| ): | |
| if single_file: | |
| self.ids = images_dir | |
| self.images_fps = os.path.join('test/test/images', self.ids) | |
| self.masks_fps = os.path.join('test/test/masks', self.ids) | |
| else: | |
| self.ids = os.listdir(images_dir) | |
| self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids] | |
| self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids] | |
| # convert str names to class values on masks | |
| self.class_values = [self.CLASSES.index(cls.lower()) for cls in classes] | |
| self.augmentation = augmentation | |
| self.preprocessing = preprocessing | |
| def __getitem__(self, i): | |
| # read data | |
| image = cv2.imread(self.images_fps) | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| mask = cv2.imread(self.masks_fps, 0) | |
| mask[np.where(mask < 8)] = 0 | |
| mask[np.where(mask > 8)] = 255 | |
| # extract certain classes from mask (e.g. polyp) | |
| masks = [(mask == v) for v in self.class_values] | |
| mask = np.stack(masks, axis=-1).astype('float') | |
| # apply augmentations | |
| if self.augmentation: | |
| sample = self.augmentation(image=image, mask=mask) | |
| image, mask = sample['image'], sample['mask'] | |
| # apply preprocessing | |
| if self.preprocessing: | |
| sample = self.preprocessing(image=image, mask=mask) | |
| image, mask = sample['image'], sample['mask'] | |
| return image, mask | |
| def __len__(self): | |
| return len(self.ids) | |
| def model_infer(img_name): | |
| model = smp.UnetPlusPlus( | |
| encoder_name=ENCODER, | |
| encoder_weights=ENCODER_WEIGHTS, | |
| encoder_depth=5, | |
| decoder_channels=(256, 128, 64, 32, 16), | |
| classes=len(CLASSES), | |
| activation=ACTIVATION, | |
| decoder_attention_type=None, | |
| ) | |
| model.load_state_dict(torch.load('best.pth', map_location=torch.device('cpu'))['model_state_dict']) | |
| model.eval() | |
| test_dataset = Dataset( | |
| img_name, | |
| img_name, | |
| augmentation=get_validation_augmentation(), | |
| preprocessing=get_preprocessing(preprocessing_fn), | |
| classes=CLASSES, | |
| single_file=True | |
| ) | |
| test_dataloader = DataLoader(test_dataset) | |
| loaders = {"infer": test_dataloader} | |
| runner = SupervisedRunner() | |
| logits = [] | |
| f = 0 | |
| for prediction in runner.predict_loader(model=model, loader=loaders['infer'],cpu=True): | |
| if f < 3: | |
| logits.append(prediction['logits']) | |
| f = f + 1 | |
| else: | |
| break | |
| threshold = 0.5 | |
| break_at = 1 | |
| for i, (input, output) in enumerate(zip( | |
| test_dataset, logits)): | |
| image, mask = input | |
| image_vis = image.transpose(1, 2, 0) | |
| gt_mask = mask[0].astype('uint8') | |
| pr_mask = (output[0].numpy() > threshold).astype('uint8')[0] | |
| i = i + 1 | |
| if i >= break_at: | |
| break | |
| return image_vis, gt_mask, pr_mask | |
| PAGE_TITLE = "Polyp Segmentation" | |
| SUBHEADER = "Polyps are growths in the colon which can be precursors to colon cancer and are of particular interest \ | |
| when performing colonoscopies. Improving automatic detection of polyps helps doctors analyze thousands of frames from colonoscopy videos \ | |
| and leads to more reliable and efficient prevention of colon cancer. This web app uses a CNN trained on colonoscopy images from the Kvasir dataset." | |
| def file_selector(folder_path='.'): | |
| filenames = os.listdir(folder_path) | |
| selected_filename = st.selectbox('Select a file', filenames) | |
| return os.path.join(folder_path, selected_filename) | |
| def file_selector_ui(): | |
| folder_path = './test/test/images' | |
| filename = file_selector(folder_path=folder_path) | |
| printname = list(filename) | |
| printname[filename.rfind('\\')] = '/' | |
| st.write('You selected`%s`' % ''.join(printname)) | |
| return filename | |
| def file_upload(folder_path='.'): | |
| filenames = os.listdir(folder_path) | |
| folder_path = './test/test/images' | |
| uploaded_file = st.file_uploader("Choose a file") | |
| filename = os.path.join(folder_path, uploaded_file.name) | |
| printname = list(filename) | |
| printname[filename.rfind('\\')] = '/' | |
| st.write('You selected`%s`' % ''.join(printname)) | |
| return filename | |
| def main(): | |
| st.set_page_config(page_title=PAGE_TITLE, layout="wide") | |
| st.title(PAGE_TITLE) | |
| st.markdown(SUBHEADER) | |
| image_path = file_selector_ui() | |
| # image_path = file_upload() | |
| image_path = os.path.abspath(image_path) | |
| to_infer = image_path[image_path.rfind("\\") + 1:] | |
| if os.path.isfile(image_path) is True: | |
| _, file_extension = os.path.splitext(image_path) | |
| if file_extension == ".jpg": | |
| image_vis, gt_mask, pr_mask = model_infer(to_infer) | |
| visualize( | |
| image=image_vis, | |
| #ground_truth_mask=gt_mask, | |
| predicted_mask=pr_mask | |
| ) | |
| if __name__ == "__main__": | |
| main() |