callmeumer commited on
Commit
504e605
·
verified ·
1 Parent(s): 106fe7e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -89
app.py CHANGED
@@ -1,11 +1,11 @@
1
- from typing import Optional, Tuple
2
  import spaces
3
  import gradio as gr
4
  import numpy as np
5
  import torch
6
  from PIL import Image
7
  import io
8
- import json
9
 
10
  import base64, os
11
  from util.utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img
@@ -23,6 +23,7 @@ snapshot_download(repo_id=repo_id, local_dir=local_dir)
23
 
24
  print(f"Repository downloaded to: {local_dir}")
25
 
 
26
  yolo_model = get_yolo_model(model_path='weights/icon_detect/model.pt')
27
  caption_model_processor = get_caption_model_processor(model_name="florence2", model_name_or_path="weights/icon_caption")
28
  # caption_model_processor = get_caption_model_processor(model_name="blip2", model_name_or_path="weights/icon_caption_blip2")
@@ -42,14 +43,18 @@ DEVICE = torch.device('cuda')
42
 
43
  @spaces.GPU
44
  @torch.inference_mode()
 
45
  def process(
46
  image_input,
47
  box_threshold,
48
  iou_threshold,
49
  use_paddleocr,
50
  imgsz
51
- ) -> Tuple[str, str]:
52
 
 
 
 
53
  box_overlay_ratio = image_input.size[0] / 3200
54
  draw_bbox_config = {
55
  'text_scale': 0.8 * box_overlay_ratio,
@@ -57,93 +62,38 @@ def process(
57
  'text_padding': max(int(3 * box_overlay_ratio), 1),
58
  'thickness': max(int(3 * box_overlay_ratio), 1),
59
  }
 
60
 
61
- ocr_bbox_rslt, is_goal_filtered = check_ocr_box(
62
- image_input,
63
- display_img=False,
64
- output_bb_format='xyxy',
65
- goal_filtering=None,
66
- easyocr_args={'paragraph': False, 'text_threshold': 0.9},
67
- use_paddleocr=use_paddleocr
68
- )
69
  text, ocr_bbox = ocr_bbox_rslt
70
-
71
- dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
72
- image_input,
73
- yolo_model,
74
- BOX_TRESHOLD=box_threshold,
75
- output_coord_in_ratio=True,
76
- ocr_bbox=ocr_bbox,
77
- draw_bbox_config=draw_bbox_config,
78
- caption_model_processor=caption_model_processor,
79
- ocr_text=text,
80
- iou_threshold=iou_threshold,
81
- imgsz=imgsz
82
- )
83
-
84
  print('finish processing')
85
- parsed_content_list = '\n'.join([f'icon {i}: ' + str(v) for i, v in enumerate(parsed_content_list)])
86
-
87
- # Convert label_coordinates to JSON string for API consumption
88
- label_coordinates_json = json.dumps(label_coordinates)
89
-
90
- return str(parsed_content_list), label_coordinates_json
91
-
92
- # Create interface with simplified component definitions
93
  with gr.Blocks() as demo:
94
  gr.Markdown(MARKDOWN)
95
-
96
  with gr.Row():
97
  with gr.Column():
98
  image_input_component = gr.Image(
99
- type='pil',
100
- label='Upload image'
101
- )
102
-
103
- # Simplified slider definitions
104
  box_threshold_component = gr.Slider(
105
- minimum=0.01,
106
- maximum=1.0,
107
- value=0.05,
108
- step=0.01,
109
- label='Box Threshold'
110
- )
111
-
112
  iou_threshold_component = gr.Slider(
113
- minimum=0.01,
114
- maximum=1.0,
115
- value=0.1,
116
- step=0.01,
117
- label='IOU Threshold'
118
- )
119
-
120
  use_paddleocr_component = gr.Checkbox(
121
- value=True,
122
- label='Use PaddleOCR'
123
- )
124
-
125
  imgsz_component = gr.Slider(
126
- minimum=640,
127
- maximum=1920,
128
- value=640,
129
- step=32,
130
- label='Icon Detect Image Size'
131
- )
132
-
133
  submit_button_component = gr.Button(
134
- value='Submit',
135
- variant='primary'
136
- )
137
-
138
  with gr.Column():
139
- text_output_component = gr.Textbox(
140
- label='Parsed screen elements',
141
- placeholder='Text Output'
142
- )
143
- coordinates_output_component = gr.Textbox(
144
- label='Label Coordinates (JSON)',
145
- placeholder='Coordinates JSON Output'
146
- )
147
 
148
  submit_button_component.click(
149
  fn=process,
@@ -154,18 +104,10 @@ with gr.Blocks() as demo:
154
  use_paddleocr_component,
155
  imgsz_component
156
  ],
157
- outputs=[text_output_component, coordinates_output_component]
158
  )
159
 
160
- # Try launching with different configurations
161
- try:
162
- demo.queue().launch(share=True)
163
- except Exception as e:
164
- print(f"Error launching with queue: {e}")
165
- # Fallback: try without queue
166
- try:
167
- demo.launch(share=True)
168
- except Exception as e2:
169
- print(f"Error launching without queue: {e2}")
170
- # Final fallback: basic launch
171
- demo.launch(debug=True, show_error=True, share=True)
 
1
+ from typing import Optional
2
  import spaces
3
  import gradio as gr
4
  import numpy as np
5
  import torch
6
  from PIL import Image
7
  import io
8
+
9
 
10
  import base64, os
11
  from util.utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img
 
23
 
24
  print(f"Repository downloaded to: {local_dir}")
25
 
26
+
27
  yolo_model = get_yolo_model(model_path='weights/icon_detect/model.pt')
28
  caption_model_processor = get_caption_model_processor(model_name="florence2", model_name_or_path="weights/icon_caption")
29
  # caption_model_processor = get_caption_model_processor(model_name="blip2", model_name_or_path="weights/icon_caption_blip2")
 
43
 
44
  @spaces.GPU
45
  @torch.inference_mode()
46
+ # @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
47
  def process(
48
  image_input,
49
  box_threshold,
50
  iou_threshold,
51
  use_paddleocr,
52
  imgsz
53
+ ) -> Optional[Image.Image]:
54
 
55
+ # image_save_path = 'imgs/saved_image_demo.png'
56
+ # image_input.save(image_save_path)
57
+ # image = Image.open(image_save_path)
58
  box_overlay_ratio = image_input.size[0] / 3200
59
  draw_bbox_config = {
60
  'text_scale': 0.8 * box_overlay_ratio,
 
62
  'text_padding': max(int(3 * box_overlay_ratio), 1),
63
  'thickness': max(int(3 * box_overlay_ratio), 1),
64
  }
65
+ # import pdb; pdb.set_trace()
66
 
67
+ ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_input, display_img = False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold':0.9}, use_paddleocr=use_paddleocr)
 
 
 
 
 
 
 
68
  text, ocr_bbox = ocr_bbox_rslt
69
+ dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_input, yolo_model, BOX_TRESHOLD = box_threshold, output_coord_in_ratio=True, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=caption_model_processor, ocr_text=text,iou_threshold=iou_threshold, imgsz=imgsz,)
70
+ image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
 
 
 
 
 
 
 
 
 
 
 
 
71
  print('finish processing')
72
+ parsed_content_list = '\n'.join([f'icon {i}: ' + str(v) for i,v in enumerate(parsed_content_list)])
73
+ # parsed_content_list = str(parsed_content_list)
74
+ return image, str(parsed_content_list)
75
+
 
 
 
 
76
  with gr.Blocks() as demo:
77
  gr.Markdown(MARKDOWN)
 
78
  with gr.Row():
79
  with gr.Column():
80
  image_input_component = gr.Image(
81
+ type='pil', label='Upload image')
82
+ # set the threshold for removing the bounding boxes with low confidence, default is 0.05
 
 
 
83
  box_threshold_component = gr.Slider(
84
+ label='Box Threshold', minimum=0.01, maximum=1.0, step=0.01, value=0.05)
85
+ # set the threshold for removing the bounding boxes with large overlap, default is 0.1
 
 
 
 
 
86
  iou_threshold_component = gr.Slider(
87
+ label='IOU Threshold', minimum=0.01, maximum=1.0, step=0.01, value=0.1)
 
 
 
 
 
 
88
  use_paddleocr_component = gr.Checkbox(
89
+ label='Use PaddleOCR', value=True)
 
 
 
90
  imgsz_component = gr.Slider(
91
+ label='Icon Detect Image Size', minimum=640, maximum=1920, step=32, value=640)
 
 
 
 
 
 
92
  submit_button_component = gr.Button(
93
+ value='Submit', variant='primary')
 
 
 
94
  with gr.Column():
95
+ image_output_component = gr.Image(type='pil', label='Image Output')
96
+ text_output_component = gr.Textbox(label='Parsed screen elements', placeholder='Text Output')
 
 
 
 
 
 
97
 
98
  submit_button_component.click(
99
  fn=process,
 
104
  use_paddleocr_component,
105
  imgsz_component
106
  ],
107
+ outputs=[image_output_component, text_output_component]
108
  )
109
 
110
+ # demo.launch(debug=False, show_error=True, share=True)
111
+ # demo.launch(share=True, server_port=7861, server_name='0.0.0.0')
112
+ demo.queue().launch(share=True)
113
+