document / app.py
iammraat's picture
Update app.py
e55fda2 verified
raw
history blame
4.01 kB
import gradio as gr
import cv2
import numpy as np
import onnxruntime as ort
from huggingface_hub import hf_hub_download
# --- STEP 1: Download the ONNX Model ---
print("Downloading ONNX model...")
model_path = hf_hub_download(repo_id="alex-dinh/PP-DocLayoutV3-ONNX", filename="model.onnx")
print(f"Model downloaded to: {model_path}")
# --- STEP 2: Initialize ONNX Engine ---
# This loads the AI "brain" directly without needing Paddle
session = ort.InferenceSession(model_path)
input_names = [i.name for i in session.get_inputs()]
output_names = [o.name for o in session.get_outputs()]
# Define labels map (Standard for PP-DocLayout)
LABELS = {1: "Text", 2: "Title", 3: "List", 4: "Table", 5: "Figure"}
def preprocess_image(image, target_size=(800, 800)):
"""
Prepares the image exactly how the AI expects it (Resize -> Normalize).
"""
h, w = image.shape[:2]
# 1. Resize
# We do NOT keep aspect ratio for the input blob, but we keep scales to fix boxes later
img_resized = cv2.resize(image, target_size)
# 2. Normalize (Standard ImageNet mean/std)
img_data = img_resized.astype(np.float32) / 255.0
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
img_data = (img_data - mean) / std
# 3. Transpose to (Batch, Channel, Height, Width)
img_data = img_data.transpose(2, 0, 1)[None, :, :, :]
# Calculate scale factors to map detections back to original image
scale_factor = np.array([target_size[0] / h, target_size[1] / w], dtype=np.float32).reshape(1, 2)
return img_data, scale_factor
def analyze_layout(input_image):
if input_image is None:
return None, "No image uploaded"
# Convert PIL to Numpy/OpenCV
image_np = np.array(input_image)
orig_h, orig_w = image_np.shape[:2]
# --- INFERENCE ---
input_blob, scale_factor = preprocess_image(image_np)
# ONNX Runtime inputs
inputs = {
input_names[0]: input_blob, # The image data
input_names[1]: scale_factor # The resize scale
}
# Run!
outputs = session.run(output_names, inputs)
# --- POST-PROCESSING ---
# Output format is typically [Batch, N, 6] -> [Class, Score, X1, Y1, X2, Y2]
detections = outputs[0]
viz_image = image_np.copy()
log = []
for det in detections:
class_id = int(det[0])
score = det[1]
bbox = det[2:]
if score < 0.5: continue # Filter weak detections
# Map labels
label_name = LABELS.get(class_id, "Unknown")
# Coordinates
x1, y1, x2, y2 = map(int, bbox)
# Color coding
color = (0, 255, 0) # Green
if label_name == "Title": color = (0, 0, 255)
elif label_name == "Table": color = (255, 255, 0)
elif label_name == "Figure": color = (255, 0, 0)
# Draw
cv2.rectangle(viz_image, (x1, y1), (x2, y2), color, 3)
cv2.putText(viz_image, f"{label_name} {score:.2f}", (x1, y1-10),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
log.append(f"Found {label_name} at [{x1}, {y1}, {x2}, {y2}] (Conf: {score:.2f})")
return viz_image, "\n".join(log)
with gr.Blocks(title="ONNX Layout Analysis") as demo:
gr.Markdown("## ⚡ Fast V3 Layout Analysis (ONNX)")
gr.Markdown("Uses **PP-DocLayoutV3** via ONNX Runtime. No Paddle dependencies.")
with gr.Row():
with gr.Column():
input_img = gr.Image(type="pil", label="Input Document")
submit_btn = gr.Button("Analyze Layout", variant="primary")
with gr.Column():
output_img = gr.Image(label="Layout Visualization")
output_log = gr.Textbox(label="Detections", lines=10)
submit_btn.click(fn=analyze_layout, inputs=input_img, outputs=[output_img, output_log])
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)