JH-BK commited on
Commit
36b7e0f
ยท
1 Parent(s): 622627d

fix: build failure

Browse files
Files changed (1) hide show
  1. app.py +82 -8
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import gradio as gr
2
  import numpy as np
3
  import math
 
 
4
  import torch
5
 
6
  from PIL import Image, ImageDraw
@@ -13,9 +15,70 @@ config = Config()
13
  cuda = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
 
15
 
16
- docscanner = load_docscanner_model(
17
- cuda, path_l=config.get_rec_model_path, path_m=config.get_seg_model_path
18
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  # ์ขŒํ‘œ๋ฅผ ์ดˆ๊ธฐํ™”ํ•˜๋Š” ํ•จ์ˆ˜
21
  def reset_points(image, state):
@@ -38,7 +101,8 @@ def rotate_image(image):
38
  return rotated_image, state
39
 
40
  def reset_image(image, state):
41
- out_image, msk_np = docscanner_rec(image, docscanner)
 
42
  state = list(get_corner(mask2point(mask=msk_np)))
43
 
44
  img = Image.fromarray(image)
@@ -59,6 +123,7 @@ def reset_image(image, state):
59
  return img, state
60
 
61
  def auto_point_detect(image):
 
62
  out_image, msk_np = docscanner_rec(image, docscanner, cuda)
63
  state = list(get_corner(mask2point(mask=msk_np)))
64
 
@@ -126,6 +191,7 @@ def sort_corners(corners):
126
  def convert(image, state):
127
  h,w = image.shape[:2]
128
  if len(state) < 4:
 
129
  out_image, msk_np = docscanner_rec(image, docscanner, cuda)
130
  out_image = out_image[:,:,::-1]
131
  elif len(state) ==4:
@@ -151,7 +217,12 @@ with gr.Blocks(css=css) as demo:
151
  with gr.Row():
152
  with gr.Column():
153
  text = gr.Textbox("์ž…๋ ฅ ์ด๋ฏธ์ง€(์ฝ”๋„ˆ๋ฅผ ํด๋ฆญํ•˜์„ธ์š”)", show_label=False)
154
- image_input = gr.Image(show_label=False, interactive=True, elem_classes="image-container")
 
 
 
 
 
155
  clear_button = gr.Button("Clear Points")
156
  cutting_button = gr.Button("Cutting Image(need more than 2 points)")
157
  rotating_button = gr.Button("Rotate Image(clock wise 90 degree)")
@@ -159,13 +230,13 @@ with gr.Blocks(css=css) as demo:
159
  convert_button = gr.Button("Convert Image")
160
  with gr.Column():
161
  text = gr.Textbox("๋ณ€ํ™˜๋  ์˜์—ญ", show_label=False)
162
- image_output = gr.Image(show_label=False)
163
  # state_display = gr.Textbox(label="Current State")
164
  # coordinates_text = gr.Textbox(label="Coordinates", placeholder="Enter coordinates (x, y) for each point")
165
  # update_coords_button = gr.Button("Update Coordinates")
166
  with gr.Column():
167
  text = gr.Textbox("๊ฒฐ๊ณผ ์ด๋ฏธ์ง€", show_label=False)
168
- result_image = gr.Image(show_label=False, format="png")
169
 
170
  # # ์ด๋ฏธ์ง€ ์œ„์—์„œ ํด๋ฆญ ์ด๋ฒคํŠธ ์ฒ˜๋ฆฌ
171
  image_input.select(draw_polygon_on_image, inputs=[image_input,state], outputs=[image_output,state])
@@ -183,4 +254,7 @@ with gr.Blocks(css=css) as demo:
183
  # ๋ณ€ํ™˜ ๋ฒ„ํŠผ
184
  convert_button.click(fn=convert, inputs=[image_input,state], outputs=result_image)
185
 
186
- demo.launch(share=True)
 
 
 
 
1
  import gradio as gr
2
  import numpy as np
3
  import math
4
+ import os
5
+ import shutil
6
  import torch
7
 
8
  from PIL import Image, ImageDraw
 
15
  cuda = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
17
 
18
+ _docscanner = None
19
+
20
+
21
+ def _is_git_lfs_pointer_file(path: str) -> bool:
22
+ try:
23
+ with open(path, "rb") as f:
24
+ head = f.read(200)
25
+ return b"version https://git-lfs.github.com/spec/v1" in head
26
+ except FileNotFoundError:
27
+ return False
28
+
29
+
30
+ def _try_download_from_hf_hub(repo_id: str, filename: str, local_path: str) -> bool:
31
+ try:
32
+ from huggingface_hub import hf_hub_download # type: ignore
33
+
34
+ downloaded = hf_hub_download(repo_id=repo_id, filename=filename)
35
+ os.makedirs(os.path.dirname(local_path), exist_ok=True)
36
+ shutil.copyfile(downloaded, local_path)
37
+ return True
38
+ except Exception:
39
+ return False
40
+
41
+
42
+ def _ensure_weights_present() -> None:
43
+ seg_path = config.get_seg_model_path
44
+ rec_path = config.get_rec_model_path
45
+
46
+ needs_seg = (not os.path.exists(seg_path)) or _is_git_lfs_pointer_file(seg_path)
47
+ needs_rec = (not os.path.exists(rec_path)) or _is_git_lfs_pointer_file(rec_path)
48
+ if not (needs_seg or needs_rec):
49
+ return
50
+
51
+ repo_id = (
52
+ os.getenv("SPACE_ID")
53
+ or os.getenv("HF_SPACE_ID")
54
+ or os.getenv("HF_REPO_ID")
55
+ or os.getenv("REPO_ID")
56
+ )
57
+ if repo_id:
58
+ if needs_seg:
59
+ _try_download_from_hf_hub(repo_id, os.path.basename(seg_path), seg_path)
60
+ if needs_rec:
61
+ _try_download_from_hf_hub(repo_id, os.path.basename(rec_path), rec_path)
62
+
63
+
64
+ def get_docscanner():
65
+ global _docscanner
66
+ if _docscanner is not None:
67
+ return _docscanner
68
+
69
+ _ensure_weights_present()
70
+
71
+ seg_path = config.get_seg_model_path
72
+ rec_path = config.get_rec_model_path
73
+ if _is_git_lfs_pointer_file(seg_path) or _is_git_lfs_pointer_file(rec_path):
74
+ raise RuntimeError(
75
+ "Model weight files look like Git LFS pointers. "
76
+ "Make sure LFS objects are downloaded (e.g. `git lfs pull`) "
77
+ "or allow the Space to download them from the Hub at runtime."
78
+ )
79
+
80
+ _docscanner = load_docscanner_model(cuda, path_l=rec_path, path_m=seg_path)
81
+ return _docscanner
82
 
83
  # ์ขŒํ‘œ๋ฅผ ์ดˆ๊ธฐํ™”ํ•˜๋Š” ํ•จ์ˆ˜
84
  def reset_points(image, state):
 
101
  return rotated_image, state
102
 
103
  def reset_image(image, state):
104
+ docscanner = get_docscanner()
105
+ out_image, msk_np = docscanner_rec(image, docscanner, cuda)
106
  state = list(get_corner(mask2point(mask=msk_np)))
107
 
108
  img = Image.fromarray(image)
 
123
  return img, state
124
 
125
  def auto_point_detect(image):
126
+ docscanner = get_docscanner()
127
  out_image, msk_np = docscanner_rec(image, docscanner, cuda)
128
  state = list(get_corner(mask2point(mask=msk_np)))
129
 
 
191
  def convert(image, state):
192
  h,w = image.shape[:2]
193
  if len(state) < 4:
194
+ docscanner = get_docscanner()
195
  out_image, msk_np = docscanner_rec(image, docscanner, cuda)
196
  out_image = out_image[:,:,::-1]
197
  elif len(state) ==4:
 
217
  with gr.Row():
218
  with gr.Column():
219
  text = gr.Textbox("์ž…๋ ฅ ์ด๋ฏธ์ง€(์ฝ”๋„ˆ๋ฅผ ํด๋ฆญํ•˜์„ธ์š”)", show_label=False)
220
+ image_input = gr.Image(
221
+ show_label=False,
222
+ interactive=True,
223
+ elem_classes="image-container",
224
+ type="numpy",
225
+ )
226
  clear_button = gr.Button("Clear Points")
227
  cutting_button = gr.Button("Cutting Image(need more than 2 points)")
228
  rotating_button = gr.Button("Rotate Image(clock wise 90 degree)")
 
230
  convert_button = gr.Button("Convert Image")
231
  with gr.Column():
232
  text = gr.Textbox("๋ณ€ํ™˜๋  ์˜์—ญ", show_label=False)
233
+ image_output = gr.Image(show_label=False, type="pil")
234
  # state_display = gr.Textbox(label="Current State")
235
  # coordinates_text = gr.Textbox(label="Coordinates", placeholder="Enter coordinates (x, y) for each point")
236
  # update_coords_button = gr.Button("Update Coordinates")
237
  with gr.Column():
238
  text = gr.Textbox("๊ฒฐ๊ณผ ์ด๋ฏธ์ง€", show_label=False)
239
+ result_image = gr.Image(show_label=False, format="png", type="numpy")
240
 
241
  # # ์ด๋ฏธ์ง€ ์œ„์—์„œ ํด๋ฆญ ์ด๋ฒคํŠธ ์ฒ˜๋ฆฌ
242
  image_input.select(draw_polygon_on_image, inputs=[image_input,state], outputs=[image_output,state])
 
254
  # ๋ณ€ํ™˜ ๋ฒ„ํŠผ
255
  convert_button.click(fn=convert, inputs=[image_input,state], outputs=result_image)
256
 
257
+ is_spaces = bool(
258
+ os.getenv("SPACE_ID") or os.getenv("HF_SPACE_ID") or os.getenv("SYSTEM") == "spaces"
259
+ )
260
+ demo.launch(share=not is_spaces and bool(os.getenv("GRADIO_SHARE")))