#!/usr/bin/env python3 """ Vision Model Figure Extraction Test Script This script uses DocLayout-YOLO to detect and extract figures, tables, and charts from PDF documents. It processes PDFs in the uploaded_pdfs/ directory and saves extracted figures as separate image files with metadata. Usage: python test_figure_extraction.py # Process all PDFs in uploaded_pdfs/ python test_figure_extraction.py path/to/file.pdf # Process specific PDF Integration Notes for main.py: - This script demonstrates the figure extraction pipeline - For integration: modify extract_text_by_page() to also extract figures - Store figure embeddings in Qdrant alongside text embeddings - Use multimodal retrieval (CLIP embeddings) for figure search """ import os import sys import json import argparse from pathlib import Path from typing import List, Dict, Any import cv2 import numpy as np from PIL import Image from pdf2image import convert_from_path from doclayout_yolo import YOLOv10 import torch import requests import os # Configuration EXTRACTED_FIGURES_DIR = Path("extracted_figures") UPLOADED_PDFS_DIR = Path("uploaded_pdfs") MODEL_NAME = "doclayout_yolo_docstructbench_imgsz1024.pt" CONFIDENCE_THRESHOLD = 0.25 IMAGE_SIZE = 1024 # Figure-related class labels in DocLayout-YOLO FIGURE_CLASSES = ['figure', 'picture', 'chart', 'diagram', 'graph', 'plot'] def setup_directories(): """Create necessary directories for output.""" EXTRACTED_FIGURES_DIR.mkdir(exist_ok=True) print(f"Created output directory: {EXTRACTED_FIGURES_DIR}") def download_model(): """Download the DocLayout-YOLO model if it doesn't exist.""" model_path = Path(MODEL_NAME) if model_path.exists(): print(f"Model already exists: {MODEL_NAME}") return str(model_path) print("Downloading DocLayout-YOLO model...") model_url = "https://huggingface.co/juliozhao/DocLayout-YOLO-DocStructBench/resolve/main/doclayout_yolo_docstructbench_imgsz1024.pt" try: response = requests.get(model_url, stream=True) response.raise_for_status() with open(model_path, 'wb') as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) print(f"Model downloaded successfully: {MODEL_NAME}") return str(model_path) except Exception as e: print(f"Error downloading model: {e}") print("Make sure you have internet connection for model download") sys.exit(1) def load_model(): """Load the DocLayout-YOLO model.""" print("Loading DocLayout-YOLO model...") try: # Download model if not exists model_path = download_model() model = YOLOv10(model_path) print(f"Model loaded successfully: {MODEL_NAME}") return model except Exception as e: print(f"Error loading model: {e}") print("Make sure you have internet connection for model download") sys.exit(1) def convert_pdf_to_images(pdf_path: Path) -> List[Image.Image]: """Convert PDF pages to PIL Images.""" print(f"Converting PDF to images: {pdf_path.name}") try: images = convert_from_path(pdf_path, dpi=200) print(f"Converted {len(images)} pages to images") return images except Exception as e: print(f"Error converting PDF: {e}") return [] def detect_figures(model, image: Image.Image) -> List[Dict[str, Any]]: """Detect figures in a single page image.""" # Convert PIL to numpy array for YOLO image_np = np.array(image) # Run detection results = model.predict( image_np, imgsz=IMAGE_SIZE, conf=CONFIDENCE_THRESHOLD, device='cuda' if torch.cuda.is_available() else 'cpu', verbose=False ) detections = [] if results and len(results) > 0 and results[0].boxes is not None: boxes = results[0].boxes for i, box in enumerate(boxes): # Get class name class_id = int(box.cls[0]) class_name = model.names[class_id] # Check if it's a figure-related class if class_name.lower() in FIGURE_CLASSES: # Get bounding box coordinates x1, y1, x2, y2 = box.xyxy[0].cpu().numpy() confidence = float(box.conf[0]) detections.append({ 'class_name': class_name, 'confidence': confidence, 'bbox': [float(x1), float(y1), float(x2), float(y2)], 'detection_id': i }) return detections def extract_and_save_figures(image: Image.Image, detections: List[Dict[str, Any]], page_num: int, pdf_name: str) -> List[Dict[str, Any]]: """Extract and save detected figures.""" saved_figures = [] for idx, detection in enumerate(detections): x1, y1, x2, y2 = detection['bbox'] # Crop the figure from the image figure_crop = image.crop((x1, y1, x2, y2)) # Generate filename figure_filename = f"{pdf_name}_page{page_num+1}_figure{idx+1}_{detection['class_name']}.png" figure_path = EXTRACTED_FIGURES_DIR / figure_filename # Save the figure figure_crop.save(figure_path, "PNG") # Store metadata figure_metadata = { 'filename': figure_filename, 'page_number': page_num + 1, 'class_name': detection['class_name'], 'confidence': detection['confidence'], 'bbox': detection['bbox'], 'image_size': figure_crop.size, 'pdf_name': pdf_name } saved_figures.append(figure_metadata) print(f" Saved figure: {figure_filename} (confidence: {detection['confidence']:.3f})") return saved_figures def process_pdf(pdf_path: Path, model) -> Dict[str, Any]: """Process a single PDF file.""" print(f"\n{'='*60}") print(f"Processing PDF: {pdf_path.name}") print(f"{'='*60}") # Convert PDF to images images = convert_pdf_to_images(pdf_path) if not images: return {'error': 'Failed to convert PDF to images'} pdf_name = pdf_path.stem all_figures = [] total_pages = len(images) # Process each page for page_num, image in enumerate(images): print(f"\nProcessing page {page_num + 1}/{total_pages}...") # Detect figures on this page detections = detect_figures(model, image) if detections: print(f" Found {len(detections)} figures on page {page_num + 1}") # Extract and save figures saved_figures = extract_and_save_figures(image, detections, page_num, pdf_name) all_figures.extend(saved_figures) else: print(f" No figures detected on page {page_num + 1}") # Save metadata metadata = { 'pdf_name': pdf_name, 'pdf_path': str(pdf_path), 'total_pages': total_pages, 'total_figures': len(all_figures), 'figures': all_figures, 'processing_timestamp': str(Path().cwd()), 'model_used': MODEL_NAME, 'confidence_threshold': CONFIDENCE_THRESHOLD } metadata_filename = f"{pdf_name}_metadata.json" metadata_path = EXTRACTED_FIGURES_DIR / metadata_filename with open(metadata_path, 'w') as f: json.dump(metadata, f, indent=2) print(f"\nSummary for {pdf_name}:") print(f" Pages processed: {total_pages}") print(f" Figures extracted: {len(all_figures)}") print(f" Metadata saved: {metadata_filename}") return metadata def main(): """Main function to process PDFs.""" parser = argparse.ArgumentParser(description='Extract figures from PDFs using DocLayout-YOLO') parser.add_argument('pdf_path', nargs='?', help='Path to specific PDF file (optional)') args = parser.parse_args() print("Vision Model Figure Extraction Test Script") print("=" * 50) # Setup setup_directories() model = load_model() # Determine which PDFs to process if args.pdf_path: pdf_path = Path(args.pdf_path) if not pdf_path.exists(): print(f"Error: PDF file not found: {pdf_path}") sys.exit(1) pdf_files = [pdf_path] else: # Process all PDFs in uploaded_pdfs directory if not UPLOADED_PDFS_DIR.exists(): print(f"Error: Directory not found: {UPLOADED_PDFS_DIR}") sys.exit(1) pdf_files = list(UPLOADED_PDFS_DIR.glob("*.pdf")) if not pdf_files: print(f"No PDF files found in {UPLOADED_PDFS_DIR}") sys.exit(1) print(f"Found {len(pdf_files)} PDF files to process") # Process each PDF all_results = [] total_figures = 0 for pdf_file in pdf_files: result = process_pdf(pdf_file, model) if 'error' not in result: all_results.append(result) total_figures += result['total_figures'] # Final summary print(f"\n{'='*60}") print("FINAL SUMMARY") print(f"{'='*60}") print(f"PDFs processed: {len(all_results)}") print(f"Total figures extracted: {total_figures}") print(f"Output directory: {EXTRACTED_FIGURES_DIR}") if total_figures > 0: print(f"\nExtracted figures are saved in: {EXTRACTED_FIGURES_DIR}") print("Each PDF has a corresponding metadata JSON file with detailed information.") print("\nIntegration Notes:") print("- Modify extract_text_by_page() in main.py to include figure extraction") print("- Store figure embeddings in Qdrant using CLIP or similar vision encoders") print("- Implement multimodal retrieval for combined text + figure search") if __name__ == "__main__": main()