| # import gradio as gr | |
| # import cv2 | |
| # import numpy as np | |
| # import onnxruntime as ort | |
| # from huggingface_hub import hf_hub_download, list_repo_files | |
| # # --- STEP 1: Find and Download Model --- | |
| # REPO_ID = "alex-dinh/PP-DocLayoutV3-ONNX" | |
| # print(f"Searching for ONNX model in {REPO_ID}...") | |
| # all_files = list_repo_files(repo_id=REPO_ID) | |
| # onnx_filename = next((f for f in all_files if f.endswith('.onnx')), None) | |
| # if onnx_filename is None: | |
| # raise FileNotFoundError("No .onnx file found in repo.") | |
| # print(f"Found model file: {onnx_filename}") | |
| # model_path = hf_hub_download(repo_id=REPO_ID, filename=onnx_filename) | |
| # # --- STEP 2: Initialize Session --- | |
| # session = ort.InferenceSession(model_path) | |
| # model_inputs = session.get_inputs() | |
| # input_names = [i.name for i in model_inputs] | |
| # output_names = [o.name for o in session.get_outputs()] | |
| # print(f"Model expects inputs: {input_names}") | |
| # LABELS = {0: "Text", 1: "Title", 2: "List", 3: "Table", 4: "Figure"} | |
| # # --- FIX: Hardcode target_size to 800x800 --- | |
| # # The ONNX graph requires exactly this dimension. | |
| # def preprocess_image(image, target_size=(800, 800)): | |
| # h, w = image.shape[:2] | |
| # # 1. Resize | |
| # # We use linear interpolation to ensure smooth gradients | |
| # img_resized = cv2.resize(image, target_size, interpolation=cv2.INTER_LINEAR) | |
| # # 2. Normalize | |
| # img_data = img_resized.astype(np.float32) / 255.0 | |
| # mean = np.array([0.485, 0.456, 0.406], dtype=np.float32) | |
| # std = np.array([0.229, 0.224, 0.225], dtype=np.float32) | |
| # img_data = (img_data - mean) / std | |
| # # 3. Transpose (HWC -> CHW) | |
| # img_data = img_data.transpose(2, 0, 1)[None, :, :, :] | |
| # # 4. Prepare Metadata Inputs | |
| # # scale_factor = resized_shape / original_shape | |
| # scale_factor = np.array([target_size[0] / h, target_size[1] / w], dtype=np.float32).reshape(1, 2) | |
| # # im_shape needs to be the input size (800, 800) | |
| # im_shape = np.array([target_size[0], target_size[1]], dtype=np.float32).reshape(1, 2) | |
| # return img_data, scale_factor, im_shape | |
| # def analyze_layout(input_image): | |
| # if input_image is None: | |
| # return None, "No image uploaded" | |
| # image_np = np.array(input_image) | |
| # # --- INFERENCE --- | |
| # # This will now return an 800x800 blob | |
| # img_blob, scale_factor, im_shape = preprocess_image(image_np) | |
| # inputs = {} | |
| # for i in model_inputs: | |
| # name = i.name | |
| # if 'image' in name: | |
| # inputs[name] = img_blob | |
| # elif 'scale' in name: | |
| # inputs[name] = scale_factor | |
| # elif 'shape' in name: | |
| # inputs[name] = im_shape | |
| # # Run ONNX | |
| # outputs = session.run(output_names, inputs) | |
| # # --- PARSE RESULTS --- | |
| # # Output is [Batch, N, 6] -> [Class, Score, X1, Y1, X2, Y2] | |
| # detections = outputs[0] | |
| # if len(detections.shape) == 3: | |
| # detections = detections[0] | |
| # viz_image = image_np.copy() | |
| # log = [] | |
| # for det in detections: | |
| # score = det[1] | |
| # if score < 0.45: continue | |
| # class_id = int(det[0]) | |
| # bbox = det[2:] | |
| # # Map labels | |
| # label_name = LABELS.get(class_id, f"Class {class_id}") | |
| # # Draw Box | |
| # try: | |
| # x1, y1, x2, y2 = map(int, bbox) | |
| # # Color coding | |
| # color = (0, 255, 0) # Green | |
| # if "Title" in label_name: color = (0, 0, 255) | |
| # elif "Table" in label_name: color = (255, 255, 0) | |
| # elif "Figure" in label_name: color = (255, 0, 0) | |
| # cv2.rectangle(viz_image, (x1, y1), (x2, y2), color, 3) | |
| # label_text = f"{label_name} {score:.2f}" | |
| # (w, h), _ = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2) | |
| # cv2.rectangle(viz_image, (x1, y1 - 20), (x1 + w, y1), color, -1) | |
| # cv2.putText(viz_image, label_text, (x1, y1 - 5), | |
| # cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2) | |
| # log.append(f"Found {label_name} at [{x1}, {y1}, {x2}, {y2}]") | |
| # except: pass | |
| # return viz_image, "\n".join(log) | |
| # with gr.Blocks(title="ONNX Layout Analysis") as demo: | |
| # gr.Markdown("## ⚡ Fast V3 Layout Analysis (ONNX)") | |
| # gr.Markdown(f"Running `{onnx_filename}` via ONNX Runtime (800x800).") | |
| # with gr.Row(): | |
| # with gr.Column(): | |
| # input_img = gr.Image(type="pil", label="Input Document") | |
| # submit_btn = gr.Button("Analyze Layout", variant="primary") | |
| # with gr.Column(): | |
| # output_img = gr.Image(label="Layout Visualization") | |
| # output_log = gr.Textbox(label="Detections", lines=10) | |
| # submit_btn.click(fn=analyze_layout, inputs=input_img, outputs=[output_img, output_log]) | |
| # if __name__ == "__main__": | |
| # demo.launch(server_name="0.0.0.0", server_port=7860) | |
| import gradio as gr | |
| import cv2 | |
| import numpy as np | |
| import onnxruntime as ort | |
| from huggingface_hub import hf_hub_download, list_repo_files | |
| # --- STEP 1: Find and Download Model --- | |
| REPO_ID = "alex-dinh/PP-DocLayoutV3-ONNX" | |
| print(f"Searching for ONNX model in {REPO_ID}...") | |
| all_files = list_repo_files(repo_id=REPO_ID) | |
| onnx_filename = next((f for f in all_files if f.endswith('.onnx')), None) | |
| if onnx_filename is None: | |
| raise FileNotFoundError("No .onnx file found in repo.") | |
| print(f"Found model file: {onnx_filename}") | |
| model_path = hf_hub_download(repo_id=REPO_ID, filename=onnx_filename) | |
| # --- STEP 2: Initialize Session --- | |
| session = ort.InferenceSession(model_path) | |
| model_inputs = session.get_inputs() | |
| input_names = [i.name for i in model_inputs] | |
| output_names = [o.name for o in session.get_outputs()] | |
| print(f"Model expects inputs: {input_names}") | |
| LABELS = {0: "Text", 1: "Title", 2: "List", 3: "Table", 4: "Figure"} | |
| def preprocess_image(image, target_size=(800, 800)): | |
| h, w = image.shape[:2] | |
| # 1. Resize | |
| img_resized = cv2.resize(image, target_size, interpolation=cv2.INTER_LINEAR) | |
| # 2. Normalize | |
| img_data = img_resized.astype(np.float32) / 255.0 | |
| mean = np.array([0.485, 0.456, 0.406], dtype=np.float32) | |
| std = np.array([0.229, 0.224, 0.225], dtype=np.float32) | |
| img_data = (img_data - mean) / std | |
| # 3. Transpose (HWC -> CHW) | |
| img_data = img_data.transpose(2, 0, 1)[None, :, :, :] | |
| # 4. Prepare Metadata Inputs | |
| # Scale Factor: Ratio of resized / original | |
| scale_factor = np.array([target_size[0] / h, target_size[1] / w], dtype=np.float32).reshape(1, 2) | |
| # --- DEBUG CHANGE: Try passing target_size as im_shape --- | |
| # Some exports want the INPUT size (800,800), not the ORIGINAL size. | |
| im_shape = np.array([target_size[0], target_size[1]], dtype=np.float32).reshape(1, 2) | |
| return img_data, scale_factor, im_shape | |
| def analyze_layout(input_image): | |
| if input_image is None: | |
| return None, "No image uploaded" | |
| image_np = np.array(input_image) | |
| # --- INFERENCE --- | |
| img_blob, scale_factor, im_shape = preprocess_image(image_np) | |
| inputs = {} | |
| for i in model_inputs: | |
| name = i.name | |
| if 'image' in name: | |
| inputs[name] = img_blob | |
| elif 'scale' in name: | |
| inputs[name] = scale_factor | |
| elif 'shape' in name: | |
| inputs[name] = im_shape | |
| outputs = session.run(output_names, inputs) | |
| detections = outputs[0] | |
| if len(detections.shape) == 3: | |
| detections = detections[0] | |
| # --- RAW DEBUG LOGGING --- | |
| print(f"\n[DEBUG] Raw Detections Shape: {detections.shape}") | |
| print(f"[DEBUG] Top 3 Raw Detections (Class, Score, BBox):") | |
| for i in range(min(3, len(detections))): | |
| print(f" {detections[i]}") | |
| viz_image = image_np.copy() | |
| log = [] | |
| # Sort by score descending to find the best ones | |
| # detections = detections[detections[:, 1].argsort()[::-1]] | |
| for det in detections: | |
| score = det[1] | |
| # Lower threshold strictly for debugging | |
| if score < 0.3: continue | |
| class_id = int(det[0]) | |
| bbox = det[2:] | |
| # Map labels | |
| label_name = LABELS.get(class_id, f"Class {class_id}") | |
| try: | |
| x1, y1, x2, y2 = map(int, bbox) | |
| # Color coding | |
| color = (0, 255, 0) # Green | |
| if "Title" in label_name: color = (0, 0, 255) | |
| elif "Table" in label_name: color = (255, 255, 0) | |
| elif "Figure" in label_name: color = (255, 0, 0) | |
| cv2.rectangle(viz_image, (x1, y1), (x2, y2), color, 3) | |
| label_text = f"{label_name} {score:.2f}" | |
| (w, h), _ = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2) | |
| cv2.rectangle(viz_image, (x1, y1 - 20), (x1 + w, y1), color, -1) | |
| cv2.putText(viz_image, label_text, (x1, y1 - 5), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2) | |
| log.append(f"Found {label_name} at [{x1}, {y1}, {x2}, {y2}] (Conf: {score:.2f})") | |
| except: pass | |
| if not log: | |
| log.append("No layout regions detected above threshold.") | |
| return viz_image, "\n".join(log) | |
| with gr.Blocks(title="ONNX Layout Analysis (Debug)") as demo: | |
| gr.Markdown("## ⚡ Layout Analysis (Debug Mode)") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_img = gr.Image(type="pil", label="Input Document") | |
| submit_btn = gr.Button("Analyze Layout", variant="primary") | |
| with gr.Column(): | |
| output_img = gr.Image(label="Layout Visualization") | |
| output_log = gr.Textbox(label="Detections", lines=10) | |
| submit_btn.click(fn=analyze_layout, inputs=input_img, outputs=[output_img, output_log]) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |