File size: 691 Bytes
e34d96b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import gradio as gr
from model import Net, predict
import torch
import torchvision.transforms as transforms
from PIL import Image

model = Net()
model.load_state_dict(torch.load("mnist_model.pth", map_location=torch.device("cpu")))
model.eval()

transform = transforms.Compose([
    transforms.Grayscale(),  # Convert to grayscale if needed
    transforms.Resize((28, 28)),  # Fix: pass size as a tuple
    transforms.ToTensor()  # Convert to a PyTorch tensor
])


def predict_image(image):
    input_tensors = transform(Image.fromarray(image)).unsqueeze(0)

    result = predict(model,input_tensors)

    return result


app = gr.Interface(predict_image, gr.Image(), "text")


app.launch()