kusatmer commited on
Commit
b8b55ff
·
0 Parent(s):

feat: Implement initial image text extraction application with Streamlit UI, OCR service, and tests.

Browse files
.gitignore ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # with no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
130
+
131
+ # pytype static type analyzer
132
+ .pytype/
133
+
134
+ # Cython debug symbols
135
+ cython_debug/
136
+
137
+ # macOS
138
+ .DS_Store
139
+
140
+ # VS Code
141
+
.vscode/settings.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "python.defaultInterpreterPath": "/opt/homebrew/bin/python3",
3
+ "python.analysis.typeCheckingMode": "off"
4
+ }
README.md ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Image Text Extractor
2
+
3
+ This project is a Streamlit application that uses the `olmOCR` model (based on Qwen2.5-VL) to extract text from images. It provides a user-friendly interface to upload images and view the extracted text along with metadata.
4
+
5
+ ## Features
6
+
7
+ - **Image Upload**: Support for PNG, JPG, and JPEG formats.
8
+ - **Text Extraction**: Uses state-of-the-art Vision-Language Models for accurate OCR.
9
+ - **Metadata Extraction**: Extracts additional information like primary language, rotation, and content type (table, diagram).
10
+ - **JSON Export**: Download extraction results as JSON files.
11
+ - **Configurable**: Adjust maximum token generation for longer documents.
12
+
13
+ ## Installation
14
+
15
+ 1. **Clone the repository**:
16
+ ```bash
17
+ git clone <repository-url>
18
+ cd image-text-extractor
19
+ ```
20
+
21
+ 2. **Create a virtual environment** (recommended):
22
+ ```bash
23
+ python -m venv venv
24
+ source venv/bin/activate # On Windows: venv\Scripts\activate
25
+ ```
26
+
27
+ 3. **Install dependencies**:
28
+ ```bash
29
+ pip install -r requirements.txt
30
+ ```
31
+
32
+ ## Usage
33
+
34
+ 1. **Run the Streamlit app**:
35
+ ```bash
36
+ streamlit run streamlit_app.py
37
+ ```
38
+
39
+ 2. **Open your browser**:
40
+ The app should automatically open in your default browser at `http://localhost:8501`.
41
+
42
+ ## Testing
43
+
44
+ This project uses `pytest` for unit testing.
45
+
46
+ 1. **Run tests**:
47
+ ```bash
48
+ pytest tests/
49
+ ```
50
+
51
+ ## Project Structure
52
+
53
+ - `streamlit_app.py`: The main entry point for the Streamlit application.
54
+ - `service/`: Contains the backend logic for text extraction.
55
+ - `text_extraction_service.py`: The core service class handling model interaction.
56
+ - `tests/`: Unit tests for the application.
57
+ - `requirements.txt`: Python dependencies.
58
+
59
+ ## License
60
+
61
+ [Add License Here]
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit>=1.28.0
2
+ torch>=2.0.0
3
+ torchvision>=0.15.0
4
+ transformers>=4.55.2
5
+ pillow>=10.0.0
6
+ olmocr>=0.4.6
7
+ pytest>=7.0.0
8
+ pytest-mock>=3.10.0
9
+
10
+
service/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Services package for text extraction functionality
2
+
3
+
service/text_extraction_service.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Text Extraction Service
3
+ Handles OCR text extraction from images using olmOCR model.
4
+ Separated from UI concerns for better maintainability.
5
+ """
6
+ import base64
7
+ import json
8
+ import os
9
+ import re
10
+ from io import BytesIO
11
+ from typing import Dict, Tuple, Optional
12
+
13
+ import torch
14
+ from PIL import Image
15
+ from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
16
+ from olmocr.prompts import build_no_anchoring_v4_yaml_prompt
17
+
18
+
19
+ class TextExtractionService:
20
+ """
21
+ Service class for extracting text from images using olmOCR model.
22
+ Handles model initialization, image processing, and result formatting.
23
+ """
24
+
25
+ def __init__(self, model_name: str = "allenai/olmOCR-2-7B-1025",
26
+ processor_name: str = "Qwen/Qwen2.5-VL-7B-Instruct"):
27
+ """
28
+ Initialize the text extraction service with model and processor.
29
+
30
+ Args:
31
+ model_name: Name of the olmOCR model to use
32
+ processor_name: Name of the processor to use
33
+ """
34
+ self.model_name = model_name
35
+ self.processor_name = processor_name
36
+ self.model = None
37
+ self.processor = None
38
+ self.device = None
39
+ self._initialize_model()
40
+
41
+ def _initialize_model(self):
42
+ """Initialize the model and processor, set up device."""
43
+ # Initialize model
44
+ self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
45
+ self.model_name,
46
+ torch_dtype=torch.bfloat16
47
+ ).eval()
48
+
49
+ # Initialize processor
50
+ self.processor = AutoProcessor.from_pretrained(self.processor_name)
51
+
52
+ # Determine device (CUDA, MPS for Mac, or CPU)
53
+ if torch.cuda.is_available():
54
+ self.device = torch.device("cuda")
55
+ elif torch.backends.mps.is_available():
56
+ self.device = torch.device("mps")
57
+ else:
58
+ self.device = torch.device("cpu")
59
+
60
+ # Move model to device
61
+ self.model.to(self.device)
62
+
63
+ def _parse_ocr_output(self, raw_text: str) -> Tuple[Dict, str]:
64
+ """
65
+ Parse OCR output that contains YAML frontmatter and extract metadata and text separately.
66
+
67
+ Args:
68
+ raw_text: Raw output from OCR model
69
+
70
+ Returns:
71
+ Tuple of (metadata_dict, extracted_text)
72
+ """
73
+ # Split by YAML delimiters
74
+ parts = raw_text.split("---")
75
+
76
+ metadata = {}
77
+ extracted_text = ""
78
+
79
+ if len(parts) >= 3:
80
+ # Extract metadata from between first two --- markers
81
+ yaml_content = parts[1].strip()
82
+ # Extract text after second --- marker
83
+ extracted_text = parts[2].strip()
84
+
85
+ # Parse YAML-like key-value pairs
86
+ for line in yaml_content.split("\n"):
87
+ line = line.strip()
88
+ if ":" in line:
89
+ key, value = line.split(":", 1)
90
+ key = key.strip()
91
+ value = value.strip()
92
+
93
+ # Convert string booleans and numbers
94
+ if value.lower() == "true":
95
+ value = True
96
+ elif value.lower() == "false":
97
+ value = False
98
+ elif value.isdigit():
99
+ value = int(value)
100
+ elif re.match(r"^-?\d+\.\d+$", value):
101
+ value = float(value)
102
+
103
+ metadata[key] = value
104
+ else:
105
+ # No YAML frontmatter found, use entire text
106
+ extracted_text = raw_text.strip()
107
+
108
+ return metadata, extracted_text
109
+
110
+ def extract_text_from_image(self, image: Image.Image,
111
+ max_new_tokens: int = 2048) -> Dict:
112
+ """
113
+ Extract text from a PIL Image object.
114
+
115
+ Args:
116
+ image: PIL Image object to extract text from
117
+ max_new_tokens: Maximum number of tokens to generate
118
+
119
+ Returns:
120
+ Dictionary containing extracted text and metadata
121
+ """
122
+ # Convert image to base64
123
+ buffered = BytesIO()
124
+ image.save(buffered, format="PNG")
125
+ image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
126
+
127
+ # Build the full prompt
128
+ messages = [
129
+ {
130
+ "role": "user",
131
+ "content": [
132
+ {"type": "text", "text": build_no_anchoring_v4_yaml_prompt()},
133
+ {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}},
134
+ ],
135
+ }
136
+ ]
137
+
138
+ # Apply the chat template and processor
139
+ text = self.processor.apply_chat_template(
140
+ messages,
141
+ tokenize=False,
142
+ add_generation_prompt=True
143
+ )
144
+
145
+ # Process inputs
146
+ inputs = self.processor(
147
+ text=[text],
148
+ images=[image],
149
+ padding=True,
150
+ return_tensors="pt",
151
+ )
152
+ inputs = {key: value.to(self.device) for (key, value) in inputs.items()}
153
+
154
+ # Generate the output
155
+ output = self.model.generate(
156
+ **inputs,
157
+ temperature=0.1,
158
+ max_new_tokens=max_new_tokens,
159
+ num_return_sequences=1,
160
+ do_sample=True,
161
+ )
162
+
163
+ # Decode the output
164
+ prompt_length = inputs["input_ids"].shape[1]
165
+ new_tokens = output[:, prompt_length:]
166
+ text_output = self.processor.tokenizer.batch_decode(
167
+ new_tokens,
168
+ skip_special_tokens=True
169
+ )
170
+
171
+ # Extract the text content
172
+ raw_output = text_output[0] if text_output else ""
173
+
174
+ # Parse the output
175
+ metadata, extracted_text = self._parse_ocr_output(raw_output)
176
+
177
+ # Prepare result data structure
178
+ result_data = {
179
+ "extracted_text": extracted_text,
180
+ "primary_language": metadata.get("primary_language", None),
181
+ "is_rotation_valid": metadata.get("is_rotation_valid", None),
182
+ "rotation_correction": metadata.get("rotation_correction", None),
183
+ "is_table": metadata.get("is_table", None),
184
+ "is_diagram": metadata.get("is_diagram", None),
185
+ "model": self.model_name,
186
+ "processor": self.processor_name
187
+ }
188
+
189
+ return result_data
190
+
191
+ def save_result_to_json(self, result_data: Dict, output_path: str,
192
+ source_image_name: Optional[str] = None):
193
+ """
194
+ Save extraction result to JSON file.
195
+
196
+ Args:
197
+ result_data: Dictionary containing extraction results
198
+ output_path: Path where to save the JSON file
199
+ source_image_name: Optional name of the source image
200
+ """
201
+ # Add source image name if provided
202
+ if source_image_name:
203
+ result_data["source_image"] = source_image_name
204
+
205
+ # Ensure output directory exists
206
+ output_dir = os.path.dirname(output_path)
207
+ if output_dir:
208
+ os.makedirs(output_dir, exist_ok=True)
209
+
210
+ # Save to JSON file
211
+ with open(output_path, "w", encoding="utf-8") as json_file:
212
+ json.dump(result_data, json_file, ensure_ascii=False, indent=2)
213
+
streamlit_app.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Streamlit App for Text Extraction from Images
3
+ UI layer for the text extraction service.
4
+ """
5
+ import html
6
+ import json
7
+ from pathlib import Path
8
+
9
+ import streamlit as st
10
+ from PIL import Image
11
+
12
+ from service.text_extraction_service import TextExtractionService
13
+
14
+
15
+ # Page configuration
16
+ st.set_page_config(
17
+ page_title="Text Extraction from Images",
18
+ page_icon="📄",
19
+ layout="wide"
20
+ )
21
+
22
+ # Initialize session state
23
+ if "extraction_service" not in st.session_state:
24
+ st.session_state.extraction_service = None
25
+ if "extraction_result" not in st.session_state:
26
+ st.session_state.extraction_result = None
27
+
28
+
29
+ @st.cache_resource
30
+ def get_extraction_service():
31
+ """
32
+ Get or create the text extraction service instance.
33
+ Cached to avoid reloading the model on every interaction.
34
+ """
35
+ if st.session_state.extraction_service is None:
36
+ with st.spinner("Loading OCR model... This may take a moment."):
37
+ service = TextExtractionService()
38
+ st.session_state.extraction_service = service
39
+ return st.session_state.extraction_service
40
+
41
+
42
+ def main():
43
+ """Main application function."""
44
+ st.title("📄 Text Extraction from Images")
45
+ st.markdown("Upload an image to extract text using olmOCR model.")
46
+
47
+ # Sidebar for settings
48
+ with st.sidebar:
49
+ st.header("⚙️ Settings")
50
+ max_tokens = st.slider(
51
+ "Max Tokens",
52
+ min_value=512,
53
+ max_value=4096,
54
+ value=2048,
55
+ step=256,
56
+ help="Maximum number of tokens to generate. Higher values allow longer text extraction."
57
+ )
58
+
59
+ # File uploader
60
+ uploaded_file = st.file_uploader(
61
+ "Choose an image file",
62
+ type=["png", "jpg", "jpeg"],
63
+ help="Upload an image file (PNG, JPG, JPEG)"
64
+ )
65
+
66
+ if uploaded_file is not None:
67
+ # Display uploaded image
68
+ st.subheader("📷 Uploaded Image")
69
+ image = Image.open(uploaded_file)
70
+ st.image(image)
71
+ st.caption(f"File: {uploaded_file.name}")
72
+
73
+ st.divider()
74
+
75
+ # Extract button
76
+ st.subheader("📝 Text Extraction")
77
+ if st.button("🔍 Extract Text", type="primary"):
78
+ try:
79
+ # Get extraction service
80
+ service = get_extraction_service()
81
+
82
+ # Extract text
83
+ with st.spinner("Extracting text from image... This may take a while."):
84
+ result = service.extract_text_from_image(
85
+ image,
86
+ max_new_tokens=max_tokens
87
+ )
88
+
89
+ # Store result in session state
90
+ st.session_state.extraction_result = result
91
+ st.session_state.extraction_result["source_image"] = uploaded_file.name
92
+
93
+ except Exception as e:
94
+ st.error(f"❌ Error during extraction: {str(e)}")
95
+ st.exception(e)
96
+
97
+ # Display results if available
98
+ if st.session_state.extraction_result:
99
+ st.divider()
100
+ result = st.session_state.extraction_result
101
+
102
+ st.subheader("📄 Extracted Text")
103
+ # Display extracted text with black color
104
+ extracted_text = result.get("extracted_text", "")
105
+ # Escape HTML to prevent injection and ensure proper display
106
+ escaped_text = html.escape(extracted_text)
107
+ st.markdown(
108
+ f'<div style="background-color: #f0f2f6; padding: 15px; border-radius: 5px; max-height: 300px; overflow-y: auto; color: #000000; white-space: pre-wrap; font-family: monospace;">{escaped_text}</div>',
109
+ unsafe_allow_html=True
110
+ )
111
+
112
+ # Display metadata (full JSON)
113
+ with st.expander("📊 Full JSON Metadata"):
114
+ st.json(result)
115
+
116
+ # Download JSON button
117
+ json_str = json.dumps(result, ensure_ascii=False, indent=2)
118
+ st.download_button(
119
+ label="💾 Download JSON",
120
+ data=json_str,
121
+ file_name=f"{Path(uploaded_file.name).stem}.json",
122
+ mime="application/json"
123
+ )
124
+
125
+ else:
126
+ # Show instructions when no file is uploaded
127
+ st.info("👆 Please upload an image file to get started.")
128
+
129
+ # Example section
130
+ with st.expander("ℹ️ How to use"):
131
+ st.markdown("""
132
+ 1. **Upload an image**: Click on the file uploader and select an image file (PNG, JPG, JPEG)
133
+ 2. **Adjust settings** (optional): Use the sidebar to adjust max tokens if needed
134
+ 3. **Extract text**: Click the "Extract Text" button
135
+ 4. **View results**: The extracted text and metadata will be displayed
136
+ 5. **Download**: Download the results as JSON if needed
137
+
138
+ **Note**: The first extraction may take longer as the model needs to be loaded.
139
+ Subsequent extractions will be faster.
140
+ """)
141
+
142
+
143
+ if __name__ == "__main__":
144
+ main()
145
+
tests/test_text_extraction_service.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from unittest.mock import MagicMock, patch
3
+ from PIL import Image
4
+ from service.text_extraction_service import TextExtractionService
5
+
6
+ @pytest.fixture
7
+ def mock_service(mocker):
8
+ """Fixture to create a TextExtractionService with mocked model and processor."""
9
+ with patch("service.text_extraction_service.Qwen2_5_VLForConditionalGeneration") as mock_model_cls, \
10
+ patch("service.text_extraction_service.AutoProcessor") as mock_processor_cls, \
11
+ patch("torch.cuda.is_available", return_value=False), \
12
+ patch("torch.backends.mps.is_available", return_value=False):
13
+
14
+ mock_model = MagicMock()
15
+ mock_model_cls.from_pretrained.return_value = mock_model
16
+
17
+ mock_processor = MagicMock()
18
+ mock_processor_cls.from_pretrained.return_value = mock_processor
19
+
20
+ service = TextExtractionService()
21
+ return service, mock_model, mock_processor
22
+
23
+ def test_parse_ocr_output_with_yaml(mock_service):
24
+ service, _, _ = mock_service
25
+
26
+ raw_text = """Some prefix text
27
+ ---
28
+ primary_language: English
29
+ is_rotation_valid: true
30
+ rotation_correction: 0
31
+ is_table: false
32
+ ---
33
+ This is the extracted text content.
34
+ It has multiple lines.
35
+ """
36
+ metadata, text = service._parse_ocr_output(raw_text)
37
+
38
+ assert metadata["primary_language"] == "English"
39
+ assert metadata["is_rotation_valid"] is True
40
+ assert metadata["rotation_correction"] == 0
41
+ assert metadata["is_table"] is False
42
+ assert text == "This is the extracted text content.\nIt has multiple lines."
43
+
44
+ def test_parse_ocr_output_without_yaml(mock_service):
45
+ service, _, _ = mock_service
46
+
47
+ raw_text = "Just some plain text without any YAML frontmatter."
48
+ metadata, text = service._parse_ocr_output(raw_text)
49
+
50
+ assert metadata == {}
51
+ assert text == "Just some plain text without any YAML frontmatter."
52
+
53
+ def test_parse_ocr_output_malformed_yaml(mock_service):
54
+ service, _, _ = mock_service
55
+
56
+ # Missing the second separator
57
+ raw_text = """---
58
+ key: value
59
+ This should probably fail to parse as YAML but return text.
60
+ """
61
+ metadata, text = service._parse_ocr_output(raw_text)
62
+
63
+ # Based on current implementation logic:
64
+ # split('---') will return ['', '\nkey: value\nThis should...', ''] if it ends with ---
65
+ # or just 2 parts if it starts with --- but doesn't end.
66
+ # The implementation checks if len(parts) >= 3.
67
+
68
+ # If there are only 2 parts (one separator), it falls back to returning everything as text.
69
+ assert metadata == {}
70
+ assert "key: value" in text
71
+
72
+ def test_extract_text_from_image(mock_service):
73
+ service, mock_model, mock_processor = mock_service
74
+
75
+ # Mock image
76
+ image = Image.new('RGB', (100, 100), color='red')
77
+
78
+ # Mock processor output
79
+ mock_processor.apply_chat_template.return_value = "mock_prompt"
80
+ mock_processor.return_value = {"input_ids": MagicMock(), "pixel_values": MagicMock()}
81
+ mock_processor.return_value["input_ids"].shape = [1, 10] # Mock shape
82
+
83
+ # Mock tokenizer decode
84
+ mock_processor.tokenizer.batch_decode.return_value = ["""---
85
+ primary_language: English
86
+ ---
87
+ Extracted Text"""]
88
+
89
+ # Mock model generate
90
+ mock_model.generate.return_value = MagicMock() # Return value doesn't matter much as we mock batch_decode
91
+
92
+ result = service.extract_text_from_image(image)
93
+
94
+ assert result["extracted_text"] == "Extracted Text"
95
+ assert result["primary_language"] == "English"
96
+ assert result["model"] == service.model_name
97
+
98
+ def test_initialization_device_selection():
99
+ """Test that the correct device is selected based on availability."""
100
+ with patch("service.text_extraction_service.Qwen2_5_VLForConditionalGeneration"), \
101
+ patch("service.text_extraction_service.AutoProcessor"):
102
+
103
+ # Test CPU
104
+ with patch("torch.cuda.is_available", return_value=False), \
105
+ patch("torch.backends.mps.is_available", return_value=False):
106
+ service = TextExtractionService()
107
+ assert service.device.type == "cpu"
108
+
109
+ # Test CUDA
110
+ with patch("torch.cuda.is_available", return_value=True):
111
+ service = TextExtractionService()
112
+ assert service.device.type == "cuda"