''' Import Modules ''' import torch import torch.nn as nn import torchvision.models as models import torchvision.transforms as T import gradio as gr import PIL.Image as Image import numpy as np import os ''' Setup ''' weights_path = "vit_base_state_dict.pth" model = models.vit_b_16() model.heads = nn.Sequential(nn.Linear(768, 29)) model.load_state_dict(torch.load(weights_path, map_location="cpu")) transform = T.Compose([ T.Resize((224, 224)), T.ToTensor(), T.Normalize(mean=[0.5 for _ in range(3)], std=[0.5 for _ in range(3)]) ]) label_to_idx = { 0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E', 5: 'F', 6: 'G', 7: 'H', 8: 'I', 9: 'J', 10: 'K', 11: 'L', 12: 'M', 13: 'N', 14: 'O', 15: 'P', 16: 'Q', 17: 'R', 18: 'S', 19: 'T', 20: 'U', 21: 'V', 22: 'W', 23: 'X', 24: 'Y', 25: 'Z', 26: 'del', 27: 'nothing', 28: 'space' } def main(input_image: np.array): pil_image = Image.fromarray(input_image) tensor_image = transform(pil_image) with torch.inference_mode(): pred = model(tensor_image.unsqueeze(0)).squeeze(0) pred = torch.argmax(torch.softmax(pred, dim=0), dim=0) pred = label_to_idx[pred.item()] return pred img_files = os.listdir("examples") img_files.remove(".DS_Store") examples = ["examples/"+img_name for img_name in img_files] app = gr.Interface( fn=main, inputs=gr.Image(), outputs=gr.Textbox(), examples=examples ) app.launch()