# import gradio as gr # import cv2 # import numpy as np # import onnxruntime as ort # from huggingface_hub import hf_hub_download, list_repo_files # # --- STEP 1: Find and Download Model --- # REPO_ID = "alex-dinh/PP-DocLayoutV3-ONNX" # print(f"Searching for ONNX model in {REPO_ID}...") # all_files = list_repo_files(repo_id=REPO_ID) # onnx_filename = next((f for f in all_files if f.endswith('.onnx')), None) # if onnx_filename is None: # raise FileNotFoundError("No .onnx file found in repo.") # print(f"Found model file: {onnx_filename}") # model_path = hf_hub_download(repo_id=REPO_ID, filename=onnx_filename) # # --- STEP 2: Initialize Session --- # session = ort.InferenceSession(model_path) # model_inputs = session.get_inputs() # input_names = [i.name for i in model_inputs] # output_names = [o.name for o in session.get_outputs()] # print(f"Model expects inputs: {input_names}") # LABELS = {0: "Text", 1: "Title", 2: "List", 3: "Table", 4: "Figure"} # # --- FIX: Hardcode target_size to 800x800 --- # # The ONNX graph requires exactly this dimension. # def preprocess_image(image, target_size=(800, 800)): # h, w = image.shape[:2] # # 1. Resize # # We use linear interpolation to ensure smooth gradients # img_resized = cv2.resize(image, target_size, interpolation=cv2.INTER_LINEAR) # # 2. Normalize # img_data = img_resized.astype(np.float32) / 255.0 # mean = np.array([0.485, 0.456, 0.406], dtype=np.float32) # std = np.array([0.229, 0.224, 0.225], dtype=np.float32) # img_data = (img_data - mean) / std # # 3. Transpose (HWC -> CHW) # img_data = img_data.transpose(2, 0, 1)[None, :, :, :] # # 4. Prepare Metadata Inputs # # scale_factor = resized_shape / original_shape # scale_factor = np.array([target_size[0] / h, target_size[1] / w], dtype=np.float32).reshape(1, 2) # # im_shape needs to be the input size (800, 800) # im_shape = np.array([target_size[0], target_size[1]], dtype=np.float32).reshape(1, 2) # return img_data, scale_factor, im_shape # def analyze_layout(input_image): # if input_image is None: # return None, "No image uploaded" # image_np = np.array(input_image) # # --- INFERENCE --- # # This will now return an 800x800 blob # img_blob, scale_factor, im_shape = preprocess_image(image_np) # inputs = {} # for i in model_inputs: # name = i.name # if 'image' in name: # inputs[name] = img_blob # elif 'scale' in name: # inputs[name] = scale_factor # elif 'shape' in name: # inputs[name] = im_shape # # Run ONNX # outputs = session.run(output_names, inputs) # # --- PARSE RESULTS --- # # Output is [Batch, N, 6] -> [Class, Score, X1, Y1, X2, Y2] # detections = outputs[0] # if len(detections.shape) == 3: # detections = detections[0] # viz_image = image_np.copy() # log = [] # for det in detections: # score = det[1] # if score < 0.45: continue # class_id = int(det[0]) # bbox = det[2:] # # Map labels # label_name = LABELS.get(class_id, f"Class {class_id}") # # Draw Box # try: # x1, y1, x2, y2 = map(int, bbox) # # Color coding # color = (0, 255, 0) # Green # if "Title" in label_name: color = (0, 0, 255) # elif "Table" in label_name: color = (255, 255, 0) # elif "Figure" in label_name: color = (255, 0, 0) # cv2.rectangle(viz_image, (x1, y1), (x2, y2), color, 3) # label_text = f"{label_name} {score:.2f}" # (w, h), _ = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2) # cv2.rectangle(viz_image, (x1, y1 - 20), (x1 + w, y1), color, -1) # cv2.putText(viz_image, label_text, (x1, y1 - 5), # cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2) # log.append(f"Found {label_name} at [{x1}, {y1}, {x2}, {y2}]") # except: pass # return viz_image, "\n".join(log) # with gr.Blocks(title="ONNX Layout Analysis") as demo: # gr.Markdown("## ⚡ Fast V3 Layout Analysis (ONNX)") # gr.Markdown(f"Running `{onnx_filename}` via ONNX Runtime (800x800).") # with gr.Row(): # with gr.Column(): # input_img = gr.Image(type="pil", label="Input Document") # submit_btn = gr.Button("Analyze Layout", variant="primary") # with gr.Column(): # output_img = gr.Image(label="Layout Visualization") # output_log = gr.Textbox(label="Detections", lines=10) # submit_btn.click(fn=analyze_layout, inputs=input_img, outputs=[output_img, output_log]) # if __name__ == "__main__": # demo.launch(server_name="0.0.0.0", server_port=7860) import gradio as gr import cv2 import numpy as np import onnxruntime as ort from huggingface_hub import hf_hub_download, list_repo_files # --- STEP 1: Find and Download Model --- REPO_ID = "alex-dinh/PP-DocLayoutV3-ONNX" print(f"Searching for ONNX model in {REPO_ID}...") all_files = list_repo_files(repo_id=REPO_ID) onnx_filename = next((f for f in all_files if f.endswith('.onnx')), None) if onnx_filename is None: raise FileNotFoundError("No .onnx file found in repo.") print(f"Found model file: {onnx_filename}") model_path = hf_hub_download(repo_id=REPO_ID, filename=onnx_filename) # --- STEP 2: Initialize Session --- session = ort.InferenceSession(model_path) model_inputs = session.get_inputs() input_names = [i.name for i in model_inputs] output_names = [o.name for o in session.get_outputs()] print(f"Model expects inputs: {input_names}") LABELS = {0: "Text", 1: "Title", 2: "List", 3: "Table", 4: "Figure"} def preprocess_image(image, target_size=(800, 800)): h, w = image.shape[:2] # 1. Resize img_resized = cv2.resize(image, target_size, interpolation=cv2.INTER_LINEAR) # 2. Normalize img_data = img_resized.astype(np.float32) / 255.0 mean = np.array([0.485, 0.456, 0.406], dtype=np.float32) std = np.array([0.229, 0.224, 0.225], dtype=np.float32) img_data = (img_data - mean) / std # 3. Transpose (HWC -> CHW) img_data = img_data.transpose(2, 0, 1)[None, :, :, :] # 4. Prepare Metadata Inputs # Scale Factor: Ratio of resized / original scale_factor = np.array([target_size[0] / h, target_size[1] / w], dtype=np.float32).reshape(1, 2) # --- DEBUG CHANGE: Try passing target_size as im_shape --- # Some exports want the INPUT size (800,800), not the ORIGINAL size. im_shape = np.array([target_size[0], target_size[1]], dtype=np.float32).reshape(1, 2) return img_data, scale_factor, im_shape def analyze_layout(input_image): if input_image is None: return None, "No image uploaded" image_np = np.array(input_image) # --- INFERENCE --- img_blob, scale_factor, im_shape = preprocess_image(image_np) inputs = {} for i in model_inputs: name = i.name if 'image' in name: inputs[name] = img_blob elif 'scale' in name: inputs[name] = scale_factor elif 'shape' in name: inputs[name] = im_shape outputs = session.run(output_names, inputs) detections = outputs[0] if len(detections.shape) == 3: detections = detections[0] # --- RAW DEBUG LOGGING --- print(f"\n[DEBUG] Raw Detections Shape: {detections.shape}") print(f"[DEBUG] Top 3 Raw Detections (Class, Score, BBox):") for i in range(min(3, len(detections))): print(f" {detections[i]}") viz_image = image_np.copy() log = [] # Sort by score descending to find the best ones # detections = detections[detections[:, 1].argsort()[::-1]] for det in detections: score = det[1] # Lower threshold strictly for debugging if score < 0.3: continue class_id = int(det[0]) bbox = det[2:] # Map labels label_name = LABELS.get(class_id, f"Class {class_id}") try: x1, y1, x2, y2 = map(int, bbox) # Color coding color = (0, 255, 0) # Green if "Title" in label_name: color = (0, 0, 255) elif "Table" in label_name: color = (255, 255, 0) elif "Figure" in label_name: color = (255, 0, 0) cv2.rectangle(viz_image, (x1, y1), (x2, y2), color, 3) label_text = f"{label_name} {score:.2f}" (w, h), _ = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2) cv2.rectangle(viz_image, (x1, y1 - 20), (x1 + w, y1), color, -1) cv2.putText(viz_image, label_text, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2) log.append(f"Found {label_name} at [{x1}, {y1}, {x2}, {y2}] (Conf: {score:.2f})") except: pass if not log: log.append("No layout regions detected above threshold.") return viz_image, "\n".join(log) with gr.Blocks(title="ONNX Layout Analysis (Debug)") as demo: gr.Markdown("## ⚡ Layout Analysis (Debug Mode)") with gr.Row(): with gr.Column(): input_img = gr.Image(type="pil", label="Input Document") submit_btn = gr.Button("Analyze Layout", variant="primary") with gr.Column(): output_img = gr.Image(label="Layout Visualization") output_log = gr.Textbox(label="Detections", lines=10) submit_btn.click(fn=analyze_layout, inputs=input_img, outputs=[output_img, output_log]) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)