Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import os | |
| import torch | |
| import torchvision | |
| from model import create_model | |
| from timeit import default_timer as timer | |
| with open("class_names.txt", "r") as f: | |
| class_names = [class_name.strip() for class_name in f.readlines()] | |
| model, model_transforms = create_model(num_classes=len(class_names)) | |
| model.load_state_dict(torch.load("ConvNeXt_Tiny_101classes_20_10epochs.pth", | |
| map_location=torch.device('cpu'))) # load -> load_state_dict | |
| def predict(img): | |
| time_start = timer() | |
| weights = torchvision.models.ConvNeXt_Tiny_Weights.DEFAULT | |
| transform_convnext_tiny = weights.transforms() | |
| img_tensor = transform_convnext_tiny(img).unsqueeze(dim=0) # [Channels, Height, Width] -> [Batch_size, Channels, Height, Width] | |
| model.eval() | |
| with torch.inference_mode(): | |
| predicted_probs = model(img_tensor).softmax(dim=1) | |
| # Class name & predicted probability for each class (required by Gradio) | |
| pred_labels_probs = {} | |
| for i in range(len(class_names)): | |
| pred_labels_probs[class_names[i]] = float(predicted_probs[0][i]) | |
| return pred_labels_probs, round(timer() - time_start, 5) | |
| app = gr.Interface(fn=predict, # mapping function for [ input -> output ] | |
| inputs=gr.Image(type="pil"), # Input data | |
| outputs=[gr.Label(num_top_classes=3, label="Predictions"), # Output data (fn function's return values) | |
| gr.Number(label="Inference time (s)")], | |
| examples=[["examples/" + example] for example in os.listdir("examples")], | |
| title='ConvNeXt_Food101', | |
| description='A ConvNext CV model to classify 101 foods', | |
| article='Model trained on 150 images per class') | |
| app.launch() | |