import os os.environ["GRADIO_TEMP_DIR"] = "./tmp" import sys import torch import torchvision import gradio as gr import numpy as np import cv2 from PIL import Image from transformers import ( DFineForObjectDetection, RTDetrV2ForObjectDetection, RTDetrImageProcessor, ) # == Device configuration == device = 'cuda' if torch.cuda.is_available() else 'cpu' # == Model configurations == MODELS = { "Egret XLarge": { "path": "ds4sd/docling-layout-egret-xlarge", "model_class": DFineForObjectDetection }, "Egret Large": { "path": "ds4sd/docling-layout-egret-large", "model_class": DFineForObjectDetection }, "Egret Medium": { "path": "ds4sd/docling-layout-egret-medium", "model_class": DFineForObjectDetection }, "Heron 101": { "path": "ds4sd/docling-layout-heron-101", "model_class": RTDetrV2ForObjectDetection }, "Heron": { "path": "ds4sd/docling-layout-heron", "model_class": RTDetrV2ForObjectDetection } } # == Class mappings == classes_map = { 0: "Caption", 1: "Footnote", 2: "Formula", 3: "List-item", 4: "Page-footer", 5: "Page-header", 6: "Picture", 7: "Section-header", 8: "Table", 9: "Text", 10: "Title", 11: "Document Index", 12: "Code", 13: "Checkbox-Selected", 14: "Checkbox-Unselected", 15: "Form", 16: "Key-Value Region", } # == Global model variables == current_model = None current_processor = None current_model_name = None def colormap(N=256, normalized=False): """Generate dynamic colormap.""" def bitget(byteval, idx): return ((byteval & (1 << idx)) != 0) cmap = np.zeros((N, 3), dtype=np.uint8) for i in range(N): r = g = b = 0 c = i for j in range(8): r = r | (bitget(c, 0) << (7 - j)) g = g | (bitget(c, 1) << (7 - j)) b = b | (bitget(c, 2) << (7 - j)) c = c >> 3 cmap[i] = np.array([r, g, b]) if normalized: cmap = cmap.astype(np.float32) / 255.0 return cmap def iomin(box1, box2): """Intersection over Minimum (IoMin).""" x1 = torch.max(box1[:, 0], box2[:, 0]) y1 = torch.max(box1[:, 1], box2[:, 1]) x2 = torch.min(box1[:, 2], box2[:, 2]) y2 = torch.min(box1[:, 3], box2[:, 3]) inter_area = torch.clamp(x2 - x1, min=0) * torch.clamp(y2 - y1, min=0) box1_area = (box1[:, 2] - box1[:, 0]) * (box1[:, 3] - box1[:, 1]) box2_area = (box2[:, 2] - box2[:, 0]) * (box2[:, 3] - box2[:, 1]) min_area = torch.min(box1_area, box2_area) return inter_area / min_area def nms_custom(boxes, scores, iou_threshold=0.5): """Custom NMS implementation using IoMin.""" keep = [] _, order = scores.sort(descending=True) while order.numel() > 0: i = order[0] keep.append(i.item()) if order.numel() == 1: break box_i = boxes[i].unsqueeze(0) rest = order[1:] ious = iomin(box_i, boxes[rest]) mask = (ious <= iou_threshold) order = order[1:][mask] return torch.tensor(keep, dtype=torch.long) def load_model(model_name): """Load the selected model automatically.""" global current_model, current_processor, current_model_name if current_model_name == model_name: return current_model, current_processor try: model_info = MODELS[model_name] model_path = model_info["path"] model_class = model_info["model_class"] print(f"Loading {model_name} from {model_path}") processor = RTDetrImageProcessor.from_pretrained(model_path) model = model_class.from_pretrained(model_path) model = model.to(device) model.eval() current_processor = processor current_model = model current_model_name = model_name return model, processor except Exception as e: print(f"Error loading model: {e}") return None, None def visualize_bbox(image_input, bboxes, classes, scores, id_to_names, alpha=0.3, show_labels=True): """Visualize bounding boxes with OpenCV.""" if isinstance(image_input, Image.Image): image = np.array(image_input) image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) elif isinstance(image_input, np.ndarray): if len(image_input.shape) == 3 and image_input.shape[2] == 3: image = cv2.cvtColor(image_input, cv2.COLOR_RGB2BGR) else: image = image_input.copy() else: raise ValueError("Input must be PIL Image or numpy array") if len(bboxes) == 0: return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) overlay = image.copy() cmap = colormap(N=len(id_to_names), normalized=False) for i in range(len(bboxes)): try: bbox = bboxes[i] if torch.is_tensor(bbox): bbox = bbox.cpu().numpy() class_id = classes[i] if torch.is_tensor(class_id): class_id = class_id.item() score = scores[i] if torch.is_tensor(score): score = score.item() x_min, y_min, x_max, y_max = map(int, bbox) class_id = int(class_id) class_name = id_to_names.get(class_id, f"unknown_{class_id}") color = tuple(int(c) for c in cmap[class_id % len(cmap)]) # Draw filled rectangle on overlay cv2.rectangle(overlay, (x_min, y_min), (x_max, y_max), color, -1) # Draw border on main image cv2.rectangle(image, (x_min, y_min), (x_max, y_max), color, 3) # Add text label only if show_labels is True if show_labels: text = f"{class_name}: {score:.3f}" (text_width, text_height), baseline = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.8, 2) cv2.rectangle(image, (x_min, y_min - text_height - baseline - 4), (x_min + text_width + 8, y_min), color, -1) cv2.putText(image, text, (x_min + 4, y_min - 6), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2) except Exception as e: print(f"Skipping box {i} due to error: {e}") # Apply transparency cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0, image) return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) def process_image(input_img, model_name, conf_threshold, iou_threshold, nms_method, alpha, show_labels): """Process image with document layout detection.""" if input_img is None: return None, "❌ Please upload an image first." # Load model if needed model, processor = load_model(model_name) if model is None or processor is None: return None, f"❌ Error loading model {model_name}." try: # Prepare image if isinstance(input_img, np.ndarray): input_img = Image.fromarray(input_img) if input_img.mode != 'RGB': input_img = input_img.convert('RGB') # Process with model inputs = processor(images=[input_img], return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = model(**inputs) # Post-process results results = processor.post_process_object_detection( outputs, target_sizes=torch.tensor([input_img.size[::-1]]), threshold=conf_threshold, ) if not results or len(results) == 0: return np.array(input_img), "ℹ️ No detections found." result = results[0] boxes = result["boxes"] scores = result["scores"] labels = result["labels"] if len(boxes) == 0: return np.array(input_img), f"ℹ️ No detections above threshold {conf_threshold:.2f}." # Apply NMS if iou_threshold < 1.0: if nms_method == "Custom IoMin": keep_indices = nms_custom(boxes=boxes, scores=scores, iou_threshold=iou_threshold) else: # Use torchvision NMS with correct format keep_indices = torchvision.ops.nms(boxes, scores, iou_threshold) boxes = boxes[keep_indices] scores = scores[keep_indices] labels = labels[keep_indices] # Visualize results output = visualize_bbox(input_img, boxes, labels, scores, classes_map, alpha=alpha, show_labels=show_labels) labels_status = "with labels" if show_labels else "without labels" info = f"✅ Found {len(boxes)} detections ({labels_status}) | Model: {model_name} | Confidence: {conf_threshold:.2f}" return output, info except Exception as e: print(f"[ERROR] process_image failed: {e}") error_msg = f"❌ Processing error: {str(e)}" if input_img is not None: return np.array(input_img), error_msg return np.zeros((512, 512, 3), dtype=np.uint8), error_msg if __name__ == "__main__": print(f"🚀 Starting Document Layout Analysis App") print(f"📱 Device: {device}") print(f"🤖 Available models: {len(MODELS)}") # Custom CSS for compact layout custom_css = """ .gradio-container { max-width: 1400px !important; margin: 0 auto !important; padding: 20px !important; } .controls-container { background: #f8f9fa; border-radius: 12px; border: 1px solid #dee2e6; padding: 20px; margin-bottom: 20px; } .results-container { background: #ffffff; border-radius: 12px; border: 1px solid #dee2e6; padding: 20px; } .section-divider { border-top: 2px solid #e9ecef; margin: 20px 0; padding-top: 20px; } .analyze-btn { background: linear-gradient(45deg, #667eea, #764ba2) !important; border: none !important; color: white !important; font-weight: bold !important; font-size: 18px !important; padding: 15px 30px !important; border-radius: 10px !important; } """ # Create Gradio interface with gr.Blocks( title="📄 Document Layout Analysis", theme=gr.themes.Soft(), css=custom_css ) as demo: # Header gr.HTML("""

🔍 Document Layout Analysis

Compact interface for advanced document structure detection

""") # Controls Section with gr.Group(elem_classes=["controls-container"]): # 1. Image Upload (First) gr.HTML("

📄 Upload Document

") input_img = gr.Image( label="Document Image", type="pil", height=300, interactive=True ) # Divider gr.HTML("
") # 2. Model Selection (Second) gr.HTML("

🤖 Model Selection

") model_dropdown = gr.Dropdown( choices=list(MODELS.keys()), value="Egret XLarge", label="AI Model", info="Model will load automatically when analyzing", interactive=True ) # Divider gr.HTML("
") # 3. Detection Parameters (Third) gr.HTML("

⚙️ Detection Settings

") with gr.Row(): conf_threshold = gr.Slider( minimum=0.0, maximum=1.0, value=0.6, step=0.05, label="Confidence Threshold", info="Minimum confidence for detections" ) iou_threshold = gr.Slider( minimum=0.0, maximum=1.0, value=0.5, step=0.05, label="NMS IoU Threshold", info="Non-maximum suppression threshold" ) with gr.Row(): nms_method = gr.Radio( choices=["Custom IoMin", "Standard IoU"], value="Custom IoMin", label="NMS Algorithm", info="Choose suppression method" ) alpha_slider = gr.Slider( minimum=0.0, maximum=1.0, value=0.3, step=0.1, label="Overlay Transparency", info="Transparency of detection overlays" ) show_labels_checkbox = gr.Checkbox( value=True, label="Show Class Labels and Confidence Scores", info="Display detection labels on the output image", interactive=True ) # Divider gr.HTML("
") # 4. Analyze Button (Last) detect_btn = gr.Button( "🔍 Analyze Document", variant="primary", size="lg", elem_classes=["analyze-btn"] ) # Results Section with gr.Group(elem_classes=["results-container"]): gr.HTML("

🎯 Analysis Results

") output_img = gr.Image( label="Analyzed Document", type="numpy", height=600, interactive=False ) detection_info = gr.Textbox( label="Detection Summary", value="Ready for analysis. Upload an image and click 'Analyze Document'.", interactive=False, lines=2, show_copy_button=True ) # Event Handler detect_btn.click( fn=process_image, inputs=[ input_img, model_dropdown, conf_threshold, iou_threshold, nms_method, alpha_slider, show_labels_checkbox ], outputs=[output_img, detection_info] ) # Launch application demo.launch( server_name="0.0.0.0", server_port=7860, debug=True, share=False, show_error=True, inbrowser=True )