File size: 7,386 Bytes
b8b55ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
"""
Text Extraction Service
Handles OCR text extraction from images using olmOCR model.
Separated from UI concerns for better maintainability.
"""
import base64
import json
import os
import re
from io import BytesIO
from typing import Dict, Tuple, Optional

import torch
from PIL import Image
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
from olmocr.prompts import build_no_anchoring_v4_yaml_prompt


class TextExtractionService:
    """
    Service class for extracting text from images using olmOCR model.
    Handles model initialization, image processing, and result formatting.
    """
    
    def __init__(self, model_name: str = "allenai/olmOCR-2-7B-1025", 
                 processor_name: str = "Qwen/Qwen2.5-VL-7B-Instruct"):
        """
        Initialize the text extraction service with model and processor.
        
        Args:
            model_name: Name of the olmOCR model to use
            processor_name: Name of the processor to use
        """
        self.model_name = model_name
        self.processor_name = processor_name
        self.model = None
        self.processor = None
        self.device = None
        self._initialize_model()
    
    def _initialize_model(self):
        """Initialize the model and processor, set up device."""
        # Initialize model
        self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            self.model_name, 
            torch_dtype=torch.bfloat16
        ).eval()
        
        # Initialize processor
        self.processor = AutoProcessor.from_pretrained(self.processor_name)
        
        # Determine device (CUDA, MPS for Mac, or CPU)
        if torch.cuda.is_available():
            self.device = torch.device("cuda")
        elif torch.backends.mps.is_available():
            self.device = torch.device("mps")
        else:
            self.device = torch.device("cpu")
        
        # Move model to device
        self.model.to(self.device)
    
    def _parse_ocr_output(self, raw_text: str) -> Tuple[Dict, str]:
        """
        Parse OCR output that contains YAML frontmatter and extract metadata and text separately.
        
        Args:
            raw_text: Raw output from OCR model
            
        Returns:
            Tuple of (metadata_dict, extracted_text)
        """
        # Split by YAML delimiters
        parts = raw_text.split("---")
        
        metadata = {}
        extracted_text = ""
        
        if len(parts) >= 3:
            # Extract metadata from between first two --- markers
            yaml_content = parts[1].strip()
            # Extract text after second --- marker
            extracted_text = parts[2].strip()
            
            # Parse YAML-like key-value pairs
            for line in yaml_content.split("\n"):
                line = line.strip()
                if ":" in line:
                    key, value = line.split(":", 1)
                    key = key.strip()
                    value = value.strip()
                    
                    # Convert string booleans and numbers
                    if value.lower() == "true":
                        value = True
                    elif value.lower() == "false":
                        value = False
                    elif value.isdigit():
                        value = int(value)
                    elif re.match(r"^-?\d+\.\d+$", value):
                        value = float(value)
                    
                    metadata[key] = value
        else:
            # No YAML frontmatter found, use entire text
            extracted_text = raw_text.strip()
        
        return metadata, extracted_text
    
    def extract_text_from_image(self, image: Image.Image, 
                                max_new_tokens: int = 2048) -> Dict:
        """
        Extract text from a PIL Image object.
        
        Args:
            image: PIL Image object to extract text from
            max_new_tokens: Maximum number of tokens to generate
            
        Returns:
            Dictionary containing extracted text and metadata
        """
        # Convert image to base64
        buffered = BytesIO()
        image.save(buffered, format="PNG")
        image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
        
        # Build the full prompt
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": build_no_anchoring_v4_yaml_prompt()},
                    {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}},
                ],
            }
        ]
        
        # Apply the chat template and processor
        text = self.processor.apply_chat_template(
            messages, 
            tokenize=False, 
            add_generation_prompt=True
        )
        
        # Process inputs
        inputs = self.processor(
            text=[text],
            images=[image],
            padding=True,
            return_tensors="pt",
        )
        inputs = {key: value.to(self.device) for (key, value) in inputs.items()}
        
        # Generate the output
        output = self.model.generate(
            **inputs,
            temperature=0.1,
            max_new_tokens=max_new_tokens,
            num_return_sequences=1,
            do_sample=True,
        )
        
        # Decode the output
        prompt_length = inputs["input_ids"].shape[1]
        new_tokens = output[:, prompt_length:]
        text_output = self.processor.tokenizer.batch_decode(
            new_tokens, 
            skip_special_tokens=True
        )
        
        # Extract the text content
        raw_output = text_output[0] if text_output else ""
        
        # Parse the output
        metadata, extracted_text = self._parse_ocr_output(raw_output)
        
        # Prepare result data structure
        result_data = {
            "extracted_text": extracted_text,
            "primary_language": metadata.get("primary_language", None),
            "is_rotation_valid": metadata.get("is_rotation_valid", None),
            "rotation_correction": metadata.get("rotation_correction", None),
            "is_table": metadata.get("is_table", None),
            "is_diagram": metadata.get("is_diagram", None),
            "model": self.model_name,
            "processor": self.processor_name
        }
        
        return result_data
    
    def save_result_to_json(self, result_data: Dict, output_path: str, 
                           source_image_name: Optional[str] = None):
        """
        Save extraction result to JSON file.
        
        Args:
            result_data: Dictionary containing extraction results
            output_path: Path where to save the JSON file
            source_image_name: Optional name of the source image
        """
        # Add source image name if provided
        if source_image_name:
            result_data["source_image"] = source_image_name
        
        # Ensure output directory exists
        output_dir = os.path.dirname(output_path)
        if output_dir:
            os.makedirs(output_dir, exist_ok=True)
        
        # Save to JSON file
        with open(output_path, "w", encoding="utf-8") as json_file:
            json.dump(result_data, json_file, ensure_ascii=False, indent=2)