ham1000 / app.py
ganteng88's picture
Upload 7 files
0dbadfe
import gradio as gr
import torch
from torchvision.transforms import transforms
import numpy as np
from typing import Optional
import torch.nn as nn
import os
from utils import page_utils
class BasicBlock(nn.Module):
"""ResNet Basic Block.
Parameters
----------
in_channels : int
Number of input channels
out_channels : int
Number of output channels
stride : int, optional
Convolution stride size, by default 1
identity_downsample : Optional[torch.nn.Module], optional
Downsampling layer, by default None
"""
def __init__(self,
in_channels: int,
out_channels: int,
stride: int = 1,
identity_downsample: Optional[torch.nn.Module] = None):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels,
out_channels,
kernel_size = 3,
stride = stride,
padding = 1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(out_channels,
out_channels,
kernel_size = 3,
stride = 1,
padding = 1)
self.bn2 = nn.BatchNorm2d(out_channels)
self.identity_downsample = identity_downsample
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply forward computation."""
identity = x
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
# Apply an operation to the identity output.
# Useful to reduce the layer size and match from conv2 output
if self.identity_downsample is not None:
identity = self.identity_downsample(identity)
x += identity
x = self.relu(x)
return x
class ResNet18(nn.Module):
"""Construct ResNet-18 Model.
Parameters
----------
input_channels : int
Number of input channels
num_classes : int
Number of class outputs
"""
def __init__(self, input_channels, num_classes):
super(ResNet18, self).__init__()
self.conv1 = nn.Conv2d(input_channels,
64, kernel_size = 7,
stride = 2, padding=3)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size = 3,
stride = 2,
padding = 1)
self.layer1 = self._make_layer(64, 64, stride = 1)
self.layer2 = self._make_layer(64, 128, stride = 2)
self.layer3 = self._make_layer(128, 256, stride = 2)
self.layer4 = self._make_layer(256, 512, stride = 2)
# Last layers
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512, num_classes)
def identity_downsample(self, in_channels: int, out_channels: int) -> nn.Module:
"""Downsampling block to reduce the feature sizes."""
return nn.Sequential(
nn.Conv2d(in_channels,
out_channels,
kernel_size = 3,
stride = 2,
padding = 1),
nn.BatchNorm2d(out_channels)
)
def _make_layer(self, in_channels: int, out_channels: int, stride: int) -> nn.Module:
"""Create sequential basic block."""
identity_downsample = None
# Add downsampling function
if stride != 1:
identity_downsample = self.identity_downsample(in_channels, out_channels)
return nn.Sequential(
BasicBlock(in_channels, out_channels, identity_downsample=identity_downsample, stride=stride),
BasicBlock(out_channels, out_channels)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.shape[0], -1)
x = self.fc(x)
return x
model = ResNet18(1, 7)
checkpoint = torch.load('ham1.ckpt', map_location=torch.device('cpu'))
# The state dict will contains net.layer_name
# Our model doesn't contains `net.` so we have to rename it
state_dict = checkpoint['state_dict']
for key in list(state_dict.keys()):
if 'net.' in key:
state_dict[key.replace('net.', '')] = state_dict[key]
del state_dict[key]
model.load_state_dict(state_dict)
model.eval()
class_names = ['akk', 'bcc', 'bkl', 'df', 'mel','nv','vasc']
class_names.sort()
examples_dir = "sample"
transformation_pipeline = transforms.Compose([
transforms.ToPILImage(),
transforms.Grayscale(num_output_channels=1),
transforms.CenterCrop((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485], std=[0.229])
])
def preprocess_image(image: np.ndarray):
"""Preprocess the input image.
Note that the input image is in RGB mode.
Parameters
----------
image: np.ndarray
Input image from callback.
"""
image = transformation_pipeline(image)
image = torch.unsqueeze(image, 0)
return image
def image_classifier(inp):
"""Image Classifier Function.
Parameters
----------
inp: Optional[np.ndarray] = None
Input image from callback
Returns
-------
Dict
A dictionary class names and its probability
"""
# If input not valid, return dummy data or raise error
if inp is None:
return {'cat': 0.3, 'dog': 0.7}
# preprocess
image = preprocess_image(inp)
image = image.to(dtype=torch.float32)
# inference
result = model(image)
# postprocess
result = torch.nn.functional.softmax(result, dim=1) # apply softmax
result = result[0].detach().numpy().tolist() # take the first batch
labeled_result = {name:score for name, score in zip(class_names, result)}
return labeled_result
# gradio code block for input and output
with gr.Blocks() as app:
gr.Markdown("# Skin Cancer Classification")
with open('index.html', encoding="utf-8") as f:
description = f.read()
# gradio code block for input and output
with gr.Blocks(theme=gr.themes.Default(primary_hue=page_utils.KALBE_THEME_COLOR, secondary_hue=page_utils.KALBE_THEME_COLOR).set(
button_primary_background_fill="*primary_600",
button_primary_background_fill_hover="*primary_500",
button_primary_text_color="white",
)) as app:
with gr.Column():
gr.HTML(description)
with gr.Row():
with gr.Column():
inp_img = gr.Image()
with gr.Row():
clear_btn = gr.Button(value="Clear")
process_btn = gr.Button(value="Process", variant="primary")
with gr.Column():
out_txt = gr.Label(label="Probabilities", num_top_classes=3)
process_btn.click(image_classifier, inputs=inp_img, outputs=out_txt)
clear_btn.click(lambda:(
gr.update(value=None),
gr.update(value=None)
),
inputs=None,
outputs=[inp_img, out_txt])
gr.Markdown("## Image Examples")
gr.Examples(
examples=[os.path.join(examples_dir, "ISIC_0000108_downsampled.jpeg"),
os.path.join(examples_dir, "ISIC_0000142_downsampled.jpeg"),
os.path.join(examples_dir, "ISIC_0012792_downsampled.jpeg"),
os.path.join(examples_dir, "ISIC_0024452.jpeg"),
os.path.join(examples_dir, "ISIC_0025957.jpeg"),
os.path.join(examples_dir, "ISIC_0026876.jpeg"),
os.path.join(examples_dir, "ISIC_0027385.jpeg"),
os.path.join(examples_dir, "ISIC_0030956.jpeg"),
],
inputs=inp_img,
outputs=out_txt,
fn=image_classifier,
cache_examples=False,
)
gr.Markdown(line_breaks=True, value='Author: Jason Adrian (jasonadriann6@gmail.com) <div class="row"><a href="https://github.com/jasonadriann?tab=repositories"><img alt="GitHub" src="https://img.shields.io/badge/Jason%20Adrian-000000?logo=github"> </div>')
# demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label")
app.launch(share=True)