Spaces:
Running
Running
| import os | |
| import sys | |
| import time | |
| from enum import Enum | |
| from pathlib import Path | |
| import json | |
| import base64 | |
| import pycountry | |
| import logging | |
| from pydantic import BaseModel | |
| from mistralai import Mistral | |
| from mistralai import DocumentURLChunk, ImageURLChunk, TextChunk | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| # Import utilities for OCR processing | |
| try: | |
| from ocr_utils import replace_images_in_markdown, get_combined_markdown | |
| except ImportError: | |
| # Define fallback functions if module not found | |
| def replace_images_in_markdown(markdown_str, images_dict): | |
| for img_name, base64_str in images_dict.items(): | |
| markdown_str = markdown_str.replace( | |
| f"", f"" | |
| ) | |
| return markdown_str | |
| def get_combined_markdown(ocr_response): | |
| markdowns = [] | |
| for page in ocr_response.pages: | |
| image_data = {} | |
| for img in page.images: | |
| image_data[img.id] = img.image_base64 | |
| markdowns.append(replace_images_in_markdown(page.markdown, image_data)) | |
| return "\n\n".join(markdowns) | |
| # Import config directly (now local to historical-ocr) | |
| from config import MISTRAL_API_KEY, OCR_MODEL, TEXT_MODEL, VISION_MODEL | |
| # Create language enum for structured output | |
| languages = {lang.alpha_2: lang.name for lang in pycountry.languages if hasattr(lang, 'alpha_2')} | |
| class LanguageMeta(Enum.__class__): | |
| def __new__(metacls, cls, bases, classdict): | |
| for code, name in languages.items(): | |
| classdict[name.upper().replace(' ', '_')] = name | |
| return super().__new__(metacls, cls, bases, classdict) | |
| class Language(Enum, metaclass=LanguageMeta): | |
| pass | |
| class StructuredOCRModel(BaseModel): | |
| file_name: str | |
| topics: list[str] | |
| languages: list[Language] | |
| ocr_contents: dict | |
| class StructuredOCR: | |
| def __init__(self, api_key=None): | |
| """Initialize the OCR processor with API key""" | |
| self.api_key = api_key or MISTRAL_API_KEY | |
| self.client = Mistral(api_key=self.api_key) | |
| def process_file(self, file_path, file_type=None, use_vision=True, max_pages=None, file_size_mb=None, custom_pages=None): | |
| """Process a file and return structured OCR results | |
| Args: | |
| file_path: Path to the file to process | |
| file_type: 'pdf' or 'image' (will be auto-detected if None) | |
| use_vision: Whether to use vision model for improved analysis | |
| max_pages: Optional limit on number of pages to process | |
| file_size_mb: Optional file size in MB (used for automatic page limiting) | |
| custom_pages: Optional list of specific page numbers to process | |
| Returns: | |
| Dictionary with structured OCR results | |
| """ | |
| # Convert file_path to Path object if it's a string | |
| file_path = Path(file_path) | |
| # Auto-detect file type if not provided | |
| if file_type is None: | |
| suffix = file_path.suffix.lower() | |
| file_type = "pdf" if suffix == ".pdf" else "image" | |
| # Get file size if not provided | |
| if file_size_mb is None and file_path.exists(): | |
| file_size_mb = file_path.stat().st_size / (1024 * 1024) # Convert bytes to MB | |
| # Check if file exceeds API limits (50 MB) | |
| if file_size_mb and file_size_mb > 50: | |
| logging.warning(f"File size {file_size_mb:.2f} MB exceeds Mistral API limit of 50 MB") | |
| return { | |
| "file_name": file_path.name, | |
| "topics": ["Document"], | |
| "languages": ["English"], | |
| "confidence_score": 0.0, | |
| "error": f"File size {file_size_mb:.2f} MB exceeds API limit of 50 MB", | |
| "ocr_contents": { | |
| "error": f"Failed to process file: File size {file_size_mb:.2f} MB exceeds Mistral API limit of 50 MB", | |
| "partial_text": "Document could not be processed due to size limitations." | |
| } | |
| } | |
| # For PDF files, limit pages based on file size if no explicit limit is given | |
| if file_type == "pdf" and file_size_mb and max_pages is None and custom_pages is None: | |
| if file_size_mb > 100: # Very large files | |
| max_pages = 3 | |
| elif file_size_mb > 50: # Large files | |
| max_pages = 5 | |
| elif file_size_mb > 20: # Medium files | |
| max_pages = 10 | |
| else: # Small files | |
| max_pages = None # Process all pages | |
| # Start processing timer | |
| start_time = time.time() | |
| # Read and process the file | |
| if file_type == "pdf": | |
| result = self._process_pdf(file_path, use_vision, max_pages, custom_pages) | |
| else: | |
| result = self._process_image(file_path, use_vision) | |
| # Add processing time information | |
| processing_time = time.time() - start_time | |
| result['processing_time'] = processing_time | |
| # Add a default confidence score if not present | |
| if 'confidence_score' not in result: | |
| result['confidence_score'] = 0.85 # Default confidence | |
| return result | |
| def _process_pdf(self, file_path, use_vision=True, max_pages=None, custom_pages=None): | |
| """Process a PDF file with OCR | |
| Args: | |
| file_path: Path to the PDF file | |
| use_vision: Whether to use vision model | |
| max_pages: Optional limit on the number of pages to process | |
| custom_pages: Optional list of specific page numbers to process | |
| """ | |
| logger = logging.getLogger("pdf_processor") | |
| logger.info(f"Processing PDF: {file_path}") | |
| try: | |
| # Upload the PDF file | |
| logger.info("Uploading PDF file to Mistral API") | |
| uploaded_file = self.client.files.upload( | |
| file={ | |
| "file_name": file_path.stem, | |
| "content": file_path.read_bytes(), | |
| }, | |
| purpose="ocr", | |
| ) | |
| # Get a signed URL for the uploaded file | |
| signed_url = self.client.files.get_signed_url(file_id=uploaded_file.id, expiry=1) | |
| # Process the PDF with OCR | |
| logger.info(f"Processing PDF with OCR using {OCR_MODEL}") | |
| pdf_response = self.client.ocr.process( | |
| document=DocumentURLChunk(document_url=signed_url.url), | |
| model=OCR_MODEL, | |
| include_image_base64=True | |
| ) | |
| # Limit pages if requested | |
| pages_to_process = pdf_response.pages | |
| total_pages = len(pdf_response.pages) | |
| limited_pages = False | |
| logger.info(f"PDF has {total_pages} total pages") | |
| # Handle custom page selection if provided | |
| if custom_pages: | |
| # Convert to 0-based indexing and filter valid page numbers | |
| valid_indices = [i-1 for i in custom_pages if 0 < i <= total_pages] | |
| if valid_indices: | |
| pages_to_process = [pdf_response.pages[i] for i in valid_indices] | |
| limited_pages = True | |
| logger.info(f"Processing {len(valid_indices)} custom-selected pages") | |
| # Otherwise handle max_pages limit | |
| elif max_pages and total_pages > max_pages: | |
| pages_to_process = pages_to_process[:max_pages] | |
| limited_pages = True | |
| logger.info(f"Processing only first {max_pages} pages out of {total_pages} total pages") | |
| # Calculate average confidence score based on OCR response if available | |
| confidence_score = 0.0 | |
| try: | |
| # Some OCR APIs provide confidence scores | |
| confidence_values = [] | |
| for page in pages_to_process: | |
| if hasattr(page, 'confidence'): | |
| confidence_values.append(page.confidence) | |
| if confidence_values: | |
| confidence_score = sum(confidence_values) / len(confidence_values) | |
| else: | |
| confidence_score = 0.85 # Default if no confidence scores available | |
| except: | |
| confidence_score = 0.85 # Default fallback | |
| # Combine pages' markdown into a single string | |
| all_markdown = "\n\n".join([page.markdown for page in pages_to_process]) | |
| # Extract structured data using the appropriate model | |
| if use_vision: | |
| # Get base64 of first page for vision model | |
| first_page_image = None | |
| if pages_to_process and pages_to_process[0].images: | |
| first_page_image = pages_to_process[0].images[0].image_base64 | |
| if first_page_image: | |
| # Use vision model | |
| logger.info(f"Using vision model: {VISION_MODEL}") | |
| result = self._extract_structured_data_with_vision(first_page_image, all_markdown, file_path.name) | |
| else: | |
| # Fall back to text-only model if no image available | |
| logger.info(f"No images in PDF, falling back to text model: {TEXT_MODEL}") | |
| result = self._extract_structured_data_text_only(all_markdown, file_path.name) | |
| else: | |
| # Use text-only model | |
| logger.info(f"Using text-only model: {TEXT_MODEL}") | |
| result = self._extract_structured_data_text_only(all_markdown, file_path.name) | |
| # Add page limit info to result if needed | |
| if limited_pages: | |
| result['limited_pages'] = { | |
| 'processed': len(pages_to_process), | |
| 'total': total_pages | |
| } | |
| # Add confidence score | |
| result['confidence_score'] = confidence_score | |
| # Store the raw OCR response for image rendering | |
| result['raw_response'] = pdf_response | |
| logger.info(f"PDF processing completed successfully") | |
| return result | |
| except Exception as e: | |
| logger.error(f"Error processing PDF: {str(e)}") | |
| # Return basic result on error | |
| return { | |
| "file_name": file_path.name, | |
| "topics": ["Document"], | |
| "languages": ["English"], | |
| "confidence_score": 0.0, | |
| "error": str(e), | |
| "ocr_contents": { | |
| "error": f"Failed to process PDF: {str(e)}", | |
| "partial_text": "Document could not be fully processed." | |
| } | |
| } | |
| def _process_image(self, file_path, use_vision=True): | |
| """Process an image file with OCR""" | |
| logger = logging.getLogger("image_processor") | |
| logger.info(f"Processing image: {file_path}") | |
| try: | |
| # Read and encode the image file | |
| logger.info("Encoding image for API") | |
| encoded_image = base64.b64encode(file_path.read_bytes()).decode() | |
| base64_data_url = f"data:image/jpeg;base64,{encoded_image}" | |
| # Process the image with OCR | |
| logger.info(f"Processing image with OCR using {OCR_MODEL}") | |
| image_response = self.client.ocr.process( | |
| document=ImageURLChunk(image_url=base64_data_url), | |
| model=OCR_MODEL, | |
| include_image_base64=True | |
| ) | |
| # Get the OCR markdown from the first page | |
| image_ocr_markdown = image_response.pages[0].markdown if image_response.pages else "" | |
| # Calculate confidence score if available | |
| confidence_score = 0.85 # Default value | |
| try: | |
| if hasattr(image_response.pages[0], 'confidence'): | |
| confidence_score = image_response.pages[0].confidence | |
| except: | |
| pass | |
| # Extract structured data using the appropriate model | |
| if use_vision: | |
| logger.info(f"Using vision model: {VISION_MODEL}") | |
| result = self._extract_structured_data_with_vision(base64_data_url, image_ocr_markdown, file_path.name) | |
| else: | |
| logger.info(f"Using text-only model: {TEXT_MODEL}") | |
| result = self._extract_structured_data_text_only(image_ocr_markdown, file_path.name) | |
| # Add confidence score | |
| result['confidence_score'] = confidence_score | |
| # Store the raw OCR response for image rendering | |
| result['raw_response'] = image_response | |
| logger.info("Image processing completed successfully") | |
| return result | |
| except Exception as e: | |
| logger.error(f"Error processing image: {str(e)}") | |
| # Return basic result on error | |
| return { | |
| "file_name": file_path.name, | |
| "topics": ["Document"], | |
| "languages": ["English"], | |
| "confidence_score": 0.0, | |
| "error": str(e), | |
| "ocr_contents": { | |
| "error": f"Failed to process image: {str(e)}", | |
| "partial_text": "Image could not be processed." | |
| } | |
| } | |
| def _extract_structured_data_with_vision(self, image_base64, ocr_markdown, filename): | |
| """Extract structured data using vision model""" | |
| try: | |
| # Parse with vision model with a timeout | |
| chat_response = self.client.chat.parse( | |
| model=VISION_MODEL, | |
| messages=[ | |
| { | |
| "role": "user", | |
| "content": [ | |
| ImageURLChunk(image_url=image_base64), | |
| TextChunk(text=( | |
| f"This is a historical document's OCR in markdown:\n" | |
| f"<BEGIN_IMAGE_OCR>\n{ocr_markdown}\n<END_IMAGE_OCR>.\n" | |
| f"Convert this into a structured JSON response with the OCR contents in a sensible dictionary. " | |
| f"Extract topics, languages, and organize the content logically." | |
| )) | |
| ], | |
| }, | |
| ], | |
| response_format=StructuredOCRModel, | |
| temperature=0 | |
| ) | |
| # Convert the response to a dictionary | |
| result = json.loads(chat_response.choices[0].message.parsed.json()) | |
| # Ensure languages is a list of strings, not Language enum objects | |
| if 'languages' in result: | |
| result['languages'] = [str(lang) for lang in result.get('languages', [])] | |
| except Exception as e: | |
| # Fall back to text-only model if vision model fails | |
| print(f"Vision model failed: {str(e)}. Falling back to text-only model.") | |
| result = self._extract_structured_data_text_only(ocr_markdown, filename) | |
| return result | |
| def _extract_structured_data_text_only(self, ocr_markdown, filename): | |
| """Extract structured data using text-only model""" | |
| try: | |
| # Parse with text-only model with a timeout | |
| chat_response = self.client.chat.parse( | |
| model=TEXT_MODEL, | |
| messages=[ | |
| { | |
| "role": "user", | |
| "content": f"This is a historical document's OCR in markdown:\n" | |
| f"<BEGIN_IMAGE_OCR>\n{ocr_markdown}\n<END_IMAGE_OCR>.\n" | |
| f"Convert this into a structured JSON response with the OCR contents. " | |
| f"Extract topics, languages, and organize the content logically." | |
| }, | |
| ], | |
| response_format=StructuredOCRModel, | |
| temperature=0 | |
| ) | |
| # Convert the response to a dictionary | |
| result = json.loads(chat_response.choices[0].message.parsed.json()) | |
| # Ensure languages is a list of strings, not Language enum objects | |
| if 'languages' in result: | |
| result['languages'] = [str(lang) for lang in result.get('languages', [])] | |
| except Exception as e: | |
| # Create a basic result if parsing fails | |
| print(f"Text model failed: {str(e)}. Creating basic result.") | |
| result = { | |
| "file_name": filename, | |
| "topics": ["Document"], | |
| "languages": ["English"], | |
| "ocr_contents": { | |
| "raw_text": ocr_markdown | |
| } | |
| } | |
| return result | |
| # For testing directly | |
| if __name__ == "__main__": | |
| import sys | |
| if len(sys.argv) < 2: | |
| print("Usage: python structured_ocr.py <file_path>") | |
| sys.exit(1) | |
| file_path = sys.argv[1] | |
| processor = StructuredOCR() | |
| result = processor.process_file(file_path) | |
| print(json.dumps(result, indent=2)) |