Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| from PIL import Image, ImageOps | |
| from torch import nn | |
| import torchvision.transforms as T | |
| import torch | |
| import cv2 | |
| import numpy as np | |
| import streamlit as st | |
| st.set_page_config(layout="wide", page_title="Digit Recognition") | |
| class Network(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=(5, 5), stride=(1, 1), padding=(0, 0)) | |
| self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=(5, 5), stride=(1, 1), padding=(0, 0)) | |
| self.conv3 = nn.Conv2d(in_channels=16, out_channels=120, kernel_size=(4, 4), stride=(1, 1), padding=(0, 0)) | |
| self.fully_connected1 = nn.Linear(in_features=120, out_features=84) | |
| self.fully_connected2 = nn.Linear(in_features=84, out_features=10) | |
| self.pooling_layer = nn.AvgPool2d(kernel_size=(2, 2), stride=(2, 2)) | |
| self.relu = nn.ReLU() | |
| self.dropout = nn.Dropout(0.25) | |
| def forward(self, x): | |
| # Convolution Layer 1 | |
| x = self.conv1(x) | |
| x = self.relu(x) | |
| x = self.pooling_layer(x) | |
| # Convolution Layer 2 | |
| x = self.conv2(x) | |
| x = self.relu(x) | |
| x = self.pooling_layer(x) | |
| x = self.dropout(x) | |
| # Convolution Layer 3 | |
| x = self.conv3(x) | |
| x = self.relu(x) | |
| # flatten x | |
| x = x.view(-1, 120) | |
| # Fully connected layer 1 | |
| x = self.fully_connected1(x) | |
| x = self.relu(x) | |
| # Fully connected layer 2 | |
| x = self.fully_connected2(x) | |
| return x | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = Network() | |
| model.load_state_dict(torch.load('mnist_model.pth', map_location=torch.device(device))) | |
| st.title("MNIST Image Classification") | |
| st.subheader("This is a simple image classification web application to predict handwritten digits") | |
| st.sidebar.write('## Please upload an image file :camera:', unsafe_allow_html=True) | |
| file = st.sidebar.file_uploader("## Upload", type=["png"]) | |
| if file is None: | |
| imagefile = './0.png' | |
| else: | |
| imagefile = file | |
| img = Image.open(imagefile) | |
| img_copy = img | |
| img = cv2.cvtColor(np.array(img), cv2.COLOR_BGR2GRAY) | |
| transform = T.Compose([ | |
| T.ToTensor(), | |
| T.Resize((28, 28)) | |
| ]) | |
| img = transform(img) | |
| st.image(img_copy, width=150) | |
| model.eval() | |
| results = model(img) | |
| category = torch.argmax(results) | |
| print(category.numpy()) | |
| st.write('<hr font-size: 30px;>The image is digit </hr>', str(category.numpy()), unsafe_allow_html=True) | |