document / app.py
iammraat's picture
Update app.py
019f2ad verified
# import gradio as gr
# import cv2
# import numpy as np
# import onnxruntime as ort
# from huggingface_hub import hf_hub_download, list_repo_files
# # --- STEP 1: Find and Download Model ---
# REPO_ID = "alex-dinh/PP-DocLayoutV3-ONNX"
# print(f"Searching for ONNX model in {REPO_ID}...")
# all_files = list_repo_files(repo_id=REPO_ID)
# onnx_filename = next((f for f in all_files if f.endswith('.onnx')), None)
# if onnx_filename is None:
# raise FileNotFoundError("No .onnx file found in repo.")
# print(f"Found model file: {onnx_filename}")
# model_path = hf_hub_download(repo_id=REPO_ID, filename=onnx_filename)
# # --- STEP 2: Initialize Session ---
# session = ort.InferenceSession(model_path)
# model_inputs = session.get_inputs()
# input_names = [i.name for i in model_inputs]
# output_names = [o.name for o in session.get_outputs()]
# print(f"Model expects inputs: {input_names}")
# LABELS = {0: "Text", 1: "Title", 2: "List", 3: "Table", 4: "Figure"}
# # --- FIX: Hardcode target_size to 800x800 ---
# # The ONNX graph requires exactly this dimension.
# def preprocess_image(image, target_size=(800, 800)):
# h, w = image.shape[:2]
# # 1. Resize
# # We use linear interpolation to ensure smooth gradients
# img_resized = cv2.resize(image, target_size, interpolation=cv2.INTER_LINEAR)
# # 2. Normalize
# img_data = img_resized.astype(np.float32) / 255.0
# mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
# std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
# img_data = (img_data - mean) / std
# # 3. Transpose (HWC -> CHW)
# img_data = img_data.transpose(2, 0, 1)[None, :, :, :]
# # 4. Prepare Metadata Inputs
# # scale_factor = resized_shape / original_shape
# scale_factor = np.array([target_size[0] / h, target_size[1] / w], dtype=np.float32).reshape(1, 2)
# # im_shape needs to be the input size (800, 800)
# im_shape = np.array([target_size[0], target_size[1]], dtype=np.float32).reshape(1, 2)
# return img_data, scale_factor, im_shape
# def analyze_layout(input_image):
# if input_image is None:
# return None, "No image uploaded"
# image_np = np.array(input_image)
# # --- INFERENCE ---
# # This will now return an 800x800 blob
# img_blob, scale_factor, im_shape = preprocess_image(image_np)
# inputs = {}
# for i in model_inputs:
# name = i.name
# if 'image' in name:
# inputs[name] = img_blob
# elif 'scale' in name:
# inputs[name] = scale_factor
# elif 'shape' in name:
# inputs[name] = im_shape
# # Run ONNX
# outputs = session.run(output_names, inputs)
# # --- PARSE RESULTS ---
# # Output is [Batch, N, 6] -> [Class, Score, X1, Y1, X2, Y2]
# detections = outputs[0]
# if len(detections.shape) == 3:
# detections = detections[0]
# viz_image = image_np.copy()
# log = []
# for det in detections:
# score = det[1]
# if score < 0.45: continue
# class_id = int(det[0])
# bbox = det[2:]
# # Map labels
# label_name = LABELS.get(class_id, f"Class {class_id}")
# # Draw Box
# try:
# x1, y1, x2, y2 = map(int, bbox)
# # Color coding
# color = (0, 255, 0) # Green
# if "Title" in label_name: color = (0, 0, 255)
# elif "Table" in label_name: color = (255, 255, 0)
# elif "Figure" in label_name: color = (255, 0, 0)
# cv2.rectangle(viz_image, (x1, y1), (x2, y2), color, 3)
# label_text = f"{label_name} {score:.2f}"
# (w, h), _ = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)
# cv2.rectangle(viz_image, (x1, y1 - 20), (x1 + w, y1), color, -1)
# cv2.putText(viz_image, label_text, (x1, y1 - 5),
# cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
# log.append(f"Found {label_name} at [{x1}, {y1}, {x2}, {y2}]")
# except: pass
# return viz_image, "\n".join(log)
# with gr.Blocks(title="ONNX Layout Analysis") as demo:
# gr.Markdown("## ⚡ Fast V3 Layout Analysis (ONNX)")
# gr.Markdown(f"Running `{onnx_filename}` via ONNX Runtime (800x800).")
# with gr.Row():
# with gr.Column():
# input_img = gr.Image(type="pil", label="Input Document")
# submit_btn = gr.Button("Analyze Layout", variant="primary")
# with gr.Column():
# output_img = gr.Image(label="Layout Visualization")
# output_log = gr.Textbox(label="Detections", lines=10)
# submit_btn.click(fn=analyze_layout, inputs=input_img, outputs=[output_img, output_log])
# if __name__ == "__main__":
# demo.launch(server_name="0.0.0.0", server_port=7860)
import gradio as gr
import cv2
import numpy as np
import onnxruntime as ort
from huggingface_hub import hf_hub_download, list_repo_files
# --- STEP 1: Find and Download Model ---
REPO_ID = "alex-dinh/PP-DocLayoutV3-ONNX"
print(f"Searching for ONNX model in {REPO_ID}...")
all_files = list_repo_files(repo_id=REPO_ID)
onnx_filename = next((f for f in all_files if f.endswith('.onnx')), None)
if onnx_filename is None:
raise FileNotFoundError("No .onnx file found in repo.")
print(f"Found model file: {onnx_filename}")
model_path = hf_hub_download(repo_id=REPO_ID, filename=onnx_filename)
# --- STEP 2: Initialize Session ---
session = ort.InferenceSession(model_path)
model_inputs = session.get_inputs()
input_names = [i.name for i in model_inputs]
output_names = [o.name for o in session.get_outputs()]
print(f"Model expects inputs: {input_names}")
LABELS = {0: "Text", 1: "Title", 2: "List", 3: "Table", 4: "Figure"}
def preprocess_image(image, target_size=(800, 800)):
h, w = image.shape[:2]
# 1. Resize
img_resized = cv2.resize(image, target_size, interpolation=cv2.INTER_LINEAR)
# 2. Normalize
img_data = img_resized.astype(np.float32) / 255.0
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
img_data = (img_data - mean) / std
# 3. Transpose (HWC -> CHW)
img_data = img_data.transpose(2, 0, 1)[None, :, :, :]
# 4. Prepare Metadata Inputs
# Scale Factor: Ratio of resized / original
scale_factor = np.array([target_size[0] / h, target_size[1] / w], dtype=np.float32).reshape(1, 2)
# --- DEBUG CHANGE: Try passing target_size as im_shape ---
# Some exports want the INPUT size (800,800), not the ORIGINAL size.
im_shape = np.array([target_size[0], target_size[1]], dtype=np.float32).reshape(1, 2)
return img_data, scale_factor, im_shape
def analyze_layout(input_image):
if input_image is None:
return None, "No image uploaded"
image_np = np.array(input_image)
# --- INFERENCE ---
img_blob, scale_factor, im_shape = preprocess_image(image_np)
inputs = {}
for i in model_inputs:
name = i.name
if 'image' in name:
inputs[name] = img_blob
elif 'scale' in name:
inputs[name] = scale_factor
elif 'shape' in name:
inputs[name] = im_shape
outputs = session.run(output_names, inputs)
detections = outputs[0]
if len(detections.shape) == 3:
detections = detections[0]
# --- RAW DEBUG LOGGING ---
print(f"\n[DEBUG] Raw Detections Shape: {detections.shape}")
print(f"[DEBUG] Top 3 Raw Detections (Class, Score, BBox):")
for i in range(min(3, len(detections))):
print(f" {detections[i]}")
viz_image = image_np.copy()
log = []
# Sort by score descending to find the best ones
# detections = detections[detections[:, 1].argsort()[::-1]]
for det in detections:
score = det[1]
# Lower threshold strictly for debugging
if score < 0.3: continue
class_id = int(det[0])
bbox = det[2:]
# Map labels
label_name = LABELS.get(class_id, f"Class {class_id}")
try:
x1, y1, x2, y2 = map(int, bbox)
# Color coding
color = (0, 255, 0) # Green
if "Title" in label_name: color = (0, 0, 255)
elif "Table" in label_name: color = (255, 255, 0)
elif "Figure" in label_name: color = (255, 0, 0)
cv2.rectangle(viz_image, (x1, y1), (x2, y2), color, 3)
label_text = f"{label_name} {score:.2f}"
(w, h), _ = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)
cv2.rectangle(viz_image, (x1, y1 - 20), (x1 + w, y1), color, -1)
cv2.putText(viz_image, label_text, (x1, y1 - 5),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
log.append(f"Found {label_name} at [{x1}, {y1}, {x2}, {y2}] (Conf: {score:.2f})")
except: pass
if not log:
log.append("No layout regions detected above threshold.")
return viz_image, "\n".join(log)
with gr.Blocks(title="ONNX Layout Analysis (Debug)") as demo:
gr.Markdown("## ⚡ Layout Analysis (Debug Mode)")
with gr.Row():
with gr.Column():
input_img = gr.Image(type="pil", label="Input Document")
submit_btn = gr.Button("Analyze Layout", variant="primary")
with gr.Column():
output_img = gr.Image(label="Layout Visualization")
output_log = gr.Textbox(label="Detections", lines=10)
submit_btn.click(fn=analyze_layout, inputs=input_img, outputs=[output_img, output_log])
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)