File size: 1,542 Bytes
2fa6329
30bfd08
 
e2f7ccb
30bfd08
e2f7ccb
4533cc6
e2f7ccb
 
 
 
f641d1c
 
b3a91df
 
 
f641d1c
e2f7ccb
 
 
 
 
 
 
a864ac4
17d21ea
e2f7ccb
 
 
 
 
 
 
 
30bfd08
 
 
e2f7ccb
 
24d5a03
 
30bfd08
 
 
24d5a03
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import gradio as gr
import torch
from torchvision import models, transforms
import torchvision
from PIL import Image
import numpy as np
import requests
from models.convmodel import MNISTnet
from pathlib import Path 

# Function to perform image classification
def classify_image(img):
    #imdata = np.asarray(Image.open(image_path))
    alltransforms = torchvision.transforms.Compose([
        torchvision.transforms.Grayscale(),
        torchvision.transforms.ToTensor()])
    tensor_image = alltransforms(img)
    # bring it to the shape model expects N, C, H, W
    #print(tensor_image.shape)
    model_input_tensor_image = tensor_image.unsqueeze(dim=0)

    #initialize the model
    loaded_model = MNISTnet(input_channels=1, num_labels=10, hidden_layers=5).eval()
    #put the state dict values
    model_state_dict_path = Path("models/MNISTnet_state_dict.pt")
    loaded_model.load_state_dict(torch.load(model_state_dict_path))
    # make the prediction
    with torch.inference_mode():
        predicted_idx = loaded_model(model_input_tensor_image).argmax(dim=1)
    label_mapping = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
                    'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
    predicted_label = label_mapping[predicted_idx.item()]
    #print(predicted_label)
    return predicted_label

# Gradio interface
iface = gr.Interface(
    fn=classify_image,
    inputs=gr.Image(type="pil"),
    outputs=gr.Label(num_top_classes=10),
    title="Predict the Image"
)

# Launch the Gradio app
iface.launch()