File size: 9,878 Bytes
b364284 148d241 2a30a76 e55fda2 c70ab27 0cf77d7 83ac06f c70ab27 83ac06f c70ab27 e55fda2 83ac06f e55fda2 83ac06f e55fda2 019f2ad c70ab27 e55fda2 2f851c6 019f2ad e55fda2 019f2ad c70ab27 e55fda2 83ac06f e55fda2 c70ab27 e55fda2 83ac06f e55fda2 019f2ad 83ac06f 2a30a76 e55fda2 83ac06f c70ab27 83ac06f e55fda2 2f851c6 c70ab27 019f2ad 83ac06f 019f2ad b364284 e55fda2 b364284 019f2ad e55fda2 c70ab27 e55fda2 c70ab27 af3df60 c70ab27 b364284 83ac06f b364284 2a30a76 e55fda2 2a30a76 019f2ad 2a30a76 148d241 2a30a76 e55fda2 148d241 2a30a76 148d241 2a30a76 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 |
# import gradio as gr
# import cv2
# import numpy as np
# import onnxruntime as ort
# from huggingface_hub import hf_hub_download, list_repo_files
# # --- STEP 1: Find and Download Model ---
# REPO_ID = "alex-dinh/PP-DocLayoutV3-ONNX"
# print(f"Searching for ONNX model in {REPO_ID}...")
# all_files = list_repo_files(repo_id=REPO_ID)
# onnx_filename = next((f for f in all_files if f.endswith('.onnx')), None)
# if onnx_filename is None:
# raise FileNotFoundError("No .onnx file found in repo.")
# print(f"Found model file: {onnx_filename}")
# model_path = hf_hub_download(repo_id=REPO_ID, filename=onnx_filename)
# # --- STEP 2: Initialize Session ---
# session = ort.InferenceSession(model_path)
# model_inputs = session.get_inputs()
# input_names = [i.name for i in model_inputs]
# output_names = [o.name for o in session.get_outputs()]
# print(f"Model expects inputs: {input_names}")
# LABELS = {0: "Text", 1: "Title", 2: "List", 3: "Table", 4: "Figure"}
# # --- FIX: Hardcode target_size to 800x800 ---
# # The ONNX graph requires exactly this dimension.
# def preprocess_image(image, target_size=(800, 800)):
# h, w = image.shape[:2]
# # 1. Resize
# # We use linear interpolation to ensure smooth gradients
# img_resized = cv2.resize(image, target_size, interpolation=cv2.INTER_LINEAR)
# # 2. Normalize
# img_data = img_resized.astype(np.float32) / 255.0
# mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
# std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
# img_data = (img_data - mean) / std
# # 3. Transpose (HWC -> CHW)
# img_data = img_data.transpose(2, 0, 1)[None, :, :, :]
# # 4. Prepare Metadata Inputs
# # scale_factor = resized_shape / original_shape
# scale_factor = np.array([target_size[0] / h, target_size[1] / w], dtype=np.float32).reshape(1, 2)
# # im_shape needs to be the input size (800, 800)
# im_shape = np.array([target_size[0], target_size[1]], dtype=np.float32).reshape(1, 2)
# return img_data, scale_factor, im_shape
# def analyze_layout(input_image):
# if input_image is None:
# return None, "No image uploaded"
# image_np = np.array(input_image)
# # --- INFERENCE ---
# # This will now return an 800x800 blob
# img_blob, scale_factor, im_shape = preprocess_image(image_np)
# inputs = {}
# for i in model_inputs:
# name = i.name
# if 'image' in name:
# inputs[name] = img_blob
# elif 'scale' in name:
# inputs[name] = scale_factor
# elif 'shape' in name:
# inputs[name] = im_shape
# # Run ONNX
# outputs = session.run(output_names, inputs)
# # --- PARSE RESULTS ---
# # Output is [Batch, N, 6] -> [Class, Score, X1, Y1, X2, Y2]
# detections = outputs[0]
# if len(detections.shape) == 3:
# detections = detections[0]
# viz_image = image_np.copy()
# log = []
# for det in detections:
# score = det[1]
# if score < 0.45: continue
# class_id = int(det[0])
# bbox = det[2:]
# # Map labels
# label_name = LABELS.get(class_id, f"Class {class_id}")
# # Draw Box
# try:
# x1, y1, x2, y2 = map(int, bbox)
# # Color coding
# color = (0, 255, 0) # Green
# if "Title" in label_name: color = (0, 0, 255)
# elif "Table" in label_name: color = (255, 255, 0)
# elif "Figure" in label_name: color = (255, 0, 0)
# cv2.rectangle(viz_image, (x1, y1), (x2, y2), color, 3)
# label_text = f"{label_name} {score:.2f}"
# (w, h), _ = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)
# cv2.rectangle(viz_image, (x1, y1 - 20), (x1 + w, y1), color, -1)
# cv2.putText(viz_image, label_text, (x1, y1 - 5),
# cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
# log.append(f"Found {label_name} at [{x1}, {y1}, {x2}, {y2}]")
# except: pass
# return viz_image, "\n".join(log)
# with gr.Blocks(title="ONNX Layout Analysis") as demo:
# gr.Markdown("## ⚡ Fast V3 Layout Analysis (ONNX)")
# gr.Markdown(f"Running `{onnx_filename}` via ONNX Runtime (800x800).")
# with gr.Row():
# with gr.Column():
# input_img = gr.Image(type="pil", label="Input Document")
# submit_btn = gr.Button("Analyze Layout", variant="primary")
# with gr.Column():
# output_img = gr.Image(label="Layout Visualization")
# output_log = gr.Textbox(label="Detections", lines=10)
# submit_btn.click(fn=analyze_layout, inputs=input_img, outputs=[output_img, output_log])
# if __name__ == "__main__":
# demo.launch(server_name="0.0.0.0", server_port=7860)
import gradio as gr
import cv2
import numpy as np
import onnxruntime as ort
from huggingface_hub import hf_hub_download, list_repo_files
# --- STEP 1: Find and Download Model ---
REPO_ID = "alex-dinh/PP-DocLayoutV3-ONNX"
print(f"Searching for ONNX model in {REPO_ID}...")
all_files = list_repo_files(repo_id=REPO_ID)
onnx_filename = next((f for f in all_files if f.endswith('.onnx')), None)
if onnx_filename is None:
raise FileNotFoundError("No .onnx file found in repo.")
print(f"Found model file: {onnx_filename}")
model_path = hf_hub_download(repo_id=REPO_ID, filename=onnx_filename)
# --- STEP 2: Initialize Session ---
session = ort.InferenceSession(model_path)
model_inputs = session.get_inputs()
input_names = [i.name for i in model_inputs]
output_names = [o.name for o in session.get_outputs()]
print(f"Model expects inputs: {input_names}")
LABELS = {0: "Text", 1: "Title", 2: "List", 3: "Table", 4: "Figure"}
def preprocess_image(image, target_size=(800, 800)):
h, w = image.shape[:2]
# 1. Resize
img_resized = cv2.resize(image, target_size, interpolation=cv2.INTER_LINEAR)
# 2. Normalize
img_data = img_resized.astype(np.float32) / 255.0
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
img_data = (img_data - mean) / std
# 3. Transpose (HWC -> CHW)
img_data = img_data.transpose(2, 0, 1)[None, :, :, :]
# 4. Prepare Metadata Inputs
# Scale Factor: Ratio of resized / original
scale_factor = np.array([target_size[0] / h, target_size[1] / w], dtype=np.float32).reshape(1, 2)
# --- DEBUG CHANGE: Try passing target_size as im_shape ---
# Some exports want the INPUT size (800,800), not the ORIGINAL size.
im_shape = np.array([target_size[0], target_size[1]], dtype=np.float32).reshape(1, 2)
return img_data, scale_factor, im_shape
def analyze_layout(input_image):
if input_image is None:
return None, "No image uploaded"
image_np = np.array(input_image)
# --- INFERENCE ---
img_blob, scale_factor, im_shape = preprocess_image(image_np)
inputs = {}
for i in model_inputs:
name = i.name
if 'image' in name:
inputs[name] = img_blob
elif 'scale' in name:
inputs[name] = scale_factor
elif 'shape' in name:
inputs[name] = im_shape
outputs = session.run(output_names, inputs)
detections = outputs[0]
if len(detections.shape) == 3:
detections = detections[0]
# --- RAW DEBUG LOGGING ---
print(f"\n[DEBUG] Raw Detections Shape: {detections.shape}")
print(f"[DEBUG] Top 3 Raw Detections (Class, Score, BBox):")
for i in range(min(3, len(detections))):
print(f" {detections[i]}")
viz_image = image_np.copy()
log = []
# Sort by score descending to find the best ones
# detections = detections[detections[:, 1].argsort()[::-1]]
for det in detections:
score = det[1]
# Lower threshold strictly for debugging
if score < 0.3: continue
class_id = int(det[0])
bbox = det[2:]
# Map labels
label_name = LABELS.get(class_id, f"Class {class_id}")
try:
x1, y1, x2, y2 = map(int, bbox)
# Color coding
color = (0, 255, 0) # Green
if "Title" in label_name: color = (0, 0, 255)
elif "Table" in label_name: color = (255, 255, 0)
elif "Figure" in label_name: color = (255, 0, 0)
cv2.rectangle(viz_image, (x1, y1), (x2, y2), color, 3)
label_text = f"{label_name} {score:.2f}"
(w, h), _ = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)
cv2.rectangle(viz_image, (x1, y1 - 20), (x1 + w, y1), color, -1)
cv2.putText(viz_image, label_text, (x1, y1 - 5),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
log.append(f"Found {label_name} at [{x1}, {y1}, {x2}, {y2}] (Conf: {score:.2f})")
except: pass
if not log:
log.append("No layout regions detected above threshold.")
return viz_image, "\n".join(log)
with gr.Blocks(title="ONNX Layout Analysis (Debug)") as demo:
gr.Markdown("## ⚡ Layout Analysis (Debug Mode)")
with gr.Row():
with gr.Column():
input_img = gr.Image(type="pil", label="Input Document")
submit_btn = gr.Button("Analyze Layout", variant="primary")
with gr.Column():
output_img = gr.Image(label="Layout Visualization")
output_log = gr.Textbox(label="Detections", lines=10)
submit_btn.click(fn=analyze_layout, inputs=input_img, outputs=[output_img, output_log])
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860) |