File size: 20,292 Bytes
5b14aa2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
"""GPU processor with OCR capabilities for images and PDFs."""

import os
import json
import logging
import tempfile
import re
from typing import Dict, Any, List, Optional
from pathlib import Path

from .base import BaseProcessor
from ..result import ConversionResult
from ..exceptions import ConversionError, FileNotFoundError
from ..pipeline.ocr_service import OCRServiceFactory

# Configure logging
logger = logging.getLogger(__name__)


class GPUConversionResult(ConversionResult):
    """Enhanced ConversionResult for GPU processing with Nanonets OCR capabilities."""
    
    def __init__(self, content: str, metadata: Optional[Dict[str, Any]] = None, 
                 gpu_processor: Optional['GPUProcessor'] = None, file_path: Optional[str] = None,
                 ocr_provider: str = "nanonets"):
        super().__init__(content, metadata)
        self.gpu_processor = gpu_processor
        self.file_path = file_path
        self.ocr_provider = ocr_provider
        
        # Add GPU-specific metadata
        if metadata is None:
            self.metadata = {}
        
        # Ensure GPU-specific metadata is present
        if 'processing_mode' not in self.metadata:
            self.metadata['processing_mode'] = 'gpu'
        if 'ocr_provider' not in self.metadata:
            self.metadata['ocr_provider'] = ocr_provider
        if 'gpu_processing' not in self.metadata:
            self.metadata['gpu_processing'] = True
    
    def get_ocr_info(self) -> Dict[str, Any]:
        """Get information about the OCR processing used.
        
        Returns:
            Dictionary with OCR processing information
        """
        return {
            'ocr_provider': self.ocr_provider,
            'processing_mode': 'gpu',
            'file_path': self.file_path,
            'gpu_processor_available': self.gpu_processor is not None
        }
    
    def extract_markdown(self) -> str:
        """Export as markdown without GPU processing metadata."""
        return self.content
    
    def extract_html(self) -> str:
        """Export as HTML with GPU processing styling."""
        # Get the base HTML from parent class
        html_content = super().extract_html()
        
        # Add GPU processing indicator
        gpu_indicator = f"""
        <div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 1rem; border-radius: 8px; margin-bottom: 2rem; text-align: center;">
            <strong>๐Ÿš€ GPU Processed</strong> - Enhanced with {self.ocr_provider} OCR
        </div>
        """
        
        # Insert the indicator after the opening body tag
        body_start = html_content.find('<body')
        if body_start != -1:
            body_end = html_content.find('>', body_start) + 1
            return html_content[:body_end] + gpu_indicator + html_content[body_end:]
        
        return html_content
    
    def extract_data(self) -> Dict[str, Any]:
        """Export as structured JSON using Nanonets model with specific prompt."""
        logger.debug(f"GPUConversionResult.extract_data() called for {self.file_path}")

        try:
            # If we have a GPU processor and file path, use the model to extract JSON
            if self.gpu_processor and self.file_path and os.path.exists(self.file_path):
                logger.info("Using Nanonets model for JSON extraction")
                return self._extract_json_with_model()
            else:
                logger.info("Using fallback JSON conversion")
                # Fallback to base JSON conversion
                return self._convert_to_base_json()
        except Exception as e:
            logger.warning(f"Failed to extract JSON with model: {e}. Using fallback conversion.")
            return self._convert_to_base_json()
    
    def _extract_json_with_model(self) -> Dict[str, Any]:
        """Extract structured JSON using Nanonets model with specific prompt."""
        try:
            from PIL import Image
            from transformers import AutoTokenizer, AutoProcessor, AutoModelForImageTextToText
            
            # Get the model from the GPU processor's OCR service
            ocr_service = self.gpu_processor._get_ocr_service()
            
            # Access the model components from the OCR service
            if hasattr(ocr_service, 'processor') and hasattr(ocr_service, 'model') and hasattr(ocr_service, 'tokenizer'):
                model = ocr_service.model
                processor = ocr_service.processor
                tokenizer = ocr_service.tokenizer
            else:
                # Fallback: load model directly
                model_path = "nanonets/Nanonets-OCR-s"
                model = AutoModelForImageTextToText.from_pretrained(
                    model_path, 
                    torch_dtype="auto", 
                    device_map="auto"
                )
                model.eval()
                processor = AutoProcessor.from_pretrained(model_path)
                tokenizer = AutoTokenizer.from_pretrained(model_path)
            
            # Define the JSON extraction prompt
            prompt = """Extract all information from the above document and return it as a valid JSON object.

Instructions:
- The output should be a single JSON object.
- Keys should be meaningful field names.
- If multiple similar blocks (like invoice items or line items), return a list of JSON objects under a key.
- Use strings for all values.
- Wrap page numbers using: "page_number": "1"
- Wrap watermarks using: "watermark": "CONFIDENTIAL"
- Use โ˜ and โ˜‘ for checkboxes.

Example:
{
  "Name": "John Doe",
  "Invoice Number": "INV-4567",
  "Amount Due": "$123.45",
  "Items": [
    {"Description": "Widget A", "Price": "$20"},
    {"Description": "Widget B", "Price": "$30"}
  ],
  "page_number": "1",
  "watermark": "CONFIDENTIAL"
}"""
            
            # Load the image
            image = Image.open(self.file_path)
            
            # Prepare messages for the model
            messages = [
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": [
                    {"type": "image", "image": f"file://{self.file_path}"},
                    {"type": "text", "text": prompt},
                ]},
            ]
            
            # Apply chat template and process
            text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            inputs = processor(text=[text], images=[image], padding=True, return_tensors="pt")
            inputs = inputs.to(model.device)
            
            # Generate JSON response
            output_ids = model.generate(**inputs, max_new_tokens=15000, do_sample=False)
            generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
            
            json_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
            logger.debug(f"Generated JSON text: {json_text[:200]}...")
            
            # Try to parse the JSON response with improved parsing
            def try_parse_json(text):
                try:
                    return json.loads(text)
                except json.JSONDecodeError:
                    # Try cleaning and reparsing
                    try:
                        text = re.sub(r"(\w+):", r'"\1":', text)  # wrap keys
                        text = text.replace("'", '"')  # replace single quotes
                        return json.loads(text)
                    except (json.JSONDecodeError, Exception):
                        return {"raw_text": text}
            
            # Parse the JSON
            extracted_data = try_parse_json(json_text)
            
            # Create the result structure
            result = {
                "document": extracted_data,
                "format": "gpu_structured_json",
                "gpu_processing_info": {
                    'ocr_provider': self.ocr_provider,
                    'processing_mode': 'gpu',
                    'file_path': self.file_path,
                    'gpu_processor_available': self.gpu_processor is not None,
                    'json_extraction_method': 'nanonets_model'
                }
            }
            
            return result
                
        except Exception as e:
            logger.error(f"Failed to extract JSON with model: {e}")
            raise
    
    def _convert_to_base_json(self) -> Dict[str, Any]:
        """Fallback to base JSON conversion method."""
        # Get the base JSON from parent class
        base_json = super().extract_data()
        
        # Add GPU-specific metadata
        base_json['gpu_processing_info'] = {
            'ocr_provider': self.ocr_provider,
            'processing_mode': 'gpu',
            'file_path': self.file_path,
            'gpu_processor_available': self.gpu_processor is not None,
            'json_extraction_method': 'fallback_conversion'
        }
        
        # Update the format to indicate GPU processing
        base_json['format'] = 'gpu_structured_json'
        
        return base_json
    
    def extract_text(self) -> str:
        """Export as plain text without GPU processing header."""
        return self.content
    
    def get_processing_stats(self) -> Dict[str, Any]:
        """Get processing statistics and information.
        
        Returns:
            Dictionary with processing statistics
        """
        stats = {
            'processing_mode': 'gpu',
            'ocr_provider': self.ocr_provider,
            'file_path': self.file_path,
            'content_length': len(self.content),
            'word_count': len(self.content.split()),
            'line_count': len(self.content.split('\n')),
            'gpu_processor_available': self.gpu_processor is not None
        }
        
        # Add metadata if available
        if self.metadata:
            stats['metadata'] = self.metadata
        
        return stats


class GPUProcessor(BaseProcessor):
    """Processor for image files and PDFs with Nanonets OCR capabilities."""
    
    def __init__(self, preserve_layout: bool = True, include_images: bool = False, ocr_enabled: bool = True, use_markdownify: bool = None, ocr_service=None):
        super().__init__(preserve_layout, include_images, ocr_enabled, use_markdownify)
        self._ocr_service = ocr_service
    
    def can_process(self, file_path: str) -> bool:
        """Check if this processor can handle the given file.
        
        Args:
            file_path: Path to the file to check
            
        Returns:
            True if this processor can handle the file
        """
        if not os.path.exists(file_path):
            return False
        
        # Check file extension - ensure file_path is a string
        file_path_str = str(file_path)
        _, ext = os.path.splitext(file_path_str.lower())
        return ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp', '.gif', '.pdf']
    
    def _get_ocr_service(self):
        """Get OCR service instance."""
        if self._ocr_service is not None:
            return self._ocr_service
        # Use Nanonets OCR service by default
        self._ocr_service = OCRServiceFactory.create_service('nanonets')
        return self._ocr_service
    
    def process(self, file_path: str) -> GPUConversionResult:
        """Process image file or PDF with OCR capabilities.
        
        Args:
            file_path: Path to the image file or PDF
            
        Returns:
            GPUConversionResult with extracted content
        """
        try:
            if not os.path.exists(file_path):
                raise FileNotFoundError(f"File not found: {file_path}")
            
            # Check file type
            file_path_str = str(file_path)
            _, ext = os.path.splitext(file_path_str.lower())
            
            if ext == '.pdf':
                logger.info(f"Processing PDF file: {file_path}")
                return self._process_pdf(file_path)
            else:
                logger.info(f"Processing image file: {file_path}")
                return self._process_image(file_path)
            
        except Exception as e:
            logger.error(f"Failed to process file {file_path}: {e}")
            raise ConversionError(f"GPU processing failed: {e}")
    
    def _process_image(self, file_path: str) -> GPUConversionResult:
        """Process image file with OCR capabilities.
        
        Args:
            file_path: Path to the image file
            
        Returns:
            GPUConversionResult with extracted content
        """
        # Get OCR service
        ocr_service = self._get_ocr_service()
        
        # Extract text with layout awareness if enabled
        if self.ocr_enabled and self.preserve_layout:
            logger.info("Extracting text with layout awareness using Nanonets OCR")
            extracted_text = ocr_service.extract_text_with_layout(file_path)
        elif self.ocr_enabled:
            logger.info("Extracting text without layout awareness using Nanonets OCR")
            extracted_text = ocr_service.extract_text(file_path)
        else:
            logger.warning("OCR is disabled, returning empty content")
            extracted_text = ""
        
        # Create GPU result
        result = GPUConversionResult(
            content=extracted_text,
            metadata={
                'file_path': file_path,
                'file_type': 'image',
                'ocr_enabled': self.ocr_enabled,
                'preserve_layout': self.preserve_layout,
                'ocr_provider': 'nanonets'
            },
            gpu_processor=self,
            file_path=file_path,
            ocr_provider='nanonets'
        )
        
        logger.info(f"Image processing completed. Extracted {len(extracted_text)} characters")
        return result
    
    def _process_pdf(self, file_path: str) -> GPUConversionResult:
        """Process PDF file by converting to images and using OCR.
        
        Args:
            file_path: Path to the PDF file
            
        Returns:
            GPUConversionResult with extracted content
        """
        try:
            # Convert PDF to images
            image_paths = self._convert_pdf_to_images(file_path)
            
            if not image_paths:
                logger.warning("No pages could be extracted from PDF")
                return GPUConversionResult(
                    content="",
                    metadata={
                        'file_path': file_path,
                        'file_type': 'pdf',
                        'ocr_enabled': self.ocr_enabled,
                        'preserve_layout': self.preserve_layout,
                        'ocr_provider': 'nanonets',
                        'pages_processed': 0
                    },
                    gpu_processor=self,
                    file_path=file_path,
                    ocr_provider='nanonets'
                )
            
            # Process each page with OCR
            all_texts = []
            ocr_service = self._get_ocr_service()
            
            for i, image_path in enumerate(image_paths):
                logger.info(f"Processing PDF page {i+1}/{len(image_paths)}")
                
                try:
                    if self.ocr_enabled and self.preserve_layout:
                        page_text = ocr_service.extract_text_with_layout(image_path)
                    elif self.ocr_enabled:
                        page_text = ocr_service.extract_text(image_path)
                    else:
                        page_text = ""
                    
                    # Add page header and content if there's text
                    if page_text.strip():
                        # Add page header (markdown style)
                        all_texts.append(f"\n## Page {i+1}\n\n")
                        all_texts.append(page_text)
                        
                        # Add horizontal rule after content (except for last page)
                        if i < len(image_paths) - 1:
                            all_texts.append("\n\n---\n\n")
                    
                except Exception as e:
                    logger.error(f"Failed to process page {i+1}: {e}")
                    # Add error page with markdown formatting
                    all_texts.append(f"\n## Page {i+1}\n\n*Error processing this page: {e}*\n\n")
                    if i < len(image_paths) - 1:
                        all_texts.append("---\n\n")
                
                finally:
                    # Clean up temporary image file
                    try:
                        os.unlink(image_path)
                    except OSError:
                        pass
            
            # Combine all page texts
            combined_text = ''.join(all_texts)
            
            # Create result
            result = GPUConversionResult(
                content=combined_text,
                metadata={
                    'file_path': file_path,
                    'file_type': 'pdf',
                    'ocr_enabled': self.ocr_enabled,
                    'preserve_layout': self.preserve_layout,
                    'ocr_provider': 'nanonets',
                    'pages_processed': len(image_paths)
                },
                gpu_processor=self,
                file_path=file_path,
                ocr_provider='nanonets'
            )
            
            logger.info(f"PDF processing completed. Processed {len(image_paths)} pages, extracted {len(combined_text)} characters")
            return result
            
        except Exception as e:
            logger.error(f"Failed to process PDF {file_path}: {e}")
            raise ConversionError(f"PDF processing failed: {e}")
    
    def _convert_pdf_to_images(self, pdf_path: str) -> List[str]:
        """Convert PDF pages to images.
        
        Args:
            pdf_path: Path to the PDF file
            
        Returns:
            List of paths to temporary image files
        """
        try:
            from pdf2image import convert_from_path
            from ..config import InternalConfig
            
            # Get DPI from config
            dpi = getattr(InternalConfig, 'pdf_image_dpi', 300)
            
            # Convert PDF pages to images using pdf2image
            images = convert_from_path(pdf_path, dpi=dpi)
            image_paths = []
            
            # Save each image to a temporary file
            for page_num, image in enumerate(images):
                persistent_image_path = tempfile.mktemp(suffix='.png')
                image.save(persistent_image_path, 'PNG')
                image_paths.append(persistent_image_path)
            
            logger.info(f"Converted PDF to {len(image_paths)} images")
            return image_paths
            
        except ImportError:
            logger.error("pdf2image not available. Please install it: pip install pdf2image")
            raise ConversionError("pdf2image is required for PDF processing")
        except Exception as e:
            logger.error(f"Failed to extract PDF to images: {e}")
            raise ConversionError(f"PDF to image conversion failed: {e}")
    
    @staticmethod
    def predownload_ocr_models():
        """Pre-download OCR models by running a dummy prediction."""
        try:
            from docstrange.pipeline.ocr_service import OCRServiceFactory
            ocr_service = OCRServiceFactory.create_service('nanonets')
            # Create a blank image for testing
            from PIL import Image
            import tempfile
            with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp:
                img = Image.new('RGB', (100, 100), color='white')
                img.save(tmp.name)
                ocr_service.extract_text_with_layout(tmp.name)
                os.unlink(tmp.name)
            logger.info("Nanonets OCR models pre-downloaded and cached.")
        except Exception as e:
            logger.error(f"Failed to pre-download Nanonets OCR models: {e}")