### 1. Imports and class names setup ### import gradio as gr import os import torch from model import create_resnet50_model from timeit import default_timer as timer from typing import Tuple, Dict # Setup class names class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'] ### 2. Model and transforms preparation ### # Create model resnet50, resnet50_transforms = create_resnet50_model(num_classes=36, seed=42) # Load saved weights resnet50.load_state_dict( torch.load( f="AMS.pth", map_location=torch.device("cpu"), # load to CPU ) ) ### 3. Predict function ### # Create predict function def predict(img) -> Tuple[Dict, float]: """Transforms and performs a prediction on img and returns prediction and time taken. """ # Start the timer start_time = timer() img = img.convert('RGB') # Transform the target image using the ResNet50 transforms img = resnet50_transforms(img).unsqueeze(0) # Put the ResNet50 model into evaluation mode resnet50.eval() with torch.inference_mode(): # Pass the transformed image through the model and obtain the prediction logits pred_logits = resnet50(img) # Convert the prediction logits to probabilities using softmax pred_probs = torch.softmax(pred_logits, dim=1) # Create a prediction label and prediction probability dictionary for each prediction class pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))} # Calculate the prediction time pred_time = round(timer() - start_time, 5) # Return the prediction dictionary and prediction time return pred_labels_and_probs, pred_time ### 4. Gradio app ### import gradio as gr # Create title, description and article strings title = "AMERICA SIGN LAGNGUAGE" description = "An resnet50 feature extractor computer vision model to classify american sign language ." #article = "Created at [09. PyTorch Model Deployment](https://www.learnpytorch.io/09_pytorch_model_deployment/)." # Create the Gradio demo demo = gr.Interface(fn=predict, # mapping function from input to output inputs=gr.Image(type="pil"), # what are the inputs? outputs=[gr.Label(num_top_classes=5, label="Predictions"), # what are the outputs? gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs, title=title, description=description, ) # Launch the demo! demo.launch()