Spaces:
Running
Running
| import os | |
| os.environ["GRADIO_TEMP_DIR"] = "./tmp" | |
| import sys | |
| import torch | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image, ImageDraw, ImageFont | |
| from transformers import ( | |
| DFineForObjectDetection, | |
| RTDetrImageProcessor, | |
| ) | |
| # == select device == | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| # Available models | |
| MODELS = { | |
| "Egret XLarge": "ds4sd/docling-layout-egret-xlarge", | |
| "Egret Large": "ds4sd/docling-layout-egret-large", | |
| "Egret Medium": "ds4sd/docling-layout-egret-medium", | |
| "Heron 101": "ds4sd/docling-layout-heron-101", | |
| "Heron": "ds4sd/docling-layout-heron" | |
| } | |
| # Classes mapping for the docling model | |
| classes_map = { | |
| 0: "Caption", | |
| 1: "Footnote", | |
| 2: "Formula", | |
| 3: "List-item", | |
| 4: "Page-footer", | |
| 5: "Page-header", | |
| 6: "Picture", | |
| 7: "Section-header", | |
| 8: "Table", | |
| 9: "Text", | |
| 10: "Title", | |
| 11: "Document Index", | |
| 12: "Code", | |
| 13: "Checkbox-Selected", | |
| 14: "Checkbox-Unselected", | |
| 15: "Form", | |
| 16: "Key-Value Region", | |
| } | |
| # Color mapping for visualization | |
| colors = [ | |
| "#FF6B6B", "#4ECDC4", "#45B7D1", "#96CEB4", "#FECA57", | |
| "#FF9FF3", "#54A0FF", "#5F27CD", "#00D2D3", "#FF9F43", | |
| "#10AC84", "#EE5A24", "#0ABDE3", "#006BA6", "#F79F1F", | |
| "#A3CB38", "#FDA7DF" | |
| ] | |
| # Global variables for model | |
| current_model = None | |
| current_processor = None | |
| current_model_name = None | |
| def iomin(box1, box2): | |
| """ | |
| Intersection over Minimum (IoMin) | |
| box1: Tensor[1, 4] | |
| box2: Tensor[N, 4] | |
| Returns: Tensor[N] | |
| """ | |
| # Intersection | |
| x1 = torch.max(box1[:, 0], box2[:, 0]) | |
| y1 = torch.max(box1[:, 1], box2[:, 1]) | |
| x2 = torch.min(box1[:, 2], box2[:, 2]) | |
| y2 = torch.min(box1[:, 3], box2[:, 3]) | |
| inter_area = torch.clamp(x2 - x1, min=0) * torch.clamp(y2 - y1, min=0) | |
| # Areas | |
| box1_area = (box1[:, 2] - box1[:, 0]) * (box1[:, 3] - box1[:, 1]) | |
| box2_area = (box2[:, 2] - box2[:, 0]) * (box2[:, 3] - box2[:, 1]) | |
| min_area = torch.min(box1_area, box2_area) | |
| return inter_area / min_area | |
| def nms(boxes, scores, iou_threshold=0.5): | |
| """ | |
| Custom NMS implementation using IoMin | |
| """ | |
| keep = [] | |
| _, order = scores.sort(descending=True) | |
| while order.numel() > 0: | |
| i = order[0] | |
| keep.append(i.item()) | |
| if order.numel() == 1: | |
| break | |
| box_i = boxes[i].unsqueeze(0) # [1, 4] | |
| rest = order[1:] | |
| ious = iomin(box_i, boxes[rest]) | |
| mask = (ious <= iou_threshold) | |
| order = order[1:][mask] | |
| return torch.tensor(keep, dtype=torch.long) | |
| def load_model(model_name): | |
| """ | |
| Load the selected model | |
| """ | |
| global current_model, current_processor, current_model_name | |
| if current_model_name == model_name: | |
| return f"β Model {model_name} is already loaded!" | |
| try: | |
| print(f"Loading model: {model_name}") | |
| model_path = MODELS[model_name] | |
| processor = RTDetrImageProcessor.from_pretrained(model_path) | |
| model = DFineForObjectDetection.from_pretrained(model_path) | |
| model = model.to(device) | |
| model.eval() | |
| current_processor = processor | |
| current_model = model | |
| current_model_name = model_name | |
| return f"β Successfully loaded {model_name}!" | |
| except Exception as e: | |
| return f"β Error loading {model_name}: {str(e)}" | |
| def visualize_bbox(image, boxes, labels, scores, classes_map, colors): | |
| """ | |
| Visualize bounding boxes on image | |
| """ | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| elif not isinstance(image, Image.Image): | |
| raise ValueError("Input image must be PIL Image or numpy array") | |
| # Create a copy to draw on | |
| draw_image = image.copy() | |
| draw = ImageDraw.Draw(draw_image) | |
| # Try to use a font, fallback to default if not available | |
| try: | |
| font = ImageFont.truetype("arial.ttf", 20) | |
| except: | |
| try: | |
| font = ImageFont.load_default() | |
| except: | |
| font = None | |
| for box, label_id, score in zip(boxes, labels, scores): | |
| # Convert tensor to int if needed | |
| if torch.is_tensor(label_id): | |
| label_id = label_id.item() | |
| if torch.is_tensor(score): | |
| score = score.item() | |
| label = classes_map.get(int(label_id), f"Class_{label_id}") | |
| color = colors[int(label_id) % len(colors)] | |
| # Convert box coordinates to integers | |
| x1, y1, x2, y2 = [int(coord) for coord in box] | |
| # Draw rectangle | |
| draw.rectangle([x1, y1, x2, y2], outline=color, width=3) | |
| # Draw label background | |
| text = f"{label}: {score:.2f}" | |
| if font: | |
| bbox = draw.textbbox((x1, y1), text, font=font) | |
| text_width = bbox[2] - bbox[0] | |
| text_height = bbox[3] - bbox[1] | |
| else: | |
| # Estimate text size if no font available | |
| text_width = len(text) * 10 | |
| text_height = 20 | |
| draw.rectangle([x1, y1-text_height-4, x1+text_width+4, y1], fill=color) | |
| draw.text((x1+2, y1-text_height-2), text, fill="white", font=font) | |
| return np.array(draw_image) | |
| def recognize_image(input_img, conf_threshold, iou_threshold, nms_method): | |
| """ | |
| Process image with docling layout model | |
| """ | |
| if input_img is None: | |
| return None, "Please upload an image first." | |
| if current_model is None or current_processor is None: | |
| return None, "Please load a model first." | |
| try: | |
| # Ensure image is PIL Image | |
| if isinstance(input_img, np.ndarray): | |
| input_img = Image.fromarray(input_img) | |
| # Convert to RGB if needed | |
| if input_img.mode != 'RGB': | |
| input_img = input_img.convert('RGB') | |
| # Process image | |
| inputs = current_processor(images=[input_img], return_tensors="pt") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| # Run inference | |
| with torch.no_grad(): | |
| outputs = current_model(**inputs) | |
| # Post-process results | |
| results = current_processor.post_process_object_detection( | |
| outputs, | |
| target_sizes=torch.tensor([input_img.size[::-1]]), | |
| threshold=conf_threshold, | |
| ) | |
| if not results or len(results) == 0: | |
| return np.array(input_img), "No detections found." | |
| result = results[0] | |
| # Get results | |
| boxes = result["boxes"] | |
| scores = result["scores"] | |
| labels = result["labels"] | |
| if len(boxes) == 0: | |
| return np.array(input_img), "No detections above confidence threshold." | |
| # Apply NMS if requested | |
| if iou_threshold < 1.0: | |
| if nms_method == "Custom IoMin": | |
| # Use custom NMS with IoMin | |
| keep_indices = nms( | |
| boxes=boxes, | |
| scores=scores, | |
| iou_threshold=iou_threshold | |
| ) | |
| else: | |
| # Use standard torchvision NMS | |
| keep_indices = torch.ops.torchvision.nms( | |
| boxes=boxes, | |
| scores=scores, | |
| iou_threshold=iou_threshold | |
| ) | |
| boxes = boxes[keep_indices] | |
| scores = scores[keep_indices] | |
| labels = labels[keep_indices] | |
| # Handle single detection case | |
| if len(boxes.shape) == 1: | |
| boxes = boxes.unsqueeze(0) | |
| scores = scores.unsqueeze(0) | |
| labels = labels.unsqueeze(0) | |
| # Visualize results | |
| output = visualize_bbox( | |
| input_img, | |
| boxes, | |
| labels, | |
| scores, | |
| classes_map, | |
| colors | |
| ) | |
| detection_info = f"Found {len(boxes)} detections after NMS ({nms_method})" | |
| return output, detection_info | |
| except Exception as e: | |
| print(f"[ERROR] recognize_image failed: {e}") | |
| error_msg = f"Error during processing: {str(e)}" | |
| # Return original image on error | |
| if input_img is not None: | |
| return np.array(input_img), error_msg | |
| return np.zeros((512, 512, 3), dtype=np.uint8), error_msg | |
| def gradio_reset(): | |
| return gr.update(value=None), gr.update(value=None), gr.update(value="") | |
| if __name__ == "__main__": | |
| print(f"Using device: {device}") | |
| # Create header HTML | |
| header_html = """ | |
| <div style="text-align: center; margin-bottom: 20px;"> | |
| <h1>π Document Layout Analysis</h1> | |
| <p>Using Docling Layout Models for document structure detection</p> | |
| <p>Select a model, upload an image and adjust the parameters to detect document elements</p> | |
| </div> | |
| """ | |
| with gr.Blocks(title="Document Layout Analysis", theme=gr.themes.Soft()) as demo: | |
| gr.HTML(header_html) | |
| with gr.Row(): | |
| with gr.Column(): | |
| # Model selection | |
| model_dropdown = gr.Dropdown( | |
| choices=list(MODELS.keys()), | |
| value="Egret XLarge", | |
| label="π€ Select Model", | |
| info="Choose which Docling model to use" | |
| ) | |
| load_btn = gr.Button("π₯ Load Model", variant="secondary") | |
| model_status = gr.Textbox( | |
| label="Model Status", | |
| interactive=False, | |
| value="No model loaded" | |
| ) | |
| input_img = gr.Image( | |
| label="π Upload Document Image", | |
| interactive=True, | |
| type="pil" | |
| ) | |
| with gr.Row(): | |
| clear = gr.Button("ποΈ Clear") | |
| predict = gr.Button("π Detect Layout", interactive=True, variant="primary") | |
| with gr.Row(): | |
| conf_threshold = gr.Slider( | |
| label="Confidence Threshold", | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.05, | |
| value=0.6, | |
| info="Minimum confidence score for detections" | |
| ) | |
| with gr.Row(): | |
| iou_threshold = gr.Slider( | |
| label="NMS IoU Threshold", | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.05, | |
| value=0.5, | |
| info="IoU threshold for Non-Maximum Suppression" | |
| ) | |
| nms_method = gr.Radio( | |
| choices=["Custom IoMin", "Standard IoU"], | |
| value="Custom IoMin", | |
| label="NMS Method", | |
| info="Choose NMS algorithm" | |
| ) | |
| # Legend | |
| with gr.Accordion("π Detected Classes", open=False): | |
| legend_html = "<div style='display: grid; grid-template-columns: repeat(2, 1fr); gap: 10px;'>" | |
| for class_id, class_name in classes_map.items(): | |
| color = colors[class_id % len(colors)] | |
| legend_html += f""" | |
| <div style='display: flex; align-items: center; padding: 5px;'> | |
| <div style='width: 20px; height: 20px; background-color: {color}; margin-right: 10px; border: 1px solid #ccc;'></div> | |
| <span>{class_name}</span> | |
| </div> | |
| """ | |
| legend_html += "</div>" | |
| gr.HTML(legend_html) | |
| with gr.Column(): | |
| gr.HTML("<h3>π― Detection Results</h3>") | |
| output_img = gr.Image( | |
| label="Detected Layout", | |
| interactive=False, | |
| type="numpy" | |
| ) | |
| detection_info = gr.Textbox( | |
| label="Detection Info", | |
| interactive=False, | |
| value="" | |
| ) | |
| # Event handlers | |
| load_btn.click( | |
| load_model, | |
| inputs=[model_dropdown], | |
| outputs=[model_status] | |
| ) | |
| clear.click( | |
| gradio_reset, | |
| inputs=None, | |
| outputs=[input_img, output_img, detection_info] | |
| ) | |
| predict.click( | |
| recognize_image, | |
| inputs=[input_img, conf_threshold, iou_threshold, nms_method], | |
| outputs=[output_img, detection_info] | |
| ) | |
| # Launch the demo | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| debug=True, | |
| share=False | |
| ) |