abhishri-medewar commited on
Commit
6c7e7bf
·
verified ·
1 Parent(s): 1878203

Upload 9 files

Browse files
Files changed (9) hide show
  1. 0.png +0 -0
  2. 1.png +0 -0
  3. 2.png +0 -0
  4. 3.png +0 -0
  5. 4.png +0 -0
  6. README.md +12 -13
  7. app.py +84 -0
  8. mnist_model.pth +3 -0
  9. requirements.txt +10 -0
0.png ADDED
1.png ADDED
2.png ADDED
3.png ADDED
4.png ADDED
README.md CHANGED
@@ -1,13 +1,12 @@
1
- ---
2
- title: Classification
3
- emoji: 🏆
4
- colorFrom: purple
5
- colorTo: purple
6
- sdk: streamlit
7
- sdk_version: 1.36.0
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: Mnist Classification
3
+ emoji: 🌖
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: streamlit
7
+ sdk_version: 1.17.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from PIL import Image, ImageOps
4
+ from torch import nn
5
+ import torchvision.transforms as T
6
+ import torch
7
+ import cv2
8
+ import numpy as np
9
+ import streamlit as st
10
+
11
+ st.set_page_config(layout="wide", page_title="Digit Recognition")
12
+
13
+ class Network(nn.Module):
14
+ def __init__(self):
15
+ super().__init__()
16
+
17
+ self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=(5, 5), stride=(1, 1), padding=(0, 0))
18
+ self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=(5, 5), stride=(1, 1), padding=(0, 0))
19
+ self.conv3 = nn.Conv2d(in_channels=16, out_channels=120, kernel_size=(4, 4), stride=(1, 1), padding=(0, 0))
20
+
21
+ self.fully_connected1 = nn.Linear(in_features=120, out_features=84)
22
+ self.fully_connected2 = nn.Linear(in_features=84, out_features=10)
23
+
24
+ self.pooling_layer = nn.AvgPool2d(kernel_size=(2, 2), stride=(2, 2))
25
+ self.relu = nn.ReLU()
26
+ self.dropout = nn.Dropout(0.25)
27
+
28
+ def forward(self, x):
29
+ # Convolution Layer 1
30
+ x = self.conv1(x)
31
+ x = self.relu(x)
32
+ x = self.pooling_layer(x)
33
+
34
+ # Convolution Layer 2
35
+ x = self.conv2(x)
36
+ x = self.relu(x)
37
+ x = self.pooling_layer(x)
38
+ x = self.dropout(x)
39
+
40
+ # Convolution Layer 3
41
+ x = self.conv3(x)
42
+ x = self.relu(x)
43
+
44
+ # flatten x
45
+ x = x.view(-1, 120)
46
+
47
+ # Fully connected layer 1
48
+ x = self.fully_connected1(x)
49
+ x = self.relu(x)
50
+
51
+ # Fully connected layer 2
52
+ x = self.fully_connected2(x)
53
+
54
+ return x
55
+
56
+ device = "cuda" if torch.cuda.is_available() else "cpu"
57
+ model = Network()
58
+ model.load_state_dict(torch.load('mnist_model.pth', map_location=torch.device(device)))
59
+
60
+ st.title("MNIST Image Classification")
61
+ st.subheader("This is a simple image classification web application to predict handwritten digits")
62
+
63
+ st.sidebar.write('## Please upload an image file :camera:', unsafe_allow_html=True)
64
+ file = st.sidebar.file_uploader("## Upload", type=["png"])
65
+
66
+ if file is None:
67
+ imagefile = './0.png'
68
+ else:
69
+ imagefile = file
70
+
71
+ img = Image.open(imagefile)
72
+ img_copy = img
73
+ img = cv2.cvtColor(np.array(img), cv2.COLOR_BGR2GRAY)
74
+ transform = T.Compose([
75
+ T.ToTensor(),
76
+ T.Resize((28, 28))
77
+ ])
78
+ img = transform(img)
79
+ st.image(img_copy, width=150)
80
+ model.eval()
81
+ results = model(img)
82
+ category = torch.argmax(results)
83
+ print(category.numpy())
84
+ st.write('<hr font-size: 30px;>The image is digit </hr>', str(category.numpy()), unsafe_allow_html=True)
mnist_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b9f58974b27543b4f1eb5b7a3c4609d0c2aca5fae2bde76a11b07e721e7b20e1
3
+ size 180807
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==1.12.1
2
+ torchaudio==0.12.1
3
+ torchvision==0.13.1
4
+ streamlit==1.20.0
5
+ pandas==1.4.2
6
+ opencv-python==4.6.0.66
7
+ numpy==1.21.5
8
+ matplotlib==3.5.1
9
+ Pillow==9.0.1
10
+ streamlit-image-select==0.5.1