import gradio as gr import torch from torchvision import models, transforms import torchvision from PIL import Image import numpy as np import requests from models.convmodel import MNISTnet from pathlib import Path # Function to perform image classification def classify_image(img): #imdata = np.asarray(Image.open(image_path)) alltransforms = torchvision.transforms.Compose([ torchvision.transforms.Grayscale(), torchvision.transforms.ToTensor()]) tensor_image = alltransforms(img) # bring it to the shape model expects N, C, H, W #print(tensor_image.shape) model_input_tensor_image = tensor_image.unsqueeze(dim=0) #initialize the model loaded_model = MNISTnet(input_channels=1, num_labels=10, hidden_layers=5).eval() #put the state dict values model_state_dict_path = Path("models/MNISTnet_state_dict.pt") loaded_model.load_state_dict(torch.load(model_state_dict_path)) # make the prediction with torch.inference_mode(): predicted_idx = loaded_model(model_input_tensor_image).argmax(dim=1) label_mapping = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'] predicted_label = label_mapping[predicted_idx.item()] #print(predicted_label) return predicted_label # Gradio interface iface = gr.Interface( fn=classify_image, inputs=gr.Image(type="pil"), outputs=gr.Label(num_top_classes=10), title="Predict the Image" ) # Launch the Gradio app iface.launch()