lihao57 commited on
Commit
840a94e
·
1 Parent(s): bc77b9c

Add application file

Browse files
Files changed (4) hide show
  1. .gitignore +2 -0
  2. .pre-commit-config.yaml +13 -0
  3. app.py +120 -0
  4. requirements.txt +5 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .vscode
2
+ .gradio
.pre-commit-config.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/psf/black
3
+ rev: 22.3.0
4
+ hooks:
5
+ - id: black
6
+ args: [--line-length=120]
7
+ - repo: https://github.com/pre-commit/pre-commit-hooks
8
+ rev: v3.2.0
9
+ hooks:
10
+ - id: check-json
11
+ - id: end-of-file-fixer
12
+ - id: trailing-whitespace
13
+ - id: requirements-txt-fixer
app.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+
3
+ """
4
+ @File : app.py
5
+ @Time : 2025/8/29 15:25:00
6
+ @Author : lh9171338
7
+ @Version : 1.0
8
+ @Contact : 2909171338@qq.com
9
+ """
10
+
11
+ import os
12
+ import gradio as gr
13
+ import cv2
14
+ from PIL import Image
15
+ import io
16
+ import matplotlib.pyplot as plt
17
+ import numpy as np
18
+ from datasets import load_dataset
19
+
20
+ ds = None
21
+
22
+
23
+ def get_dataset():
24
+ """
25
+ get dataset
26
+
27
+ Args:
28
+ None
29
+
30
+ Returns:
31
+ ds (datasets.Dataset): dataset
32
+ """
33
+ global ds
34
+ if ds is None:
35
+ # ds = load_dataset("parquet", data_files={"train": "train/metadata.parquet", "test": "test/metadata.parquet"})
36
+ ds = load_dataset("lh9171338/Wireframe")
37
+ return ds
38
+
39
+
40
+ def selector_change_callback(value):
41
+ """
42
+ callback function for split selector
43
+
44
+ Args:
45
+ value (str): selected split, value must be one of ["train", "test"]
46
+
47
+ Returns:
48
+ slider_info (dict): updated slider info
49
+ image (np.ndarray): updated image
50
+ """
51
+ ds = get_dataset()
52
+ maximum = len(ds[value]) - 1
53
+ slider_info = gr.update(minimum=0, maximum=maximum, value=0)
54
+ image = show_image(split=value, index=0)
55
+ return slider_info, image
56
+
57
+
58
+ def draw_lines(image, lines):
59
+ """
60
+ draw lines on image
61
+
62
+ Args:
63
+ image (np.ndarray): input image
64
+ lines (np.ndarray): list of lines, with shape [N, 2, 2]
65
+
66
+ Returns:
67
+ image (PIL.Image): drawn image
68
+ """
69
+ height, width = image.shape[:2]
70
+ fig = plt.figure()
71
+ fig.set_size_inches(width / height, 1, forward=False)
72
+ ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0])
73
+ ax.set_axis_off()
74
+ fig.add_axes(ax)
75
+ plt.xlim([-0.5, width - 0.5])
76
+ plt.ylim([height - 0.5, -0.5])
77
+ plt.imshow(image[:, :, ::-1])
78
+ for pts in lines:
79
+ pts = pts - 0.5
80
+ plt.plot(pts[:, 0], pts[:, 1], color="orange", linewidth=0.5)
81
+ plt.scatter(pts[[0, -1], 0], pts[[0, -1], 1], color="#33FFFF", s=1.2, edgecolors="none", zorder=5)
82
+
83
+ buf = io.BytesIO()
84
+ fig.savefig(buf, format="png", dpi=height, bbox_inches=0)
85
+ buf.seek(0)
86
+ plt.close(fig)
87
+ image = Image.open(buf)
88
+ return image
89
+
90
+
91
+ def show_image(split, index):
92
+ """
93
+ show image
94
+
95
+ Args:
96
+ split (str): split name, value must be one of ["train", "test"]
97
+ index (int): index of the example
98
+
99
+ Returns:
100
+ image (PIL.Image): drawn image
101
+ """
102
+ ds = get_dataset()
103
+ example = ds[split][index]
104
+ image_file = os.path.join(split, example["file_name"])
105
+ image = cv2.imread(image_file)
106
+ lines = np.array(example["lines"])
107
+ image = draw_lines(image, lines)
108
+ return image
109
+
110
+
111
+ if __name__ == "__main__":
112
+ with gr.Blocks() as demo:
113
+ split_selector = gr.Dropdown(["train", "test"], label="Split", value="train")
114
+ index_slider = gr.Slider(0, 1, step=1, label="Index", value=0)
115
+ output = gr.Image()
116
+
117
+ split_selector.change(selector_change_callback, split_selector, [index_slider, output])
118
+ index_slider.change(show_image, [split_selector, index_slider], output)
119
+ demo.load(selector_change_callback, split_selector, [index_slider, output])
120
+ demo.launch(share=False)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ datasets
2
+ matplotlib
3
+ numpy
4
+ opencv-python
5
+ pillow