Spaces:
Sleeping
Sleeping
File size: 1,542 Bytes
2fa6329 30bfd08 e2f7ccb 30bfd08 e2f7ccb 4533cc6 e2f7ccb f641d1c b3a91df f641d1c e2f7ccb a864ac4 17d21ea e2f7ccb 30bfd08 e2f7ccb 24d5a03 30bfd08 24d5a03 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
import gradio as gr
import torch
from torchvision import models, transforms
import torchvision
from PIL import Image
import numpy as np
import requests
from models.convmodel import MNISTnet
from pathlib import Path
# Function to perform image classification
def classify_image(img):
#imdata = np.asarray(Image.open(image_path))
alltransforms = torchvision.transforms.Compose([
torchvision.transforms.Grayscale(),
torchvision.transforms.ToTensor()])
tensor_image = alltransforms(img)
# bring it to the shape model expects N, C, H, W
#print(tensor_image.shape)
model_input_tensor_image = tensor_image.unsqueeze(dim=0)
#initialize the model
loaded_model = MNISTnet(input_channels=1, num_labels=10, hidden_layers=5).eval()
#put the state dict values
model_state_dict_path = Path("models/MNISTnet_state_dict.pt")
loaded_model.load_state_dict(torch.load(model_state_dict_path))
# make the prediction
with torch.inference_mode():
predicted_idx = loaded_model(model_input_tensor_image).argmax(dim=1)
label_mapping = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
predicted_label = label_mapping[predicted_idx.item()]
#print(predicted_label)
return predicted_label
# Gradio interface
iface = gr.Interface(
fn=classify_image,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=10),
title="Predict the Image"
)
# Launch the Gradio app
iface.launch() |