| | import gradio as gr |
| | import cv2 |
| | import numpy as np |
| | import onnxruntime as ort |
| | from huggingface_hub import hf_hub_download |
| |
|
| | |
| | print("Downloading ONNX model...") |
| | model_path = hf_hub_download(repo_id="alex-dinh/PP-DocLayoutV3-ONNX", filename="model.onnx") |
| | print(f"Model downloaded to: {model_path}") |
| |
|
| | |
| | |
| | session = ort.InferenceSession(model_path) |
| | input_names = [i.name for i in session.get_inputs()] |
| | output_names = [o.name for o in session.get_outputs()] |
| |
|
| | |
| | LABELS = {1: "Text", 2: "Title", 3: "List", 4: "Table", 5: "Figure"} |
| |
|
| | def preprocess_image(image, target_size=(800, 800)): |
| | """ |
| | Prepares the image exactly how the AI expects it (Resize -> Normalize). |
| | """ |
| | h, w = image.shape[:2] |
| | |
| | |
| | |
| | img_resized = cv2.resize(image, target_size) |
| | |
| | |
| | img_data = img_resized.astype(np.float32) / 255.0 |
| | mean = np.array([0.485, 0.456, 0.406], dtype=np.float32) |
| | std = np.array([0.229, 0.224, 0.225], dtype=np.float32) |
| | img_data = (img_data - mean) / std |
| | |
| | |
| | img_data = img_data.transpose(2, 0, 1)[None, :, :, :] |
| | |
| | |
| | scale_factor = np.array([target_size[0] / h, target_size[1] / w], dtype=np.float32).reshape(1, 2) |
| | |
| | return img_data, scale_factor |
| |
|
| | def analyze_layout(input_image): |
| | if input_image is None: |
| | return None, "No image uploaded" |
| |
|
| | |
| | image_np = np.array(input_image) |
| | orig_h, orig_w = image_np.shape[:2] |
| |
|
| | |
| | input_blob, scale_factor = preprocess_image(image_np) |
| | |
| | |
| | inputs = { |
| | input_names[0]: input_blob, |
| | input_names[1]: scale_factor |
| | } |
| | |
| | |
| | outputs = session.run(output_names, inputs) |
| | |
| | |
| | |
| | detections = outputs[0] |
| | |
| | viz_image = image_np.copy() |
| | log = [] |
| |
|
| | for det in detections: |
| | class_id = int(det[0]) |
| | score = det[1] |
| | bbox = det[2:] |
| |
|
| | if score < 0.5: continue |
| |
|
| | |
| | label_name = LABELS.get(class_id, "Unknown") |
| | |
| | |
| | x1, y1, x2, y2 = map(int, bbox) |
| | |
| | |
| | color = (0, 255, 0) |
| | if label_name == "Title": color = (0, 0, 255) |
| | elif label_name == "Table": color = (255, 255, 0) |
| | elif label_name == "Figure": color = (255, 0, 0) |
| |
|
| | |
| | cv2.rectangle(viz_image, (x1, y1), (x2, y2), color, 3) |
| | cv2.putText(viz_image, f"{label_name} {score:.2f}", (x1, y1-10), |
| | cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2) |
| | |
| | log.append(f"Found {label_name} at [{x1}, {y1}, {x2}, {y2}] (Conf: {score:.2f})") |
| |
|
| | return viz_image, "\n".join(log) |
| |
|
| | with gr.Blocks(title="ONNX Layout Analysis") as demo: |
| | gr.Markdown("## ⚡ Fast V3 Layout Analysis (ONNX)") |
| | gr.Markdown("Uses **PP-DocLayoutV3** via ONNX Runtime. No Paddle dependencies.") |
| | |
| | with gr.Row(): |
| | with gr.Column(): |
| | input_img = gr.Image(type="pil", label="Input Document") |
| | submit_btn = gr.Button("Analyze Layout", variant="primary") |
| | |
| | with gr.Column(): |
| | output_img = gr.Image(label="Layout Visualization") |
| | output_log = gr.Textbox(label="Detections", lines=10) |
| |
|
| | submit_btn.click(fn=analyze_layout, inputs=input_img, outputs=[output_img, output_log]) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch(server_name="0.0.0.0", server_port=7860) |