Spaces:
Runtime error
Runtime error
| from pathlib import Path | |
| from typing import Dict, List, Optional, Any, Union | |
| import os | |
| import base64 | |
| import tempfile | |
| import json | |
| import logging | |
| from PIL import Image | |
| import io | |
| from src.parsers.parser_interface import DocumentParser | |
| from src.parsers.parser_registry import ParserRegistry | |
| from src.core.config import config | |
| from src.core.exceptions import DocumentProcessingError, ConversionError | |
| # Import the Mistral AI client | |
| try: | |
| from mistralai import Mistral | |
| MISTRAL_AVAILABLE = True | |
| except ImportError: | |
| MISTRAL_AVAILABLE = False | |
| # Get logger | |
| logger = logging.getLogger(__name__) | |
| # Check if API key is available and log a message if not | |
| if not config.api.mistral_api_key: | |
| logger.warning("MISTRAL_API_KEY environment variable not found. Mistral OCR parser may not work.") | |
| class MistralOcrParser(DocumentParser): | |
| """Parser that uses Mistral OCR to convert documents to markdown.""" | |
| def get_name(cls) -> str: | |
| return "Mistral OCR" | |
| def get_supported_ocr_methods(cls) -> List[Dict[str, Any]]: | |
| return [ | |
| { | |
| "id": "ocr", | |
| "name": "OCR Only", | |
| "default_params": {} | |
| }, | |
| { | |
| "id": "understand", | |
| "name": "Document Understanding", | |
| "default_params": {} | |
| } | |
| ] | |
| def get_description(cls) -> str: | |
| return "Mistral OCR parser for extracting text from documents and images with optional document understanding" | |
| def encode_image(self, image_path): | |
| """Encode the image to base64.""" | |
| try: | |
| with open(image_path, "rb") as image_file: | |
| return base64.b64encode(image_file.read()).decode('utf-8') | |
| except FileNotFoundError: | |
| logger.error(f"File not found: {image_path}") | |
| raise DocumentProcessingError(f"File not found: {image_path}") | |
| except Exception as e: | |
| logger.error(f"Error encoding file {image_path}: {e}") | |
| raise DocumentProcessingError(f"Error encoding file: {e}") | |
| def parse(self, file_path: Union[str, Path], ocr_method: Optional[str] = None, **kwargs) -> str: | |
| """Parse a document using Mistral OCR.""" | |
| if not MISTRAL_AVAILABLE: | |
| raise DocumentProcessingError( | |
| "The Mistral AI client is not installed. " | |
| "Please install it with 'pip install mistralai'." | |
| ) | |
| # Use the API key from centralized config | |
| if not config.api.mistral_api_key: | |
| raise DocumentProcessingError( | |
| "MISTRAL_API_KEY environment variable is not set. " | |
| "Please set it to your Mistral API key." | |
| ) | |
| # Check the OCR method | |
| use_document_understanding = ocr_method == "understand" | |
| try: | |
| # Initialize the Mistral client | |
| client = Mistral(api_key=config.api.mistral_api_key) | |
| # Determine file type based on extension | |
| file_path = Path(file_path) | |
| file_extension = file_path.suffix.lower() | |
| # Process the document with OCR | |
| if use_document_understanding: | |
| # Use document understanding via chat API for enhanced extraction | |
| return self._extract_with_document_understanding(client, file_path, file_extension) | |
| else: | |
| # Use regular OCR for basic text extraction | |
| return self._extract_with_ocr(client, file_path, file_extension) | |
| except (DocumentProcessingError, ConversionError): | |
| # Re-raise our custom exceptions | |
| raise | |
| except Exception as e: | |
| error_message = f"Error parsing document with Mistral OCR: {str(e)}" | |
| logger.error(error_message) | |
| raise DocumentProcessingError(error_message) | |
| def _extract_with_ocr(self, client, file_path, file_extension): | |
| """Extract document content using basic OCR.""" | |
| try: | |
| # Process according to file type | |
| if file_extension in ['.pdf', '.docx', '.pptx']: | |
| # For documents (PDF, DOCX, PPTX), we need to upload the file to the Mistral API first | |
| try: | |
| # Upload the file to Mistral API | |
| uploaded_pdf = client.files.upload( | |
| file={ | |
| "file_name": file_path.name, | |
| "content": open(file_path, "rb"), | |
| }, | |
| purpose="ocr" | |
| ) | |
| # Get signed URL for the file | |
| signed_url = client.files.get_signed_url(file_id=uploaded_pdf.id) | |
| # Use the signed URL for OCR processing | |
| ocr_response = client.ocr.process( | |
| model="mistral-ocr-latest", | |
| document={ | |
| "type": "document_url", | |
| "document_url": signed_url.url | |
| }, | |
| include_image_base64=True | |
| ) | |
| except Exception as e: | |
| # If file upload fails, try to use a direct URL method with base64 | |
| logger.warning(f"Failed to upload document, trying alternate method: {str(e)}") | |
| base64_doc = self.encode_image(file_path) | |
| if base64_doc: | |
| mime_type = self._get_mime_type(file_extension) | |
| ocr_response = client.ocr.process( | |
| model="mistral-ocr-latest", | |
| document={ | |
| "type": "document_url", | |
| "document_url": f"data:{mime_type};base64,{base64_doc}" | |
| }, | |
| include_image_base64=True | |
| ) | |
| else: | |
| raise DocumentProcessingError("Failed to process document") | |
| else: | |
| # For images (jpg, png, etc.), use image_url with base64 | |
| base64_image = self.encode_image(file_path) | |
| mime_type = self._get_mime_type(file_extension) | |
| ocr_response = client.ocr.process( | |
| model="mistral-ocr-latest", | |
| document={ | |
| "type": "image_url", | |
| "image_url": f"data:{mime_type};base64,{base64_image}" | |
| }, | |
| include_image_base64=True | |
| ) | |
| # Process the OCR response | |
| # The Mistral OCR response is structured with pages that contain text content | |
| markdown_text = "" | |
| # Check if the response contains pages | |
| if hasattr(ocr_response, 'pages') and ocr_response.pages: | |
| for page in ocr_response.pages: | |
| # Add page number as heading | |
| page_num = page.index if hasattr(page, 'index') else "Unknown" | |
| markdown_text += f"## Page {page_num}\n\n" | |
| # Add text content if available | |
| if hasattr(page, 'text'): | |
| markdown_text += page.text + "\n\n" | |
| # Or markdown content if that's how it's structured | |
| elif hasattr(page, 'markdown'): | |
| markdown_text += page.markdown + "\n\n" | |
| # Add any extracted tables with markdown formatting | |
| if hasattr(page, 'tables') and page.tables: | |
| for i, table in enumerate(page.tables): | |
| markdown_text += f"### Table {i+1}\n\n" | |
| if hasattr(table, 'markdown'): | |
| markdown_text += table.markdown + "\n\n" | |
| elif hasattr(table, 'data'): | |
| # Convert table data to markdown format | |
| markdown_text += self._convert_table_to_markdown(table.data) + "\n\n" | |
| # If no markdown was generated, check for raw content | |
| if not markdown_text and hasattr(ocr_response, 'content'): | |
| markdown_text = ocr_response.content | |
| # If still no content, try to access any available data | |
| if not markdown_text: | |
| # Try to get a JSON representation to extract data | |
| try: | |
| response_dict = ocr_response.to_dict() if hasattr(ocr_response, 'to_dict') else ocr_response.__dict__ | |
| markdown_text = "# Extracted Content\n\n" | |
| # Look for content or text in the response dictionary | |
| if 'content' in response_dict: | |
| markdown_text += response_dict['content'] | |
| elif 'text' in response_dict: | |
| markdown_text += response_dict['text'] | |
| elif 'pages' in response_dict: | |
| for page in response_dict['pages']: | |
| if 'text' in page: | |
| markdown_text += page['text'] + "\n\n" | |
| else: | |
| # Just dump what we got as JSON | |
| markdown_text += f"```json\n{json.dumps(response_dict, indent=2)}\n```" | |
| except Exception as e: | |
| markdown_text = f"# Error Processing Response\n\nCould not process the OCR response: {str(e)}" | |
| # If we still have no content, raise an error | |
| if not markdown_text: | |
| raise ConversionError("No text was extracted from the document") | |
| return f"# Document Content\n\n{markdown_text}" | |
| except (DocumentProcessingError, ConversionError): | |
| # Re-raise our custom exceptions | |
| raise | |
| except Exception as e: | |
| logger.error(f"OCR extraction error: {str(e)}") | |
| raise ConversionError(f"OCR extraction failed: {str(e)}") | |
| def _extract_with_document_understanding(self, client, file_path, file_extension): | |
| """Extract and understand document content using chat completion.""" | |
| try: | |
| # For documents and images, we'll use Mistral's document understanding capability | |
| if file_extension in ['.pdf', '.docx', '.pptx']: | |
| # Upload document first | |
| try: | |
| # Upload the file | |
| uploaded_pdf = client.files.upload( | |
| file={ | |
| "file_name": file_path.name, | |
| "content": open(file_path, "rb"), | |
| }, | |
| purpose="ocr" | |
| ) | |
| # Get the signed URL | |
| signed_url = client.files.get_signed_url(file_id=uploaded_pdf.id) | |
| # Send to chat completion API with document understanding prompt | |
| chat_response = client.chat.complete( | |
| model="mistral-large-latest", | |
| max_tokens=config.model.max_tokens, | |
| temperature=config.model.temperature, | |
| messages=[ | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": "Convert this document to well-formatted markdown. Preserve all important content, structure, headings, lists, and tables. Include brief descriptions of any images." | |
| }, | |
| { | |
| "type": "document_url", | |
| "document_url": signed_url.url | |
| } | |
| ] | |
| } | |
| ] | |
| ) | |
| # Get the markdown result | |
| return chat_response.choices[0].message.content | |
| except Exception as e: | |
| # Fall back to OCR if document understanding fails | |
| logger.warning(f"Document understanding failed, falling back to OCR: {str(e)}") | |
| return self._extract_with_ocr(client, file_path, file_extension) | |
| else: | |
| # For images, encode to base64 | |
| base64_image = self.encode_image(file_path) | |
| mime_type = self._get_mime_type(file_extension) | |
| # Use the chat API with the image for document understanding | |
| chat_response = client.chat.complete( | |
| model="mistral-large-latest", | |
| max_tokens=config.model.max_tokens, | |
| temperature=config.model.temperature, | |
| messages=[ | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": "Extract all text from this image and convert it to well-formatted markdown. Preserve the structure and layout as much as possible." | |
| }, | |
| { | |
| "type": "image_url", | |
| "image_url": { | |
| "url": f"data:{mime_type};base64,{base64_image}" | |
| } | |
| } | |
| ] | |
| } | |
| ] | |
| ) | |
| # Get the markdown result | |
| return chat_response.choices[0].message.content | |
| except Exception as e: | |
| logger.error(f"Document understanding error: {str(e)}") | |
| raise ConversionError(f"Document understanding failed: {str(e)}") | |
| def _get_mime_type(self, file_extension: str) -> str: | |
| """Get the MIME type for a file extension supported by Mistral OCR.""" | |
| mime_types = { | |
| # Document formats supported by Mistral OCR | |
| ".pdf": "application/pdf", | |
| ".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", | |
| ".pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation", | |
| # Image formats supported by Mistral OCR | |
| ".jpg": "image/jpeg", | |
| ".jpeg": "image/jpeg", | |
| ".png": "image/png", | |
| ".gif": "image/gif", | |
| ".bmp": "image/bmp", | |
| ".tiff": "image/tiff", | |
| ".tif": "image/tiff", | |
| ".avif": "image/avif", | |
| ".webp": "image/webp", | |
| } | |
| return mime_types.get(file_extension, "application/octet-stream") | |
| def _convert_table_to_markdown(self, table_data) -> str: | |
| """Convert a table data structure to markdown format.""" | |
| if not table_data or not isinstance(table_data, list): | |
| return "" | |
| # Create markdown table | |
| markdown = "" | |
| # Add header row | |
| if table_data and isinstance(table_data[0], list): | |
| header = table_data[0] | |
| markdown += "| " + " | ".join(str(cell) for cell in header) + " |\n" | |
| # Add separator row | |
| markdown += "| " + " | ".join(["---"] * len(header)) + " |\n" | |
| # Add data rows | |
| for row in table_data[1:]: | |
| markdown += "| " + " | ".join(str(cell) for cell in row) + " |\n" | |
| return markdown | |
| def _validate_batch_files(self, file_paths: List[Path]) -> None: | |
| """Validate batch of files for multi-document processing.""" | |
| if len(file_paths) == 0: | |
| raise DocumentProcessingError("No files provided for processing") | |
| if len(file_paths) > 5: | |
| raise DocumentProcessingError("Maximum 5 files allowed for batch processing") | |
| total_size = 0 | |
| for fp in file_paths: | |
| if not fp.exists(): | |
| raise DocumentProcessingError(f"File not found: {fp}") | |
| size = fp.stat().st_size | |
| if size > 10 * 1024 * 1024: | |
| raise DocumentProcessingError(f"Individual file size exceeds 10MB: {fp.name}") | |
| total_size += size | |
| if total_size > 20 * 1024 * 1024: | |
| raise DocumentProcessingError(f"Combined file size ({total_size / (1024*1024):.1f}MB) exceeds 20MB limit") | |
| # simple mime validation | |
| for fp in file_paths: | |
| if self._get_mime_type(fp.suffix.lower()) == "application/octet-stream": | |
| raise DocumentProcessingError(f"Unsupported file type: {fp.name}") | |
| def _create_document_part(self, file_path: Path) -> Dict[str, Any]: | |
| """Return a dict representing an image_url or document_url part for Mistral chat/OCR.""" | |
| ext = file_path.suffix.lower() | |
| if ext in ['.pdf', '.docx', '.pptx']: | |
| # upload and get signed url | |
| client = Mistral(api_key=config.api.mistral_api_key) | |
| uploaded = client.files.upload( | |
| file={ | |
| "file_name": file_path.name, | |
| "content": open(file_path, "rb"), | |
| }, | |
| purpose="ocr", | |
| ) | |
| signed = client.files.get_signed_url(file_id=uploaded.id) | |
| return { | |
| "type": "document_url", | |
| "document_url": signed.url, | |
| } | |
| else: | |
| # encode image | |
| b64 = self.encode_image(file_path) | |
| mime = self._get_mime_type(ext) | |
| return { | |
| "type": "image_url", | |
| "image_url": { | |
| "url": f"data:{mime};base64,{b64}" | |
| } | |
| } | |
| def _create_batch_prompt(self, file_paths: List[Path], processing_type: str, original_filenames: Optional[List[str]] = None) -> str: | |
| if original_filenames: | |
| names = original_filenames | |
| else: | |
| names = [fp.name for fp in file_paths] | |
| file_list = "\n".join([f"- {name}" for name in names]) | |
| base = f"I will provide you with {len(file_paths)} documents.\n{file_list}\n\n" | |
| if processing_type == "individual": | |
| return base + "Please convert each document to markdown as its own section, preserving structure." | |
| if processing_type == "summary": | |
| return base + ( | |
| "Please first write an EXECUTIVE SUMMARY of all documents, then include converted markdown sections per document." | |
| ) | |
| if processing_type == "comparison": | |
| return base + ( | |
| "Please provide a comparison table of the documents, then individual summaries and cross-document insights." | |
| ) | |
| # default combined | |
| return base + "Please merge the content of all documents into a single cohesive markdown document." | |
| def _format_batch_output(self, response_text: str, file_paths: List[Path], processing_type: str, original_filenames: Optional[List[str]] = None) -> str: | |
| if original_filenames: | |
| names = original_filenames | |
| else: | |
| names = [fp.name for fp in file_paths] | |
| header = ( | |
| f"<!-- Multi-Document Processing Results -->\n" | |
| f"<!-- Processing Type: {processing_type} -->\n" | |
| f"<!-- Files Processed: {len(file_paths)} -->\n" | |
| f"<!-- File Names: {', '.join(names)} -->\n\n" | |
| ) | |
| return header + response_text | |
| def parse_multiple( | |
| self, | |
| file_paths: List[Union[str, Path]], | |
| processing_type: str = "combined", | |
| original_filenames: Optional[List[str]] = None, | |
| ocr_method: Optional[str] = None, | |
| output_format: str = "markdown", | |
| **kwargs, | |
| ) -> str: | |
| """Parse multiple documents, supporting the same processing types as Gemini parser.""" | |
| if not MISTRAL_AVAILABLE: | |
| raise DocumentProcessingError("Mistral client not installed. Install with 'pip install mistralai'.") | |
| if not config.api.mistral_api_key: | |
| raise DocumentProcessingError("MISTRAL_API_KEY not set.") | |
| try: | |
| # convert to Path objects | |
| paths = [Path(p) for p in file_paths] | |
| self._validate_batch_files(paths) | |
| if self._check_cancellation(): | |
| return "Conversion cancelled." | |
| use_understanding = ocr_method == "understand" | |
| client = Mistral(api_key=config.api.mistral_api_key) | |
| if use_understanding: | |
| # Build chat content with document parts | |
| prompt = self._create_batch_prompt(paths, processing_type, original_filenames) | |
| content_parts = [ | |
| {"type": "text", "text": prompt}, | |
| ] | |
| for p in paths: | |
| if self._check_cancellation(): | |
| return "Conversion cancelled." | |
| content_parts.append(self._create_document_part(p)) | |
| chat_response = client.chat.complete( | |
| model="mistral-large-latest", | |
| max_tokens=config.model.max_tokens, | |
| temperature=config.model.temperature, | |
| messages=[{"role": "user", "content": content_parts}], | |
| ) | |
| markdown_text = chat_response.choices[0].message.content | |
| return self._format_batch_output(markdown_text, paths, processing_type, original_filenames) | |
| # else basic OCR path | |
| results = [] | |
| for idx, p in enumerate(paths): | |
| if self._check_cancellation(): | |
| return "Conversion cancelled." | |
| text = self._extract_with_ocr(client, p, p.suffix.lower()) | |
| if processing_type == "individual": | |
| name = (original_filenames[idx] if original_filenames else p.name) | |
| text = f"# Document {idx+1}: {name}\n\n" + text | |
| results.append(text) | |
| combined_md = "\n\n---\n\n".join(results) if processing_type in ["individual", "combined"] else "\n\n".join(results) | |
| # For summary/comparison we now ask chat to summarise | |
| if processing_type in ["summary", "comparison"]: | |
| prompt = self._create_batch_prompt(paths, processing_type, original_filenames) | |
| chat_response = client.chat.complete( | |
| model="mistral-large-latest", | |
| max_tokens=config.model.max_tokens, | |
| temperature=config.model.temperature, | |
| messages=[ | |
| {"role": "user", "content": prompt + "\n\n" + combined_md} | |
| ], | |
| ) | |
| combined_md = chat_response.choices[0].message.content | |
| return self._format_batch_output(combined_md, paths, processing_type, original_filenames) | |
| except Exception as e: | |
| logger.error(f"Error parsing multiple documents with Mistral OCR: {str(e)}") | |
| raise DocumentProcessingError(f"Batch processing failed: {str(e)}") | |
| # Register the parser with the registry | |
| if MISTRAL_AVAILABLE: | |
| ParserRegistry.register(MistralOcrParser) | |
| else: | |
| print("Mistral OCR parser not registered: mistralai package not installed") |