draftnewapp / app.py
binaychandra's picture
adding title
24d5a03
import gradio as gr
import torch
from torchvision import models, transforms
import torchvision
from PIL import Image
import numpy as np
import requests
from models.convmodel import MNISTnet
from pathlib import Path
# Function to perform image classification
def classify_image(img):
#imdata = np.asarray(Image.open(image_path))
alltransforms = torchvision.transforms.Compose([
torchvision.transforms.Grayscale(),
torchvision.transforms.ToTensor()])
tensor_image = alltransforms(img)
# bring it to the shape model expects N, C, H, W
#print(tensor_image.shape)
model_input_tensor_image = tensor_image.unsqueeze(dim=0)
#initialize the model
loaded_model = MNISTnet(input_channels=1, num_labels=10, hidden_layers=5).eval()
#put the state dict values
model_state_dict_path = Path("models/MNISTnet_state_dict.pt")
loaded_model.load_state_dict(torch.load(model_state_dict_path))
# make the prediction
with torch.inference_mode():
predicted_idx = loaded_model(model_input_tensor_image).argmax(dim=1)
label_mapping = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
predicted_label = label_mapping[predicted_idx.item()]
#print(predicted_label)
return predicted_label
# Gradio interface
iface = gr.Interface(
fn=classify_image,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=10),
title="Predict the Image"
)
# Launch the Gradio app
iface.launch()