Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| from torch import nn | |
| import torchvision.transforms as T | |
| from linea.models import build_linea | |
| from linea.util.slconfig import DictAction, SLConfig | |
| from PIL import Image, ImageDraw | |
| LINEA_MODELS = { | |
| "LINEA-N": './linea/configs/linea/linea_hgnetv2_n.py', | |
| "LINEA-S": './linea/configs/linea/linea_hgnetv2_s.py', | |
| "LINEA-M": './linea/configs/linea/linea_hgnetv2_m.py', | |
| "LINEA-L": './linea/configs/linea/linea_hgnetv2_l.py' | |
| } | |
| transforms = T.Compose( | |
| [ | |
| T.Resize((640, 640)), | |
| T.ToTensor(), | |
| T.Normalize(mean=[0.538, 0.494, 0.453], std=[0.257, 0.263, 0.273]), | |
| ] | |
| ) | |
| example_images = [ | |
| ["assets/example1.jpg"], | |
| ["assets/example2.jpg"], | |
| ["assets/example3.jpg"], | |
| ["assets/example4.jpg"], | |
| ] | |
| description = """ | |
| <h1 align="center"> | |
| <ins>LINEA</ins> | |
| <br> | |
| Fast and accurate line detection using scalable transformers | |
| </h1> | |
| <h2 align="center"> | |
| <a href="https://www.linkedin.com/in/sebastianjr/">Sebastian Janampa</a> | |
| and | |
| <a href="https://www.linkedin.com/in/marios-pattichis-207b0119/">Marios Pattichis</a> | |
| </h2> | |
| <h2 align="center"> | |
| <a href="https://github.com/SebastianJanampa/LINEA.git">GitHub</a> | | |
| <a href="https://colab.research.google.com/github/SebastianJanampa/LINEA/blob/master/LINEA_tutorial.ipynb">Colab</a> | |
| </h2> | |
| ## Getting Started | |
| LINEA is a family of transformers models that detectes the line segments on an image. | |
| Its key component is its new attention mechanism called **line attention**. | |
| To get started, upload an image or select one of the examples below. | |
| You can choose between different model size, change the confidence threshold and visualize the results. | |
| """ | |
| def create_model(model_size): | |
| cfg = SLConfig.fromfile(LINEA_MODELS[model_size]) | |
| cfg.pretrained = False | |
| model, postprocessor = build_linea(cfg) | |
| letter = model_size[-1].lower() | |
| url = f"https://github.com/SebastianJanampa/storage/releases/download/LINEA/linea_hgnetv2_{letter}.pth" | |
| state_dict = torch.hub.load_state_dict_from_url( | |
| url, map_location="cpu", file_name=f"linea_hgnetv2_{letter}.pth" | |
| ) | |
| model.load_state_dict(state_dict['model'], strict=True) | |
| class Model(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.model = model.deploy() | |
| self.postprocessor = postprocessor.deploy() | |
| def forward(self, images, orig_target_sizes): | |
| outputs = self.model(images) | |
| outputs = self.postprocessor(outputs, orig_target_sizes) | |
| return outputs | |
| model = Model() | |
| model.eval() | |
| return model | |
| def draw(images, lines, scores, thrh): | |
| for i, im in enumerate(images): | |
| draw = ImageDraw.Draw(im) | |
| scr = scores[i] | |
| line = lines[i][scr > thrh] | |
| scrs = scr[scr > thrh] | |
| for j, l in enumerate(line): | |
| draw.line(list(l), fill="red", width=5) | |
| draw.text( | |
| (l[0], l[1]), | |
| text=f"{round(scrs[j].item(), 2)}", | |
| fill="blue", | |
| ) | |
| return images | |
| def filter(lines, scores, threshold): | |
| filtered_lines, filter_scores = [], [] | |
| for line, scr in zip(lines, scores): | |
| idx = scr > threshold | |
| filtered_lines.append(line[idx]) | |
| filter_scores.append(scr[idx]) | |
| return filtered_lines, filter_scores | |
| def format_output(lines, scores): | |
| n = len(lines[0]) | |
| txt = f"{n} lines were detected\n" | |
| txt += "Detected lines:\n" | |
| for line, scr in zip(lines[0], scores[0]): | |
| txt += f"\tx1: {line[0].item():.2f}" | |
| txt += f"\ty1: {line[1].item():.2f}" | |
| txt += f"\tx2: {line[2].item():.2f}" | |
| txt += f"\ty2: {line[3].item():.2f}" | |
| txt += f"\tscore: {scr.item():.2f}\n" | |
| return txt | |
| def process_results( | |
| image_path, | |
| model_size, | |
| threshold | |
| ): | |
| """ Process the image an returns the detected lines """ | |
| if image_path is None: | |
| raise gr.Error("Please upload an image first.") | |
| model = create_model(model_size) | |
| im_pil = Image.open(image_path).convert("RGB") | |
| w, h = im_pil.size | |
| orig_size = torch.tensor([[w, h]]) | |
| im_data = transforms(im_pil).unsqueeze(0) | |
| output = model(im_data, orig_size) | |
| lines, scores = output | |
| result_images = draw([im_pil], lines, scores, thrh=threshold) | |
| filtered_lines, filtered_scores = filter(lines, scores, threshold) | |
| return format_output(filtered_lines, filtered_scores), result_images[0], (lines, scores) | |
| def update_threshold( | |
| image_path, | |
| raw_results, | |
| threshold | |
| ): | |
| lines, scores = raw_results | |
| im_pil = Image.open(image_path).convert("RGB") | |
| result_images = draw([im_pil], lines, scores, thrh=threshold) | |
| filtered_lines, filtered_scores = filter(lines, scores, threshold) | |
| return format_output(filtered_lines, filtered_scores), result_images[0] | |
| def update_model( | |
| image_path, | |
| model_size, | |
| threshold | |
| ): | |
| create_model(model_size) | |
| if image_path is None: | |
| raise gr.Error("Please upload an image first.") | |
| return None, None, None | |
| return process_results(image_path, model_size, threshold) | |
| # Create the Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown(description) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("""## Input Image""") | |
| image_path = gr.Image(label="Upload image", type="filepath") | |
| model_size = gr.Dropdown( | |
| choices=list(LINEA_MODELS.keys()), label="Choose a LINEA model.", value="LINEA-M" | |
| ) | |
| threshold = gr.Slider( | |
| label="Confidence Threshold", | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.05, | |
| interactive=True, | |
| value=0.30, | |
| ) | |
| submit_btn = gr.Button("Detect Lines") | |
| gr.Examples(examples=example_images, inputs=[image_path, model_size]) | |
| with gr.Column(): | |
| gr.Markdown("""## Results""") | |
| image_output = gr.Image(label="Detected Lines") | |
| text_output = gr.Textbox(label="Predicted lines", type="text", lines=5) | |
| # Define the action when the button is clicked | |
| raw_results = gr.State() | |
| plot_inputs = [ | |
| raw_results, | |
| threshold | |
| ] | |
| submit_btn.click( | |
| fn=process_results, | |
| inputs=[image_path, model_size] + plot_inputs[1:], | |
| outputs=[text_output, image_output, raw_results], | |
| ) | |
| # Define the action when the plot checkboxes are clicked | |
| threshold.change(fn=update_threshold, inputs=[image_path] + plot_inputs, outputs=[text_output, image_output]) | |
| demo.launch() |