Update app.py
Browse files
app.py
CHANGED
|
@@ -1,3 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import cv2
|
| 3 |
import numpy as np
|
|
@@ -22,17 +171,13 @@ model_inputs = session.get_inputs()
|
|
| 22 |
input_names = [i.name for i in model_inputs]
|
| 23 |
output_names = [o.name for o in session.get_outputs()]
|
| 24 |
|
| 25 |
-
print(f"Model expects inputs: {input_names}")
|
| 26 |
-
|
| 27 |
LABELS = {0: "Text", 1: "Title", 2: "List", 3: "Table", 4: "Figure"}
|
| 28 |
|
| 29 |
-
# --- FIX: Hardcode target_size to 800x800 ---
|
| 30 |
-
# The ONNX graph requires exactly this dimension.
|
| 31 |
def preprocess_image(image, target_size=(800, 800)):
|
| 32 |
-
|
|
|
|
| 33 |
|
| 34 |
-
# 1. Resize
|
| 35 |
-
# We use linear interpolation to ensure smooth gradients
|
| 36 |
img_resized = cv2.resize(image, target_size, interpolation=cv2.INTER_LINEAR)
|
| 37 |
|
| 38 |
# 2. Normalize
|
|
@@ -46,10 +191,12 @@ def preprocess_image(image, target_size=(800, 800)):
|
|
| 46 |
|
| 47 |
# 4. Prepare Metadata Inputs
|
| 48 |
# scale_factor = resized_shape / original_shape
|
| 49 |
-
scale_factor = np.array([target_size[0] /
|
| 50 |
|
| 51 |
-
# im_shape
|
| 52 |
-
|
|
|
|
|
|
|
| 53 |
|
| 54 |
return img_data, scale_factor, im_shape
|
| 55 |
|
|
@@ -60,7 +207,6 @@ def analyze_layout(input_image):
|
|
| 60 |
image_np = np.array(input_image)
|
| 61 |
|
| 62 |
# --- INFERENCE ---
|
| 63 |
-
# This will now return an 800x800 blob
|
| 64 |
img_blob, scale_factor, im_shape = preprocess_image(image_np)
|
| 65 |
|
| 66 |
inputs = {}
|
|
@@ -77,7 +223,6 @@ def analyze_layout(input_image):
|
|
| 77 |
outputs = session.run(output_names, inputs)
|
| 78 |
|
| 79 |
# --- PARSE RESULTS ---
|
| 80 |
-
# Output is [Batch, N, 6] -> [Class, Score, X1, Y1, X2, Y2]
|
| 81 |
detections = outputs[0]
|
| 82 |
if len(detections.shape) == 3:
|
| 83 |
detections = detections[0]
|
|
@@ -85,9 +230,16 @@ def analyze_layout(input_image):
|
|
| 85 |
viz_image = image_np.copy()
|
| 86 |
log = []
|
| 87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
for det in detections:
|
| 89 |
score = det[1]
|
| 90 |
-
|
|
|
|
|
|
|
| 91 |
|
| 92 |
class_id = int(det[0])
|
| 93 |
bbox = det[2:]
|
|
@@ -113,14 +265,17 @@ def analyze_layout(input_image):
|
|
| 113 |
cv2.putText(viz_image, label_text, (x1, y1 - 5),
|
| 114 |
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
|
| 115 |
|
| 116 |
-
log.append(f"Found {label_name} at [{x1}, {y1}, {x2}, {y2}]")
|
| 117 |
except: pass
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
return viz_image, "\n".join(log)
|
| 120 |
|
| 121 |
with gr.Blocks(title="ONNX Layout Analysis") as demo:
|
| 122 |
gr.Markdown("## ⚡ Fast V3 Layout Analysis (ONNX)")
|
| 123 |
-
gr.Markdown(f"Running `{onnx_filename}` via ONNX Runtime
|
| 124 |
|
| 125 |
with gr.Row():
|
| 126 |
with gr.Column():
|
|
|
|
| 1 |
+
# import gradio as gr
|
| 2 |
+
# import cv2
|
| 3 |
+
# import numpy as np
|
| 4 |
+
# import onnxruntime as ort
|
| 5 |
+
# from huggingface_hub import hf_hub_download, list_repo_files
|
| 6 |
+
|
| 7 |
+
# # --- STEP 1: Find and Download Model ---
|
| 8 |
+
# REPO_ID = "alex-dinh/PP-DocLayoutV3-ONNX"
|
| 9 |
+
# print(f"Searching for ONNX model in {REPO_ID}...")
|
| 10 |
+
|
| 11 |
+
# all_files = list_repo_files(repo_id=REPO_ID)
|
| 12 |
+
# onnx_filename = next((f for f in all_files if f.endswith('.onnx')), None)
|
| 13 |
+
# if onnx_filename is None:
|
| 14 |
+
# raise FileNotFoundError("No .onnx file found in repo.")
|
| 15 |
+
|
| 16 |
+
# print(f"Found model file: {onnx_filename}")
|
| 17 |
+
# model_path = hf_hub_download(repo_id=REPO_ID, filename=onnx_filename)
|
| 18 |
+
|
| 19 |
+
# # --- STEP 2: Initialize Session ---
|
| 20 |
+
# session = ort.InferenceSession(model_path)
|
| 21 |
+
# model_inputs = session.get_inputs()
|
| 22 |
+
# input_names = [i.name for i in model_inputs]
|
| 23 |
+
# output_names = [o.name for o in session.get_outputs()]
|
| 24 |
+
|
| 25 |
+
# print(f"Model expects inputs: {input_names}")
|
| 26 |
+
|
| 27 |
+
# LABELS = {0: "Text", 1: "Title", 2: "List", 3: "Table", 4: "Figure"}
|
| 28 |
+
|
| 29 |
+
# # --- FIX: Hardcode target_size to 800x800 ---
|
| 30 |
+
# # The ONNX graph requires exactly this dimension.
|
| 31 |
+
# def preprocess_image(image, target_size=(800, 800)):
|
| 32 |
+
# h, w = image.shape[:2]
|
| 33 |
+
|
| 34 |
+
# # 1. Resize
|
| 35 |
+
# # We use linear interpolation to ensure smooth gradients
|
| 36 |
+
# img_resized = cv2.resize(image, target_size, interpolation=cv2.INTER_LINEAR)
|
| 37 |
+
|
| 38 |
+
# # 2. Normalize
|
| 39 |
+
# img_data = img_resized.astype(np.float32) / 255.0
|
| 40 |
+
# mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
|
| 41 |
+
# std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
|
| 42 |
+
# img_data = (img_data - mean) / std
|
| 43 |
+
|
| 44 |
+
# # 3. Transpose (HWC -> CHW)
|
| 45 |
+
# img_data = img_data.transpose(2, 0, 1)[None, :, :, :]
|
| 46 |
+
|
| 47 |
+
# # 4. Prepare Metadata Inputs
|
| 48 |
+
# # scale_factor = resized_shape / original_shape
|
| 49 |
+
# scale_factor = np.array([target_size[0] / h, target_size[1] / w], dtype=np.float32).reshape(1, 2)
|
| 50 |
+
|
| 51 |
+
# # im_shape needs to be the input size (800, 800)
|
| 52 |
+
# im_shape = np.array([target_size[0], target_size[1]], dtype=np.float32).reshape(1, 2)
|
| 53 |
+
|
| 54 |
+
# return img_data, scale_factor, im_shape
|
| 55 |
+
|
| 56 |
+
# def analyze_layout(input_image):
|
| 57 |
+
# if input_image is None:
|
| 58 |
+
# return None, "No image uploaded"
|
| 59 |
+
|
| 60 |
+
# image_np = np.array(input_image)
|
| 61 |
+
|
| 62 |
+
# # --- INFERENCE ---
|
| 63 |
+
# # This will now return an 800x800 blob
|
| 64 |
+
# img_blob, scale_factor, im_shape = preprocess_image(image_np)
|
| 65 |
+
|
| 66 |
+
# inputs = {}
|
| 67 |
+
# for i in model_inputs:
|
| 68 |
+
# name = i.name
|
| 69 |
+
# if 'image' in name:
|
| 70 |
+
# inputs[name] = img_blob
|
| 71 |
+
# elif 'scale' in name:
|
| 72 |
+
# inputs[name] = scale_factor
|
| 73 |
+
# elif 'shape' in name:
|
| 74 |
+
# inputs[name] = im_shape
|
| 75 |
+
|
| 76 |
+
# # Run ONNX
|
| 77 |
+
# outputs = session.run(output_names, inputs)
|
| 78 |
+
|
| 79 |
+
# # --- PARSE RESULTS ---
|
| 80 |
+
# # Output is [Batch, N, 6] -> [Class, Score, X1, Y1, X2, Y2]
|
| 81 |
+
# detections = outputs[0]
|
| 82 |
+
# if len(detections.shape) == 3:
|
| 83 |
+
# detections = detections[0]
|
| 84 |
+
|
| 85 |
+
# viz_image = image_np.copy()
|
| 86 |
+
# log = []
|
| 87 |
+
|
| 88 |
+
# for det in detections:
|
| 89 |
+
# score = det[1]
|
| 90 |
+
# if score < 0.45: continue
|
| 91 |
+
|
| 92 |
+
# class_id = int(det[0])
|
| 93 |
+
# bbox = det[2:]
|
| 94 |
+
|
| 95 |
+
# # Map labels
|
| 96 |
+
# label_name = LABELS.get(class_id, f"Class {class_id}")
|
| 97 |
+
|
| 98 |
+
# # Draw Box
|
| 99 |
+
# try:
|
| 100 |
+
# x1, y1, x2, y2 = map(int, bbox)
|
| 101 |
+
|
| 102 |
+
# # Color coding
|
| 103 |
+
# color = (0, 255, 0) # Green
|
| 104 |
+
# if "Title" in label_name: color = (0, 0, 255)
|
| 105 |
+
# elif "Table" in label_name: color = (255, 255, 0)
|
| 106 |
+
# elif "Figure" in label_name: color = (255, 0, 0)
|
| 107 |
+
|
| 108 |
+
# cv2.rectangle(viz_image, (x1, y1), (x2, y2), color, 3)
|
| 109 |
+
|
| 110 |
+
# label_text = f"{label_name} {score:.2f}"
|
| 111 |
+
# (w, h), _ = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)
|
| 112 |
+
# cv2.rectangle(viz_image, (x1, y1 - 20), (x1 + w, y1), color, -1)
|
| 113 |
+
# cv2.putText(viz_image, label_text, (x1, y1 - 5),
|
| 114 |
+
# cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
|
| 115 |
+
|
| 116 |
+
# log.append(f"Found {label_name} at [{x1}, {y1}, {x2}, {y2}]")
|
| 117 |
+
# except: pass
|
| 118 |
+
|
| 119 |
+
# return viz_image, "\n".join(log)
|
| 120 |
+
|
| 121 |
+
# with gr.Blocks(title="ONNX Layout Analysis") as demo:
|
| 122 |
+
# gr.Markdown("## ⚡ Fast V3 Layout Analysis (ONNX)")
|
| 123 |
+
# gr.Markdown(f"Running `{onnx_filename}` via ONNX Runtime (800x800).")
|
| 124 |
+
|
| 125 |
+
# with gr.Row():
|
| 126 |
+
# with gr.Column():
|
| 127 |
+
# input_img = gr.Image(type="pil", label="Input Document")
|
| 128 |
+
# submit_btn = gr.Button("Analyze Layout", variant="primary")
|
| 129 |
+
|
| 130 |
+
# with gr.Column():
|
| 131 |
+
# output_img = gr.Image(label="Layout Visualization")
|
| 132 |
+
# output_log = gr.Textbox(label="Detections", lines=10)
|
| 133 |
+
|
| 134 |
+
# submit_btn.click(fn=analyze_layout, inputs=input_img, outputs=[output_img, output_log])
|
| 135 |
+
|
| 136 |
+
# if __name__ == "__main__":
|
| 137 |
+
# demo.launch(server_name="0.0.0.0", server_port=7860)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
|
| 150 |
import gradio as gr
|
| 151 |
import cv2
|
| 152 |
import numpy as np
|
|
|
|
| 171 |
input_names = [i.name for i in model_inputs]
|
| 172 |
output_names = [o.name for o in session.get_outputs()]
|
| 173 |
|
|
|
|
|
|
|
| 174 |
LABELS = {0: "Text", 1: "Title", 2: "List", 3: "Table", 4: "Figure"}
|
| 175 |
|
|
|
|
|
|
|
| 176 |
def preprocess_image(image, target_size=(800, 800)):
|
| 177 |
+
# Original dimensions
|
| 178 |
+
orig_h, orig_w = image.shape[:2]
|
| 179 |
|
| 180 |
+
# 1. Resize (Warping to 800x800 is required by this graph)
|
|
|
|
| 181 |
img_resized = cv2.resize(image, target_size, interpolation=cv2.INTER_LINEAR)
|
| 182 |
|
| 183 |
# 2. Normalize
|
|
|
|
| 191 |
|
| 192 |
# 4. Prepare Metadata Inputs
|
| 193 |
# scale_factor = resized_shape / original_shape
|
| 194 |
+
scale_factor = np.array([target_size[0] / orig_h, target_size[1] / orig_w], dtype=np.float32).reshape(1, 2)
|
| 195 |
|
| 196 |
+
# --- CRITICAL FIX: im_shape must be the ORIGINAL image size ---
|
| 197 |
+
# This tells the model the valid area to keep boxes.
|
| 198 |
+
# If we put 800x800 here, it clips valid boxes on large documents.
|
| 199 |
+
im_shape = np.array([orig_h, orig_w], dtype=np.float32).reshape(1, 2)
|
| 200 |
|
| 201 |
return img_data, scale_factor, im_shape
|
| 202 |
|
|
|
|
| 207 |
image_np = np.array(input_image)
|
| 208 |
|
| 209 |
# --- INFERENCE ---
|
|
|
|
| 210 |
img_blob, scale_factor, im_shape = preprocess_image(image_np)
|
| 211 |
|
| 212 |
inputs = {}
|
|
|
|
| 223 |
outputs = session.run(output_names, inputs)
|
| 224 |
|
| 225 |
# --- PARSE RESULTS ---
|
|
|
|
| 226 |
detections = outputs[0]
|
| 227 |
if len(detections.shape) == 3:
|
| 228 |
detections = detections[0]
|
|
|
|
| 230 |
viz_image = image_np.copy()
|
| 231 |
log = []
|
| 232 |
|
| 233 |
+
# DEBUG: Print max score to check if model is working at all
|
| 234 |
+
if len(detections) > 0:
|
| 235 |
+
max_score = np.max(detections[:, 1])
|
| 236 |
+
print(f"DEBUG: Max confidence score found: {max_score}")
|
| 237 |
+
|
| 238 |
for det in detections:
|
| 239 |
score = det[1]
|
| 240 |
+
|
| 241 |
+
# Lowered threshold to 0.2 to catch faint detections
|
| 242 |
+
if score < 0.2: continue
|
| 243 |
|
| 244 |
class_id = int(det[0])
|
| 245 |
bbox = det[2:]
|
|
|
|
| 265 |
cv2.putText(viz_image, label_text, (x1, y1 - 5),
|
| 266 |
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
|
| 267 |
|
| 268 |
+
log.append(f"Found {label_name} at [{x1}, {y1}, {x2}, {y2}] (Conf: {score:.2f})")
|
| 269 |
except: pass
|
| 270 |
+
|
| 271 |
+
if not log:
|
| 272 |
+
log.append("No layout regions detected above threshold.")
|
| 273 |
|
| 274 |
return viz_image, "\n".join(log)
|
| 275 |
|
| 276 |
with gr.Blocks(title="ONNX Layout Analysis") as demo:
|
| 277 |
gr.Markdown("## ⚡ Fast V3 Layout Analysis (ONNX)")
|
| 278 |
+
gr.Markdown(f"Running `{onnx_filename}` via ONNX Runtime.")
|
| 279 |
|
| 280 |
with gr.Row():
|
| 281 |
with gr.Column():
|