arpandesign's picture
Update app.py
6a4fc80 verified
import gradio as gr
import random
from PIL import Image, ImageEnhance
import torch
from glob import glob
from torch import nn
from torchvision import transforms
def predict_number(custom_image: Image.Image):
custom_image = ImageEnhance.Contrast(custom_image.convert("L")).enhance(5.0).point(lambda p: 255 if p > 128 else 0).resize((28, 28))
transform = transforms.Compose([
transforms.ToTensor(),
])
model_0.eval()
with torch.inference_mode():
pred = torch.softmax(model_0(transform(custom_image).unsqueeze(0).to(device)), dim=1)
class_names = list("0123456789")
return {class_names[i]: float(pred[0][i]) for i in range(len(class_names))}
class NumberClassifier(nn.Module):
def __init__(self, input_shape: int, hidden_units: int, output_shape: int) -> None:
super().__init__()
self.conv_block_1 = nn.Sequential(
nn.Conv2d(in_channels=input_shape,
out_channels=hidden_units,
kernel_size=2, # how big is the square that's going over the image?
stride=1, # default
), # options = "valid" (no padding) or "same" (output has same shape as input) or int for specific number
nn.ReLU(),
nn.Conv2d(in_channels=hidden_units,
out_channels=hidden_units,
kernel_size=2,
stride=1,
),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2,
stride=2) # default stride value is same as kernel_size
)
self.conv_block_2 = nn.Sequential(
nn.Conv2d(hidden_units, hidden_units, kernel_size=2),
nn.ReLU(),
nn.Conv2d(hidden_units, hidden_units, kernel_size=2),
nn.ReLU(),
nn.MaxPool2d(2)
)
self.classifier = nn.Sequential(
nn.Flatten(),
# Where did this in_features shape come from?
# It's because each layer of our network compresses and changes the shape of our input data.
nn.Linear(in_features=hidden_units*5*5,
out_features=output_shape)
)
def forward(self, x: torch.Tensor):
x = self.conv_block_1(x)
# print(x.shape)
x = self.conv_block_2(x)
# print(x.shape)
x = self.classifier(x)
# print(x.shape)
return x
# return self.classifier(self.conv_block_2(self.conv_block_1(x))) # <- leverage the benefits of operator fusion
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(42)
model_0 = NumberClassifier(input_shape=1, # number of color channels (3 for RGB)
hidden_units=10,
output_shape=10).to(device)
model_0.load_state_dict(torch.load("models/pytorch_num_classifier_final_model_with_EMNIST.pth", map_location=torch.device('cpu')))
title = "Number Classifier Minimal"
description = "An Image feature extractor computer vision model to classify images of handwritten digits."
article = "Created at [09. PyTorch Model Deployment](https://www.learnpytorch.io/09_pytorch_model_deployment/)."
example_list = [[str(filepath)] for filepath in random.sample(glob("examples/*"), k=25)]
example_list
demo = gr.Interface(fn=predict_number,
inputs=gr.Image(type="pil"),
outputs=[gr.Label(num_top_classes=10, label="Predictions")],
examples=example_list,
title=title,
description=description,
article=article)
demo.launch()