Spaces:
Running
Running
| from PIL import Image, ImageDraw | |
| import torch | |
| from torchvision import transforms | |
| import torch.nn.functional as F | |
| import gradio as gr | |
| # import sys | |
| # sys.path.insert(0, './') | |
| from test import create_letr, get_lines_and_draw | |
| from models.preprocessing import * | |
| from models.misc import nested_tensor_from_tensor_list | |
| model = create_letr('resnet50/checkpoint0024.pth') | |
| model101 = create_letr('resnet101/checkpoint0024.pth') | |
| # PREPARE PREPROCESSING | |
| # transform_test = transforms.Compose([ | |
| # transforms.Resize((test_size)), | |
| # transforms.ToTensor(), | |
| # transforms.Normalize([0.538, 0.494, 0.453], [0.257, 0.263, 0.273]), | |
| # ]) | |
| normalize = Compose([ | |
| ToTensor(), | |
| Normalize([0.538, 0.494, 0.453], [0.257, 0.263, 0.273]), | |
| Resize([256]), | |
| ]) | |
| normalize_512 = Compose([ | |
| ToTensor(), | |
| Normalize([0.538, 0.494, 0.453], [0.257, 0.263, 0.273]), | |
| Resize([512]), | |
| ]) | |
| normalize_1100 = Compose([ | |
| ToTensor(), | |
| Normalize([0.538, 0.494, 0.453], [0.257, 0.263, 0.273]), | |
| Resize([1100]), | |
| ]) | |
| def predict(inp, size, model_name): | |
| image = Image.fromarray(inp.astype('uint8'), 'RGB') | |
| h, w = image.height, image.width | |
| orig_size = torch.as_tensor([int(h), int(w)]) | |
| if size == '1100': | |
| img = normalize_1100(image) | |
| elif size == '512': | |
| img = normalize_512(image) | |
| else: | |
| img = normalize(image) | |
| inputs = nested_tensor_from_tensor_list([img]) | |
| with torch.no_grad(): | |
| if model_name == 'resnet101': | |
| outputs = model101(inputs)[0] | |
| else: | |
| outputs = model(inputs)[0] | |
| lines = get_lines_and_draw(image, outputs, orig_size) | |
| return image, str(lines) | |
| inputs = [ | |
| gr.inputs.Image(), | |
| gr.inputs.Radio(["256", "512", "1100"]), | |
| gr.inputs.Radio(["resnet50", "resnet101"]), | |
| ] | |
| outputs = [ | |
| gr.outputs.Image(label='Image with Lines', type='numpy'), | |
| gr.outputs.Textbox(label='Lines points List') | |
| ] | |
| gr.Interface( | |
| fn=predict, | |
| inputs=inputs, | |
| outputs=outputs, | |
| examples=[ | |
| ["demo.png", '256', "resnet50"], | |
| ["tappeto-per-calibrazione.jpg", '256', "resnet50"] | |
| ], | |
| title="LETR: Line Segment Detection Using Transformers without Edges", | |
| description="It is an end-to-end line segment detection algorithm using Transformers [published on CVPR 2021](https://github.com/mlpc-ucsd/LETR)." | |
| ).launch() | |