iammraat commited on
Commit
b364284
·
verified ·
1 Parent(s): 2f851c6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +170 -15
app.py CHANGED
@@ -1,3 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import cv2
3
  import numpy as np
@@ -22,17 +171,13 @@ model_inputs = session.get_inputs()
22
  input_names = [i.name for i in model_inputs]
23
  output_names = [o.name for o in session.get_outputs()]
24
 
25
- print(f"Model expects inputs: {input_names}")
26
-
27
  LABELS = {0: "Text", 1: "Title", 2: "List", 3: "Table", 4: "Figure"}
28
 
29
- # --- FIX: Hardcode target_size to 800x800 ---
30
- # The ONNX graph requires exactly this dimension.
31
  def preprocess_image(image, target_size=(800, 800)):
32
- h, w = image.shape[:2]
 
33
 
34
- # 1. Resize
35
- # We use linear interpolation to ensure smooth gradients
36
  img_resized = cv2.resize(image, target_size, interpolation=cv2.INTER_LINEAR)
37
 
38
  # 2. Normalize
@@ -46,10 +191,12 @@ def preprocess_image(image, target_size=(800, 800)):
46
 
47
  # 4. Prepare Metadata Inputs
48
  # scale_factor = resized_shape / original_shape
49
- scale_factor = np.array([target_size[0] / h, target_size[1] / w], dtype=np.float32).reshape(1, 2)
50
 
51
- # im_shape needs to be the input size (800, 800)
52
- im_shape = np.array([target_size[0], target_size[1]], dtype=np.float32).reshape(1, 2)
 
 
53
 
54
  return img_data, scale_factor, im_shape
55
 
@@ -60,7 +207,6 @@ def analyze_layout(input_image):
60
  image_np = np.array(input_image)
61
 
62
  # --- INFERENCE ---
63
- # This will now return an 800x800 blob
64
  img_blob, scale_factor, im_shape = preprocess_image(image_np)
65
 
66
  inputs = {}
@@ -77,7 +223,6 @@ def analyze_layout(input_image):
77
  outputs = session.run(output_names, inputs)
78
 
79
  # --- PARSE RESULTS ---
80
- # Output is [Batch, N, 6] -> [Class, Score, X1, Y1, X2, Y2]
81
  detections = outputs[0]
82
  if len(detections.shape) == 3:
83
  detections = detections[0]
@@ -85,9 +230,16 @@ def analyze_layout(input_image):
85
  viz_image = image_np.copy()
86
  log = []
87
 
 
 
 
 
 
88
  for det in detections:
89
  score = det[1]
90
- if score < 0.45: continue
 
 
91
 
92
  class_id = int(det[0])
93
  bbox = det[2:]
@@ -113,14 +265,17 @@ def analyze_layout(input_image):
113
  cv2.putText(viz_image, label_text, (x1, y1 - 5),
114
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
115
 
116
- log.append(f"Found {label_name} at [{x1}, {y1}, {x2}, {y2}]")
117
  except: pass
 
 
 
118
 
119
  return viz_image, "\n".join(log)
120
 
121
  with gr.Blocks(title="ONNX Layout Analysis") as demo:
122
  gr.Markdown("## ⚡ Fast V3 Layout Analysis (ONNX)")
123
- gr.Markdown(f"Running `{onnx_filename}` via ONNX Runtime (800x800).")
124
 
125
  with gr.Row():
126
  with gr.Column():
 
1
+ # import gradio as gr
2
+ # import cv2
3
+ # import numpy as np
4
+ # import onnxruntime as ort
5
+ # from huggingface_hub import hf_hub_download, list_repo_files
6
+
7
+ # # --- STEP 1: Find and Download Model ---
8
+ # REPO_ID = "alex-dinh/PP-DocLayoutV3-ONNX"
9
+ # print(f"Searching for ONNX model in {REPO_ID}...")
10
+
11
+ # all_files = list_repo_files(repo_id=REPO_ID)
12
+ # onnx_filename = next((f for f in all_files if f.endswith('.onnx')), None)
13
+ # if onnx_filename is None:
14
+ # raise FileNotFoundError("No .onnx file found in repo.")
15
+
16
+ # print(f"Found model file: {onnx_filename}")
17
+ # model_path = hf_hub_download(repo_id=REPO_ID, filename=onnx_filename)
18
+
19
+ # # --- STEP 2: Initialize Session ---
20
+ # session = ort.InferenceSession(model_path)
21
+ # model_inputs = session.get_inputs()
22
+ # input_names = [i.name for i in model_inputs]
23
+ # output_names = [o.name for o in session.get_outputs()]
24
+
25
+ # print(f"Model expects inputs: {input_names}")
26
+
27
+ # LABELS = {0: "Text", 1: "Title", 2: "List", 3: "Table", 4: "Figure"}
28
+
29
+ # # --- FIX: Hardcode target_size to 800x800 ---
30
+ # # The ONNX graph requires exactly this dimension.
31
+ # def preprocess_image(image, target_size=(800, 800)):
32
+ # h, w = image.shape[:2]
33
+
34
+ # # 1. Resize
35
+ # # We use linear interpolation to ensure smooth gradients
36
+ # img_resized = cv2.resize(image, target_size, interpolation=cv2.INTER_LINEAR)
37
+
38
+ # # 2. Normalize
39
+ # img_data = img_resized.astype(np.float32) / 255.0
40
+ # mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
41
+ # std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
42
+ # img_data = (img_data - mean) / std
43
+
44
+ # # 3. Transpose (HWC -> CHW)
45
+ # img_data = img_data.transpose(2, 0, 1)[None, :, :, :]
46
+
47
+ # # 4. Prepare Metadata Inputs
48
+ # # scale_factor = resized_shape / original_shape
49
+ # scale_factor = np.array([target_size[0] / h, target_size[1] / w], dtype=np.float32).reshape(1, 2)
50
+
51
+ # # im_shape needs to be the input size (800, 800)
52
+ # im_shape = np.array([target_size[0], target_size[1]], dtype=np.float32).reshape(1, 2)
53
+
54
+ # return img_data, scale_factor, im_shape
55
+
56
+ # def analyze_layout(input_image):
57
+ # if input_image is None:
58
+ # return None, "No image uploaded"
59
+
60
+ # image_np = np.array(input_image)
61
+
62
+ # # --- INFERENCE ---
63
+ # # This will now return an 800x800 blob
64
+ # img_blob, scale_factor, im_shape = preprocess_image(image_np)
65
+
66
+ # inputs = {}
67
+ # for i in model_inputs:
68
+ # name = i.name
69
+ # if 'image' in name:
70
+ # inputs[name] = img_blob
71
+ # elif 'scale' in name:
72
+ # inputs[name] = scale_factor
73
+ # elif 'shape' in name:
74
+ # inputs[name] = im_shape
75
+
76
+ # # Run ONNX
77
+ # outputs = session.run(output_names, inputs)
78
+
79
+ # # --- PARSE RESULTS ---
80
+ # # Output is [Batch, N, 6] -> [Class, Score, X1, Y1, X2, Y2]
81
+ # detections = outputs[0]
82
+ # if len(detections.shape) == 3:
83
+ # detections = detections[0]
84
+
85
+ # viz_image = image_np.copy()
86
+ # log = []
87
+
88
+ # for det in detections:
89
+ # score = det[1]
90
+ # if score < 0.45: continue
91
+
92
+ # class_id = int(det[0])
93
+ # bbox = det[2:]
94
+
95
+ # # Map labels
96
+ # label_name = LABELS.get(class_id, f"Class {class_id}")
97
+
98
+ # # Draw Box
99
+ # try:
100
+ # x1, y1, x2, y2 = map(int, bbox)
101
+
102
+ # # Color coding
103
+ # color = (0, 255, 0) # Green
104
+ # if "Title" in label_name: color = (0, 0, 255)
105
+ # elif "Table" in label_name: color = (255, 255, 0)
106
+ # elif "Figure" in label_name: color = (255, 0, 0)
107
+
108
+ # cv2.rectangle(viz_image, (x1, y1), (x2, y2), color, 3)
109
+
110
+ # label_text = f"{label_name} {score:.2f}"
111
+ # (w, h), _ = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)
112
+ # cv2.rectangle(viz_image, (x1, y1 - 20), (x1 + w, y1), color, -1)
113
+ # cv2.putText(viz_image, label_text, (x1, y1 - 5),
114
+ # cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
115
+
116
+ # log.append(f"Found {label_name} at [{x1}, {y1}, {x2}, {y2}]")
117
+ # except: pass
118
+
119
+ # return viz_image, "\n".join(log)
120
+
121
+ # with gr.Blocks(title="ONNX Layout Analysis") as demo:
122
+ # gr.Markdown("## ⚡ Fast V3 Layout Analysis (ONNX)")
123
+ # gr.Markdown(f"Running `{onnx_filename}` via ONNX Runtime (800x800).")
124
+
125
+ # with gr.Row():
126
+ # with gr.Column():
127
+ # input_img = gr.Image(type="pil", label="Input Document")
128
+ # submit_btn = gr.Button("Analyze Layout", variant="primary")
129
+
130
+ # with gr.Column():
131
+ # output_img = gr.Image(label="Layout Visualization")
132
+ # output_log = gr.Textbox(label="Detections", lines=10)
133
+
134
+ # submit_btn.click(fn=analyze_layout, inputs=input_img, outputs=[output_img, output_log])
135
+
136
+ # if __name__ == "__main__":
137
+ # demo.launch(server_name="0.0.0.0", server_port=7860)
138
+
139
+
140
+
141
+
142
+
143
+
144
+
145
+
146
+
147
+
148
+
149
+
150
  import gradio as gr
151
  import cv2
152
  import numpy as np
 
171
  input_names = [i.name for i in model_inputs]
172
  output_names = [o.name for o in session.get_outputs()]
173
 
 
 
174
  LABELS = {0: "Text", 1: "Title", 2: "List", 3: "Table", 4: "Figure"}
175
 
 
 
176
  def preprocess_image(image, target_size=(800, 800)):
177
+ # Original dimensions
178
+ orig_h, orig_w = image.shape[:2]
179
 
180
+ # 1. Resize (Warping to 800x800 is required by this graph)
 
181
  img_resized = cv2.resize(image, target_size, interpolation=cv2.INTER_LINEAR)
182
 
183
  # 2. Normalize
 
191
 
192
  # 4. Prepare Metadata Inputs
193
  # scale_factor = resized_shape / original_shape
194
+ scale_factor = np.array([target_size[0] / orig_h, target_size[1] / orig_w], dtype=np.float32).reshape(1, 2)
195
 
196
+ # --- CRITICAL FIX: im_shape must be the ORIGINAL image size ---
197
+ # This tells the model the valid area to keep boxes.
198
+ # If we put 800x800 here, it clips valid boxes on large documents.
199
+ im_shape = np.array([orig_h, orig_w], dtype=np.float32).reshape(1, 2)
200
 
201
  return img_data, scale_factor, im_shape
202
 
 
207
  image_np = np.array(input_image)
208
 
209
  # --- INFERENCE ---
 
210
  img_blob, scale_factor, im_shape = preprocess_image(image_np)
211
 
212
  inputs = {}
 
223
  outputs = session.run(output_names, inputs)
224
 
225
  # --- PARSE RESULTS ---
 
226
  detections = outputs[0]
227
  if len(detections.shape) == 3:
228
  detections = detections[0]
 
230
  viz_image = image_np.copy()
231
  log = []
232
 
233
+ # DEBUG: Print max score to check if model is working at all
234
+ if len(detections) > 0:
235
+ max_score = np.max(detections[:, 1])
236
+ print(f"DEBUG: Max confidence score found: {max_score}")
237
+
238
  for det in detections:
239
  score = det[1]
240
+
241
+ # Lowered threshold to 0.2 to catch faint detections
242
+ if score < 0.2: continue
243
 
244
  class_id = int(det[0])
245
  bbox = det[2:]
 
265
  cv2.putText(viz_image, label_text, (x1, y1 - 5),
266
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
267
 
268
+ log.append(f"Found {label_name} at [{x1}, {y1}, {x2}, {y2}] (Conf: {score:.2f})")
269
  except: pass
270
+
271
+ if not log:
272
+ log.append("No layout regions detected above threshold.")
273
 
274
  return viz_image, "\n".join(log)
275
 
276
  with gr.Blocks(title="ONNX Layout Analysis") as demo:
277
  gr.Markdown("## ⚡ Fast V3 Layout Analysis (ONNX)")
278
+ gr.Markdown(f"Running `{onnx_filename}` via ONNX Runtime.")
279
 
280
  with gr.Row():
281
  with gr.Column():