Ayaan Sharif
Add picture classification with higher accuracy (images_scale=3.0) and improved bbox matching
1d76058
| import gradio as gr | |
| from docling.document_converter import DocumentConverter | |
| from docling.datamodel.base_models import InputFormat | |
| from docling.datamodel.pipeline_options import PdfPipelineOptions, TableFormerMode | |
| from docling.document_converter import PdfFormatOption | |
| from PIL import Image, ImageDraw, ImageFont | |
| import json | |
| import fitz # PyMuPDF | |
| # Color mapping for different layout elements | |
| COLORS = { | |
| "title": "#FF6B6B", | |
| "text": "#4ECDC4", | |
| "section_header": "#95E1D3", | |
| "table": "#F38181", | |
| "list": "#AA96DA", | |
| "figure": "#FCBAD3", | |
| "caption": "#A8D8EA", | |
| "formula": "#FFD93D", | |
| "footnote": "#6BCB77", | |
| "page_header": "#4D96FF", | |
| "page_footer": "#9D84B7", | |
| "picture": "#FF8C42", | |
| # Picture classifications | |
| "signature": "#9D4EDD", | |
| "qr_code": "#06FFA5", | |
| "bar_code": "#06FFA5", | |
| "logo": "#FFB627", | |
| "stamp": "#E63946", | |
| "icon": "#F4A261", | |
| "bar_chart": "#2A9D8F", | |
| "pie_chart": "#E76F51", | |
| "line_chart": "#264653", | |
| "flow_chart": "#8338EC", | |
| "map": "#3A86FF", | |
| "screenshot": "#FB5607", | |
| "other": "#CCCCCC", | |
| } | |
| def draw_layout_boxes(image_path, layout_data, scale_x=1.0, scale_y=1.0): | |
| """Draw bounding boxes on the image based on layout predictions""" | |
| # Open the image | |
| if isinstance(image_path, str): | |
| img = Image.open(image_path).convert("RGB") | |
| else: | |
| img = image_path.convert("RGB") | |
| draw = ImageDraw.Draw(img) | |
| # Try to load a font, fallback to default if not available | |
| try: | |
| font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 20) | |
| small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 14) | |
| except: | |
| font = ImageFont.load_default() | |
| small_font = ImageFont.load_default() | |
| # Draw each cluster | |
| for cluster in layout_data: | |
| label = cluster.get("label", "unknown") | |
| bbox = cluster.get("bbox") | |
| classification = cluster.get("classification") | |
| if bbox: | |
| # bbox format: [x0, y0, x1, y1] from PDF coordinates | |
| # Scale to match rendered image dimensions | |
| x0, y0, x1, y1 = bbox | |
| x0 = x0 * scale_x | |
| y0 = y0 * scale_y | |
| x1 = x1 * scale_x | |
| y1 = y1 * scale_y | |
| # Get color for this label | |
| color = COLORS.get(label, "#999999") | |
| # Draw rectangle | |
| draw.rectangle([x0, y0, x1, y1], outline=color, width=3) | |
| # Draw label with classification confidence if available | |
| if classification: | |
| confidence_pct = classification['confidence'] * 100 | |
| label_text = f"{label.replace('_', ' ').title()} ({confidence_pct:.0f}%)" | |
| else: | |
| label_text = label.replace("_", " ").title() | |
| bbox_text = draw.textbbox((x0, y0 - 25), label_text, font=small_font) | |
| draw.rectangle([bbox_text[0] - 2, bbox_text[1] - 2, bbox_text[2] + 2, bbox_text[3] + 2], | |
| fill=color) | |
| # Draw label text | |
| draw.text((x0, y0 - 25), label_text, fill="white", font=small_font) | |
| return img | |
| def process_document(file_path, mode, enable_ocr, enable_tables): | |
| """Process document with Docling and return results""" | |
| try: | |
| # Configure pipeline options | |
| pipeline_options = PdfPipelineOptions() | |
| pipeline_options.do_table_structure = enable_tables | |
| if enable_tables: | |
| if mode == "Accurate": | |
| pipeline_options.table_structure_options.mode = TableFormerMode.ACCURATE | |
| else: | |
| pipeline_options.table_structure_options.mode = TableFormerMode.FAST | |
| pipeline_options.do_ocr = enable_ocr | |
| pipeline_options.generate_page_images = True | |
| pipeline_options.generate_picture_images = True | |
| pipeline_options.do_picture_classification = True # Enable classification | |
| pipeline_options.images_scale = 3.0 # Higher resolution for better accuracy | |
| # Create converter | |
| converter = DocumentConverter( | |
| format_options={ | |
| InputFormat.PDF: PdfFormatOption(pipeline_options=pipeline_options), | |
| InputFormat.IMAGE: PdfFormatOption(pipeline_options=pipeline_options), | |
| } | |
| ) | |
| # Convert document | |
| result = converter.convert(file_path) | |
| # Extract layout information | |
| layout_info = [] | |
| total_clusters = 0 | |
| table_count = 0 | |
| # Get picture classifications for enrichment | |
| # We need to store by page number and use a more flexible matching | |
| picture_classifications_by_page = {} | |
| print(f"DEBUG: Total pictures found: {len(result.document.pictures)}") | |
| for picture in result.document.pictures: | |
| page_num = picture.prov[0].page_no | |
| bbox = picture.prov[0].bbox | |
| if page_num not in picture_classifications_by_page: | |
| picture_classifications_by_page[page_num] = [] | |
| # Get classification if available | |
| for annotation in picture.annotations: | |
| if hasattr(annotation, 'predicted_classes') and annotation.predicted_classes: | |
| top_pred = annotation.predicted_classes[0] | |
| picture_classifications_by_page[page_num].append({ | |
| 'bbox': bbox, | |
| 'class': top_pred.class_name, | |
| 'confidence': top_pred.confidence | |
| }) | |
| print(f"DEBUG: Found classification - page: {page_num}, bbox: ({bbox.l:.2f}, {bbox.t:.2f}, {bbox.r:.2f}, {bbox.b:.2f}), class: {top_pred.class_name}") | |
| break | |
| for page_no, page in enumerate(result.pages, 1): | |
| if page.predictions.layout: | |
| clusters = page.predictions.layout.clusters | |
| total_clusters += len(clusters) | |
| for cluster in clusters: | |
| # Check if this is a picture with classification | |
| label = cluster.label | |
| classification = None | |
| if cluster.label == "picture" and page_no in picture_classifications_by_page: | |
| print(f"DEBUG: Picture cluster at page {page_no}: ({cluster.bbox.l:.2f}, {cluster.bbox.t:.2f}, {cluster.bbox.r:.2f}, {cluster.bbox.b:.2f})") | |
| # Find matching classification by comparing bounding boxes with tolerance | |
| for pic_class in picture_classifications_by_page[page_no]: | |
| pic_bbox = pic_class['bbox'] | |
| # Check if bboxes match with small tolerance (allowing for floating point differences) | |
| # Compare left and right which should match exactly | |
| if (abs(cluster.bbox.l - pic_bbox.l) < 1.0 and | |
| abs(cluster.bbox.r - pic_bbox.r) < 1.0): | |
| # X coordinates match, this is likely the same picture | |
| classification = { | |
| 'class': pic_class['class'], | |
| 'confidence': pic_class['confidence'] | |
| } | |
| label = f"{classification['class']}" | |
| print(f"DEBUG: Matched classification: {label} (conf: {classification['confidence']:.2%})") | |
| break | |
| if not classification: | |
| print(f"DEBUG: No classification match found") | |
| layout_info.append({ | |
| "page": page_no, | |
| "label": label, | |
| "bbox": [cluster.bbox.l, cluster.bbox.t, cluster.bbox.r, cluster.bbox.b], | |
| "confidence": getattr(cluster, "confidence", None), | |
| "classification": classification | |
| }) | |
| # Count tables | |
| if page.predictions.tablestructure and page.predictions.tablestructure.table_map: | |
| table_count += len(page.predictions.tablestructure.table_map) | |
| # Get markdown output | |
| markdown_output = result.document.export_to_markdown() | |
| # Create visualization for first page | |
| visualization = None | |
| if result.pages and layout_info: | |
| # Draw boxes on first page only | |
| first_page_layout = [item for item in layout_info if item["page"] == 1] | |
| try: | |
| # Check if input is an image or PDF | |
| file_ext = file_path.lower().split('.')[-1] | |
| if file_ext in ['jpg', 'jpeg', 'png', 'tiff', 'bmp']: | |
| # For images: Open directly, coordinates should match 1:1 | |
| first_page_image = Image.open(file_path).convert("RGB") | |
| # No scaling needed for images - coordinates are already in pixels | |
| visualization = draw_layout_boxes(first_page_image, first_page_layout, | |
| scale_x=1.0, scale_y=1.0) | |
| else: | |
| # For PDFs: Render and calculate scale | |
| doc = fitz.open(file_path) | |
| page = doc[0] | |
| # Get page dimensions in PDF points | |
| page_rect = page.rect | |
| pdf_width = page_rect.width | |
| pdf_height = page_rect.height | |
| # Render at 2x for better quality | |
| zoom = 2.0 | |
| mat = fitz.Matrix(zoom, zoom) | |
| pix = page.get_pixmap(matrix=mat) | |
| first_page_image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) | |
| # Calculate scale: rendered_pixels / pdf_points | |
| scale_x = pix.width / pdf_width | |
| scale_y = pix.height / pdf_height | |
| doc.close() | |
| # Draw boxes with calculated scale | |
| visualization = draw_layout_boxes(first_page_image, first_page_layout, | |
| scale_x=scale_x, scale_y=scale_y) | |
| except Exception as e: | |
| print(f"Could not create visualization: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| # Create summary | |
| summary = f"""## Document Analysis Summary | |
| π **Total Pages:** {len(result.document.pages)} | |
| π·οΈ **Layout Elements Detected:** {total_clusters} | |
| π **Tables Found:** {table_count} | |
| ### Layout Elements by Type: | |
| """ | |
| # Count elements by type | |
| element_counts = {} | |
| for item in layout_info: | |
| label = item["label"] | |
| element_counts[label] = element_counts.get(label, 0) + 1 | |
| for label, count in sorted(element_counts.items()): | |
| summary += f"- **{label.replace('_', ' ').title()}**: {count}\n" | |
| # JSON output | |
| json_output = json.dumps(layout_info, indent=2) | |
| return visualization, summary, markdown_output, json_output | |
| except Exception as e: | |
| error_msg = f"Error processing document: {str(e)}" | |
| return None, error_msg, error_msg, error_msg | |
| def gradio_interface(file, mode, enable_ocr, enable_tables): | |
| """Gradio interface function""" | |
| if file is None: | |
| return None, "Please upload a document", "", "" | |
| return process_document(file.name, mode, enable_ocr, enable_tables) | |
| # Create Gradio interface | |
| with gr.Blocks(title="Document Layout Detection", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # π Document Layout & Structure Detection | |
| Upload a document (PDF, image, etc.) to automatically detect its layout structure including text, tables, figures, and more! | |
| **Features:** | |
| - **AI-Powered Layout Detection**: Automatically identifies document elements | |
| - **Table Structure Extraction**: Recognizes and extracts table data | |
| - **OCR Support**: Reads text from scanned documents and images | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| file_input = gr.File( | |
| label="Upload Document", | |
| file_types=[".pdf", ".jpg", ".jpeg", ".png", ".tiff", ".bmp"] | |
| ) | |
| mode_dropdown = gr.Dropdown( | |
| choices=["Fast", "Accurate"], | |
| value="Fast", | |
| label="Processing Mode", | |
| info="Accurate mode is slower but better for complex tables" | |
| ) | |
| ocr_checkbox = gr.Checkbox( | |
| label="Enable OCR", | |
| value=True, | |
| info="Use OCR for scanned documents and images" | |
| ) | |
| tables_checkbox = gr.Checkbox( | |
| label="Enable Table Detection", | |
| value=True, | |
| info="Detect and extract table structures" | |
| ) | |
| process_btn = gr.Button("π Process Document", variant="primary", size="lg") | |
| with gr.Column(scale=2): | |
| visualization_output = gr.Image(label="Layout Visualization (First Page)") | |
| summary_output = gr.Markdown(label="Summary") | |
| with gr.Tabs(): | |
| with gr.Tab("π Markdown Output"): | |
| markdown_output = gr.Textbox( | |
| label="Extracted Content (Markdown)", | |
| lines=20, | |
| max_lines=30 | |
| ) | |
| with gr.Tab("π§ JSON Layout Data"): | |
| json_output = gr.Code( | |
| label="Layout Predictions (JSON)", | |
| language="json", | |
| lines=20 | |
| ) | |
| gr.Markdown(""" | |
| ### Legend | |
| Different colors represent different document elements: | |
| **Layout Elements:** | |
| - π΄ Title β’ π΅ Text β’ π’ Section Header β’ π Table β’ π£ List/Figure/Formula | |
| **Picture Classifications (AI-detected):** | |
| - π£ Signature β’ π’ QR Code β’ π’ Barcode β’ π‘ Logo β’ π΄ Stamp | |
| - π¦ Charts (Bar/Pie/Line) β’ π£ Flow Chart β’ π Screenshot β’ βͺ Other | |
| ### How to Use | |
| 1. Upload your document (PDF or image of ID card, invoice, report, etc.) | |
| 2. Choose processing options (Fast mode recommended for quick results) | |
| 3. Click "Process Document" | |
| 4. View the visualization with bounding boxes and explore the outputs | |
| ### π‘ Try Examples Below! | |
| Click on any example document to see instant results on different document types. | |
| """) | |
| # Add examples with image previews | |
| with gr.Row(): | |
| gr.Examples( | |
| examples=[ | |
| ["sample/Screenshot 2025-10-13 114010.png", "Fast", True, True], | |
| ["sample/Screenshot 2025-10-13 114606.png", "Fast", True, True], | |
| ["sample/Screenshot 2025-10-15 191615.png", "Fast", True, True], | |
| ], | |
| inputs=[file_input, mode_dropdown, ocr_checkbox, tables_checkbox], | |
| outputs=[visualization_output, summary_output, markdown_output, json_output], | |
| fn=gradio_interface, | |
| cache_examples=False, | |
| label="π Example Documents", | |
| examples_per_page=3 | |
| ) | |
| # Connect the button | |
| process_btn.click( | |
| fn=gradio_interface, | |
| inputs=[file_input, mode_dropdown, ocr_checkbox, tables_checkbox], | |
| outputs=[visualization_output, summary_output, markdown_output, json_output] | |
| ) | |
| # Auto-process on file upload (optional) | |
| file_input.change( | |
| fn=gradio_interface, | |
| inputs=[file_input, mode_dropdown, ocr_checkbox, tables_checkbox], | |
| outputs=[visualization_output, summary_output, markdown_output, json_output] | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch() | |