Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from torchvision import models, transforms | |
| import torchvision | |
| from PIL import Image | |
| import numpy as np | |
| import requests | |
| from models.convmodel import MNISTnet | |
| from pathlib import Path | |
| # Function to perform image classification | |
| def classify_image(img): | |
| #imdata = np.asarray(Image.open(image_path)) | |
| alltransforms = torchvision.transforms.Compose([ | |
| torchvision.transforms.Grayscale(), | |
| torchvision.transforms.ToTensor()]) | |
| tensor_image = alltransforms(img) | |
| # bring it to the shape model expects N, C, H, W | |
| #print(tensor_image.shape) | |
| model_input_tensor_image = tensor_image.unsqueeze(dim=0) | |
| #initialize the model | |
| loaded_model = MNISTnet(input_channels=1, num_labels=10, hidden_layers=5).eval() | |
| #put the state dict values | |
| model_state_dict_path = Path("models/MNISTnet_state_dict.pt") | |
| loaded_model.load_state_dict(torch.load(model_state_dict_path)) | |
| # make the prediction | |
| with torch.inference_mode(): | |
| predicted_idx = loaded_model(model_input_tensor_image).argmax(dim=1) | |
| label_mapping = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', | |
| 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'] | |
| predicted_label = label_mapping[predicted_idx.item()] | |
| #print(predicted_label) | |
| return predicted_label | |
| # Gradio interface | |
| iface = gr.Interface( | |
| fn=classify_image, | |
| inputs=gr.Image(type="pil"), | |
| outputs=gr.Label(num_top_classes=10), | |
| title="Predict the Image" | |
| ) | |
| # Launch the Gradio app | |
| iface.launch() |