import os os.environ["GRADIO_TEMP_DIR"] = "./tmp" import sys import torch import gradio as gr import numpy as np from PIL import Image, ImageDraw, ImageFont from transformers import ( DFineForObjectDetection, RTDetrImageProcessor, ) # == select device == device = 'cuda' if torch.cuda.is_available() else 'cpu' # Available models MODELS = { "Egret XLarge": "ds4sd/docling-layout-egret-xlarge", "Egret Large": "ds4sd/docling-layout-egret-large", "Egret Medium": "ds4sd/docling-layout-egret-medium", "Heron 101": "ds4sd/docling-layout-heron-101", "Heron": "ds4sd/docling-layout-heron" } # Classes mapping for the docling model classes_map = { 0: "Caption", 1: "Footnote", 2: "Formula", 3: "List-item", 4: "Page-footer", 5: "Page-header", 6: "Picture", 7: "Section-header", 8: "Table", 9: "Text", 10: "Title", 11: "Document Index", 12: "Code", 13: "Checkbox-Selected", 14: "Checkbox-Unselected", 15: "Form", 16: "Key-Value Region", } # Color mapping for visualization colors = [ "#FF6B6B", "#4ECDC4", "#45B7D1", "#96CEB4", "#FECA57", "#FF9FF3", "#54A0FF", "#5F27CD", "#00D2D3", "#FF9F43", "#10AC84", "#EE5A24", "#0ABDE3", "#006BA6", "#F79F1F", "#A3CB38", "#FDA7DF" ] # Global variables for model current_model = None current_processor = None current_model_name = None def iomin(box1, box2): """ Intersection over Minimum (IoMin) box1: Tensor[1, 4] box2: Tensor[N, 4] Returns: Tensor[N] """ # Intersection x1 = torch.max(box1[:, 0], box2[:, 0]) y1 = torch.max(box1[:, 1], box2[:, 1]) x2 = torch.min(box1[:, 2], box2[:, 2]) y2 = torch.min(box1[:, 3], box2[:, 3]) inter_area = torch.clamp(x2 - x1, min=0) * torch.clamp(y2 - y1, min=0) # Areas box1_area = (box1[:, 2] - box1[:, 0]) * (box1[:, 3] - box1[:, 1]) box2_area = (box2[:, 2] - box2[:, 0]) * (box2[:, 3] - box2[:, 1]) min_area = torch.min(box1_area, box2_area) return inter_area / min_area def nms(boxes, scores, iou_threshold=0.5): """ Custom NMS implementation using IoMin """ keep = [] _, order = scores.sort(descending=True) while order.numel() > 0: i = order[0] keep.append(i.item()) if order.numel() == 1: break box_i = boxes[i].unsqueeze(0) # [1, 4] rest = order[1:] ious = iomin(box_i, boxes[rest]) mask = (ious <= iou_threshold) order = order[1:][mask] return torch.tensor(keep, dtype=torch.long) def load_model(model_name): """ Load the selected model """ global current_model, current_processor, current_model_name if current_model_name == model_name: return f"✅ Model {model_name} is already loaded!" try: print(f"Loading model: {model_name}") model_path = MODELS[model_name] processor = RTDetrImageProcessor.from_pretrained(model_path) model = DFineForObjectDetection.from_pretrained(model_path) model = model.to(device) model.eval() current_processor = processor current_model = model current_model_name = model_name return f"✅ Successfully loaded {model_name}!" except Exception as e: return f"❌ Error loading {model_name}: {str(e)}" def visualize_bbox(image, boxes, labels, scores, classes_map, colors): """ Visualize bounding boxes on image """ if isinstance(image, np.ndarray): image = Image.fromarray(image) elif not isinstance(image, Image.Image): raise ValueError("Input image must be PIL Image or numpy array") # Create a copy to draw on draw_image = image.copy() draw = ImageDraw.Draw(draw_image) # Try to use a font, fallback to default if not available try: font = ImageFont.truetype("arial.ttf", 20) except: try: font = ImageFont.load_default() except: font = None for box, label_id, score in zip(boxes, labels, scores): # Convert tensor to int if needed if torch.is_tensor(label_id): label_id = label_id.item() if torch.is_tensor(score): score = score.item() label = classes_map.get(int(label_id), f"Class_{label_id}") color = colors[int(label_id) % len(colors)] # Convert box coordinates to integers x1, y1, x2, y2 = [int(coord) for coord in box] # Draw rectangle draw.rectangle([x1, y1, x2, y2], outline=color, width=3) # Draw label background text = f"{label}: {score:.2f}" if font: bbox = draw.textbbox((x1, y1), text, font=font) text_width = bbox[2] - bbox[0] text_height = bbox[3] - bbox[1] else: # Estimate text size if no font available text_width = len(text) * 10 text_height = 20 draw.rectangle([x1, y1-text_height-4, x1+text_width+4, y1], fill=color) draw.text((x1+2, y1-text_height-2), text, fill="white", font=font) return np.array(draw_image) def recognize_image(input_img, conf_threshold, iou_threshold, nms_method): """ Process image with docling layout model """ if input_img is None: return None, "Please upload an image first." if current_model is None or current_processor is None: return None, "Please load a model first." try: # Ensure image is PIL Image if isinstance(input_img, np.ndarray): input_img = Image.fromarray(input_img) # Convert to RGB if needed if input_img.mode != 'RGB': input_img = input_img.convert('RGB') # Process image inputs = current_processor(images=[input_img], return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} # Run inference with torch.no_grad(): outputs = current_model(**inputs) # Post-process results results = current_processor.post_process_object_detection( outputs, target_sizes=torch.tensor([input_img.size[::-1]]), threshold=conf_threshold, ) if not results or len(results) == 0: return np.array(input_img), "No detections found." result = results[0] # Get results boxes = result["boxes"] scores = result["scores"] labels = result["labels"] if len(boxes) == 0: return np.array(input_img), "No detections above confidence threshold." # Apply NMS if requested if iou_threshold < 1.0: if nms_method == "Custom IoMin": # Use custom NMS with IoMin keep_indices = nms( boxes=boxes, scores=scores, iou_threshold=iou_threshold ) else: # Use standard torchvision NMS keep_indices = torch.ops.torchvision.nms( boxes=boxes, scores=scores, iou_threshold=iou_threshold ) boxes = boxes[keep_indices] scores = scores[keep_indices] labels = labels[keep_indices] # Handle single detection case if len(boxes.shape) == 1: boxes = boxes.unsqueeze(0) scores = scores.unsqueeze(0) labels = labels.unsqueeze(0) # Visualize results output = visualize_bbox( input_img, boxes, labels, scores, classes_map, colors ) detection_info = f"Found {len(boxes)} detections after NMS ({nms_method})" return output, detection_info except Exception as e: print(f"[ERROR] recognize_image failed: {e}") error_msg = f"Error during processing: {str(e)}" # Return original image on error if input_img is not None: return np.array(input_img), error_msg return np.zeros((512, 512, 3), dtype=np.uint8), error_msg def gradio_reset(): return gr.update(value=None), gr.update(value=None), gr.update(value="") if __name__ == "__main__": print(f"Using device: {device}") # Create header HTML header_html = """

🔍 Document Layout Analysis

Using Docling Layout Models for document structure detection

Select a model, upload an image and adjust the parameters to detect document elements

""" with gr.Blocks(title="Document Layout Analysis", theme=gr.themes.Soft()) as demo: gr.HTML(header_html) with gr.Row(): with gr.Column(): # Model selection model_dropdown = gr.Dropdown( choices=list(MODELS.keys()), value="Egret XLarge", label="🤖 Select Model", info="Choose which Docling model to use" ) load_btn = gr.Button("📥 Load Model", variant="secondary") model_status = gr.Textbox( label="Model Status", interactive=False, value="No model loaded" ) input_img = gr.Image( label="📄 Upload Document Image", interactive=True, type="pil" ) with gr.Row(): clear = gr.Button("🗑️ Clear") predict = gr.Button("🔍 Detect Layout", interactive=True, variant="primary") with gr.Row(): conf_threshold = gr.Slider( label="Confidence Threshold", minimum=0.0, maximum=1.0, step=0.05, value=0.6, info="Minimum confidence score for detections" ) with gr.Row(): iou_threshold = gr.Slider( label="NMS IoU Threshold", minimum=0.0, maximum=1.0, step=0.05, value=0.5, info="IoU threshold for Non-Maximum Suppression" ) nms_method = gr.Radio( choices=["Custom IoMin", "Standard IoU"], value="Custom IoMin", label="NMS Method", info="Choose NMS algorithm" ) # Legend with gr.Accordion("📋 Detected Classes", open=False): legend_html = "
" for class_id, class_name in classes_map.items(): color = colors[class_id % len(colors)] legend_html += f"""
{class_name}
""" legend_html += "
" gr.HTML(legend_html) with gr.Column(): gr.HTML("

🎯 Detection Results

") output_img = gr.Image( label="Detected Layout", interactive=False, type="numpy" ) detection_info = gr.Textbox( label="Detection Info", interactive=False, value="" ) # Event handlers load_btn.click( load_model, inputs=[model_dropdown], outputs=[model_status] ) clear.click( gradio_reset, inputs=None, outputs=[input_img, output_img, detection_info] ) predict.click( recognize_image, inputs=[input_img, conf_threshold, iou_threshold, nms_method], outputs=[output_img, detection_info] ) # Launch the demo demo.launch( server_name="0.0.0.0", server_port=7860, debug=True, share=False )