| import torch |
| import os |
| import gradio as gr |
| import numpy as np |
| from torchvision.transforms import transforms |
| from typing import Optional |
| import torch.nn as nn |
| from utils import page_utils |
|
|
| class BasicBlock(nn.Module): |
| """ |
| ResNet Basic Block. |
| |
| This class defines a basic building block for ResNet architectures. It consists of two convolutional |
| layers with batch normalization and a ReLU activation function. Optionally, it can include an |
| identity downsample layer to match the dimensions of the input and output when the stride is not 1. |
| |
| Parameters |
| ---------- |
| in_channels : int |
| Number of input channels. |
| out_channels : int |
| Number of output channels. |
| stride : int, optional |
| Convolution stride size, by default 1. |
| identity_downsample : Optional[torch.nn.Module], optional |
| Downsampling layer, by default None. |
| |
| Methods |
| ------- |
| forward(x: torch.Tensor) -> torch.Tensor: |
| Apply forward computation. |
| """ |
|
|
| def __init__(self, |
| in_channels: int, |
| out_channels: int, |
| stride: int = 1, |
| identity_downsample: Optional[torch.nn.Module] = None): |
| super(BasicBlock, self).__init__() |
| self.conv1 = nn.Conv2d(in_channels, |
| out_channels, |
| kernel_size = 3, |
| stride = stride, |
| padding = 1) |
| self.bn1 = nn.BatchNorm2d(out_channels) |
| self.relu = nn.ReLU() |
| self.conv2 = nn.Conv2d(out_channels, |
| out_channels, |
| kernel_size = 3, |
| stride = 1, |
| padding = 1) |
| self.bn2 = nn.BatchNorm2d(out_channels) |
| self.identity_downsample = identity_downsample |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """Apply forward computation.""" |
| identity = x |
| x = self.conv1(x) |
| x = self.bn1(x) |
| x = self.relu(x) |
| x = self.conv2(x) |
| x = self.bn2(x) |
|
|
| |
| |
| if self.identity_downsample is not None: |
| identity = self.identity_downsample(identity) |
| x += identity |
| x = self.relu(x) |
| return x |
|
|
| class ResNet18(nn.Module): |
| """ |
| Construct ResNet-18 Model. |
| |
| This class defines the ResNet-18 architecture, including convolutional layers, basic blocks, and |
| fully connected layers for classification. |
| |
| Parameters |
| ---------- |
| input_channels : int |
| Number of input channels. |
| num_classes : int |
| Number of class outputs. |
| |
| Methods |
| ------- |
| forward(x: torch.Tensor) -> torch.Tensor: |
| Apply forward computation. |
| """ |
|
|
| def __init__(self, input_channels, num_classes): |
|
|
| super(ResNet18, self).__init__() |
| self.conv1 = nn.Conv2d(input_channels, |
| 64, kernel_size = 7, |
| stride = 2, padding=3) |
| self.bn1 = nn.BatchNorm2d(64) |
| self.relu = nn.ReLU() |
| self.maxpool = nn.MaxPool2d(kernel_size = 3, |
| stride = 2, |
| padding = 1) |
|
|
| self.layer1 = self._make_layer(64, 64, stride = 1) |
| self.layer2 = self._make_layer(64, 128, stride = 2) |
| self.layer3 = self._make_layer(128, 256, stride = 2) |
| self.layer4 = self._make_layer(256, 512, stride = 2) |
|
|
| |
| self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) |
| self.fc = nn.Linear(512, num_classes) |
|
|
| def identity_downsample(self, in_channels: int, out_channels: int) -> nn.Module: |
| """Downsampling block to reduce the feature sizes.""" |
| return nn.Sequential( |
| nn.Conv2d(in_channels, |
| out_channels, |
| kernel_size = 3, |
| stride = 2, |
| padding = 1), |
| nn.BatchNorm2d(out_channels) |
| ) |
|
|
| def _make_layer(self, in_channels: int, out_channels: int, stride: int) -> nn.Module: |
| """Create sequential basic block.""" |
| identity_downsample = None |
|
|
| |
| if stride != 1: |
| identity_downsample = self.identity_downsample(in_channels, out_channels) |
|
|
| return nn.Sequential( |
| BasicBlock(in_channels, out_channels, identity_downsample=identity_downsample, stride=stride), |
| BasicBlock(out_channels, out_channels) |
| ) |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = self.conv1(x) |
| x = self.bn1(x) |
| x = self.relu(x) |
| x = self.maxpool(x) |
|
|
| x = self.layer1(x) |
| x = self.layer2(x) |
| x = self.layer3(x) |
| x = self.layer4(x) |
|
|
| x = self.avgpool(x) |
| x = x.view(x.shape[0], -1) |
| x = self.fc(x) |
| return x |
|
|
| model = ResNet18(3, 3) |
| checkpoint = torch.load('epoch=49-step=1750.ckpt', map_location=torch.device('cpu')) |
|
|
| state_dict = checkpoint['state_dict'] |
| for key in list(state_dict.keys()): |
| if 'net.' in key: |
| state_dict[key.replace('net.', '')] = state_dict[key] |
| del state_dict[key] |
| |
| model.load_state_dict(state_dict) |
| model.eval() |
|
|
|
|
| class_names = ['benign', 'malignant', 'normal'] |
| class_names.sort() |
|
|
| example_dir = "SAMPLES" |
|
|
| transformation_pipeline = transforms.Compose([ |
| transforms.ToPILImage(), |
| transforms.Grayscale(num_output_channels=3), |
| transforms.Resize((256, 256)), |
| transforms.RandomRotation(20), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.233827, 0.2338219, 0.23378967], std=[0.2016421162328173, 0.20164345656093885, 0.20160390432148026]) |
| ]) |
|
|
| def preprocess_image(image: np.ndarray): |
| """Preprocess the input image. |
| |
| Note that the input image is in RGB mode. |
| |
| Parameters |
| ---------- |
| image: np.ndarray |
| Input image from callback. |
| """ |
|
|
| image = transformation_pipeline(image) |
| image = torch.unsqueeze(image, 0) |
|
|
| return image |
|
|
| def image_classifier(inp): |
| """Image Classifier Function. |
| |
| Parameters |
| ---------- |
| inp: Optional[np.ndarray] = None |
| Input image from callback |
| |
| Returns |
| ------- |
| Dict |
| A dictionary class names and its probability |
| """ |
|
|
| |
| if inp is None: |
| return {gr.Error()} |
|
|
| |
| image = preprocess_image(inp) |
| image = image.to(dtype=torch.float32) |
|
|
| |
| result = model(image) |
|
|
| |
| result = torch.nn.functional.softmax(result, dim=1) |
| result = result[0].detach().numpy().tolist() |
| labeled_result = {name:score for name, score in zip(class_names, result)} |
|
|
| return labeled_result |
|
|
| with open('index.html', encoding="utf-8") as f: |
| description = f.read() |
|
|
| with open('author.html', encoding="utf-8") as author: |
| author_info = author.read() |
|
|
| with gr.Blocks( |
| theme=gr.themes.Default(primary_hue=page_utils.KALBE_THEME_COLOR, secondary_hue=page_utils.KALBE_THEME_COLOR).set( |
| button_primary_background_fill="*primary_600", |
| button_primary_background_fill_hover="*primary_500", |
| button_primary_text_color="white",) |
| ) as demo: |
| |
| with gr.Column(): |
| gr.HTML(description) |
| |
| with gr.Row(): |
| with gr.Column(): |
| inp = gr.Image(label="image", image_mode="RGB") |
| with gr.Row(): |
| clear_btn = gr.Button("Clear") |
| submit_btn = gr.Button("Submit") |
|
|
| |
| out = gr.Label(label="prediction", num_top_classes=3) |
|
|
| |
| submit_btn.click(fn=image_classifier, inputs=inp, outputs=out) |
| clear_btn.click( |
| lambda: ( |
| gr.update(value=None), |
| gr.update(value=None), |
| ), |
| inputs=None, |
| outputs=[inp, out] |
| ) |
|
|
| |
| gr.Markdown("## Image Examples") |
| gr.Examples( |
| example_dir, |
| inputs=[inp], |
| label="Image Examples", |
| cache_examples=False |
| ) |
| |
| with gr.Column(): |
| gr.HTML(author_info) |
| |
| demo.launch(share=True) |