Billy-06
Added the file Architectures
a263b83
raw
history blame
1.69 kB
import gradio as gr
import torch
import numpy as np
from model import *
def load_cub200_classes():
"""
This function loads the classes from the classes.txt file and returns a dictionary
"""
with open("classes.txt", encoding="utf-8") as f:
classes = f.read().splitlines()
# convert classes to dictionary separating the lines by the first space
classes = {int(line.split(" ")[0]) : line.split(" ")[1] for line in classes}
# return the classes dictionary
return classes
def load_model():
"""
This function loads the trained model and returns it
"""
# load the resnet model
model = resnet50(pretrained=False, stride=[1, 2, 2, 1], num_classes=200)
# load the trained weights
model.load_state_dict(torch.load("resnet.pt", map_location=torch.device('cpu')))
# set the model to evaluation mode
model.eval()
# return the model
return model
def predict_image(image):
"""
This function takes an image as input and returns the class label
"""
# load the model
model = load_model()
# load the classes
classes = load_cub200_classes()
# convert image to tensor
tensor = torch.from_numpy(image).permute(2, 0, 1).float().unsqueeze(0)
# make prediction
prediction = model(tensor).detach().numpy()[0]
# convert prediction to probabilities
probabilities = np.exp(prediction) / np.sum(np.exp(prediction))
# get the class with the highest probability
class_idx = np.argmax(probabilities)
# return the class label
return "Class: " + classes[class_idx]
# create a gradio interface
gr.Interface(fn=predict_image, inputs="image", outputs="text").launch()