File size: 1,195 Bytes
52a915d
 
bf2bb8a
 
52a915d
 
bf2bb8a
f2e4334
52a915d
 
 
bf2bb8a
 
 
 
 
 
 
 
52a915d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42632eb
52a915d
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import gradio as gr
import torch
from PIL import Image
from model import model 
from torchvision import transforms

# Load your own model
model.load_state_dict(torch.load('mnist_model.pth'))
model.eval()

def preprocess_image(image):
    transform = transforms.Compose([
        transforms.Resize((28, 28)),
        transforms.Grayscale(num_output_channels=1),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    image = Image.fromarray(image)
    tensor = transform(image).unsqueeze(0)
    return tensor

def classify(image):
    tensor = preprocess_image(image)
    with torch.no_grad():
        output = model(tensor)
        prediction = output.argmax(dim=1, keepdim=True).item()
    return str(prediction)  # Convert prediction to string

iface = gr.Interface(
    fn=classify,
    inputs="sketchpad",
    outputs='label',
    theme="huggingface",
    title="Digit Recognition",
    description="Draw a Digit 0-9 and the algorithm will detect it in real time! This is tiny model Kindly Draw digits in center of drawing area",
    article="<p style='text-align: center'>Digit Recognition | Demo Model by Jugal</p>",
    live=True)
iface.launch(debug=True)