Spaces:
Sleeping
Sleeping
File size: 7,386 Bytes
b8b55ff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 |
"""
Text Extraction Service
Handles OCR text extraction from images using olmOCR model.
Separated from UI concerns for better maintainability.
"""
import base64
import json
import os
import re
from io import BytesIO
from typing import Dict, Tuple, Optional
import torch
from PIL import Image
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
from olmocr.prompts import build_no_anchoring_v4_yaml_prompt
class TextExtractionService:
"""
Service class for extracting text from images using olmOCR model.
Handles model initialization, image processing, and result formatting.
"""
def __init__(self, model_name: str = "allenai/olmOCR-2-7B-1025",
processor_name: str = "Qwen/Qwen2.5-VL-7B-Instruct"):
"""
Initialize the text extraction service with model and processor.
Args:
model_name: Name of the olmOCR model to use
processor_name: Name of the processor to use
"""
self.model_name = model_name
self.processor_name = processor_name
self.model = None
self.processor = None
self.device = None
self._initialize_model()
def _initialize_model(self):
"""Initialize the model and processor, set up device."""
# Initialize model
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
self.model_name,
torch_dtype=torch.bfloat16
).eval()
# Initialize processor
self.processor = AutoProcessor.from_pretrained(self.processor_name)
# Determine device (CUDA, MPS for Mac, or CPU)
if torch.cuda.is_available():
self.device = torch.device("cuda")
elif torch.backends.mps.is_available():
self.device = torch.device("mps")
else:
self.device = torch.device("cpu")
# Move model to device
self.model.to(self.device)
def _parse_ocr_output(self, raw_text: str) -> Tuple[Dict, str]:
"""
Parse OCR output that contains YAML frontmatter and extract metadata and text separately.
Args:
raw_text: Raw output from OCR model
Returns:
Tuple of (metadata_dict, extracted_text)
"""
# Split by YAML delimiters
parts = raw_text.split("---")
metadata = {}
extracted_text = ""
if len(parts) >= 3:
# Extract metadata from between first two --- markers
yaml_content = parts[1].strip()
# Extract text after second --- marker
extracted_text = parts[2].strip()
# Parse YAML-like key-value pairs
for line in yaml_content.split("\n"):
line = line.strip()
if ":" in line:
key, value = line.split(":", 1)
key = key.strip()
value = value.strip()
# Convert string booleans and numbers
if value.lower() == "true":
value = True
elif value.lower() == "false":
value = False
elif value.isdigit():
value = int(value)
elif re.match(r"^-?\d+\.\d+$", value):
value = float(value)
metadata[key] = value
else:
# No YAML frontmatter found, use entire text
extracted_text = raw_text.strip()
return metadata, extracted_text
def extract_text_from_image(self, image: Image.Image,
max_new_tokens: int = 2048) -> Dict:
"""
Extract text from a PIL Image object.
Args:
image: PIL Image object to extract text from
max_new_tokens: Maximum number of tokens to generate
Returns:
Dictionary containing extracted text and metadata
"""
# Convert image to base64
buffered = BytesIO()
image.save(buffered, format="PNG")
image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
# Build the full prompt
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": build_no_anchoring_v4_yaml_prompt()},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}},
],
}
]
# Apply the chat template and processor
text = self.processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Process inputs
inputs = self.processor(
text=[text],
images=[image],
padding=True,
return_tensors="pt",
)
inputs = {key: value.to(self.device) for (key, value) in inputs.items()}
# Generate the output
output = self.model.generate(
**inputs,
temperature=0.1,
max_new_tokens=max_new_tokens,
num_return_sequences=1,
do_sample=True,
)
# Decode the output
prompt_length = inputs["input_ids"].shape[1]
new_tokens = output[:, prompt_length:]
text_output = self.processor.tokenizer.batch_decode(
new_tokens,
skip_special_tokens=True
)
# Extract the text content
raw_output = text_output[0] if text_output else ""
# Parse the output
metadata, extracted_text = self._parse_ocr_output(raw_output)
# Prepare result data structure
result_data = {
"extracted_text": extracted_text,
"primary_language": metadata.get("primary_language", None),
"is_rotation_valid": metadata.get("is_rotation_valid", None),
"rotation_correction": metadata.get("rotation_correction", None),
"is_table": metadata.get("is_table", None),
"is_diagram": metadata.get("is_diagram", None),
"model": self.model_name,
"processor": self.processor_name
}
return result_data
def save_result_to_json(self, result_data: Dict, output_path: str,
source_image_name: Optional[str] = None):
"""
Save extraction result to JSON file.
Args:
result_data: Dictionary containing extraction results
output_path: Path where to save the JSON file
source_image_name: Optional name of the source image
"""
# Add source image name if provided
if source_image_name:
result_data["source_image"] = source_image_name
# Ensure output directory exists
output_dir = os.path.dirname(output_path)
if output_dir:
os.makedirs(output_dir, exist_ok=True)
# Save to JSON file
with open(output_path, "w", encoding="utf-8") as json_file:
json.dump(result_data, json_file, ensure_ascii=False, indent=2)
|