212hjh / ocr_model.py
jzhang533's picture
fix PaddleOCR-VL-1.5 text spotting: use correct prompt and coordinate format
adabb98
"""
PaddleOCR-VL-1.5 Model Wrapper
Provides an easy-to-use interface for text detection and recognition
"""
import re
import os
import torch
from typing import Dict, List, Tuple, Optional
from PIL import Image
from transformers import AutoProcessor, AutoModelForImageTextToText
import requests
from io import BytesIO
class PaddleOCRVL:
"""Wrapper class for PaddleOCR-VL-1.5 model for text spotting tasks"""
def __init__(self, model_path: str = "PaddlePaddle/PaddleOCR-VL-1.5", device: Optional[str] = None):
"""
Initialize the PaddleOCR-VL-1.5 model
Args:
model_path: Path or name of the model (default: "PaddlePaddle/PaddleOCR-VL-1.5")
device: Device to load model on (cuda/cpu). Auto-detected if None.
"""
self.model_path = model_path
if device is None:
if torch.cuda.is_available():
self.device = "cuda"
elif torch.backends.mps.is_available():
self.device = "mps"
else:
self.device = "cpu"
else:
self.device = device
print(f"Loading PaddleOCR-VL-1.5 model on {self.device}...")
try:
self.processor = AutoProcessor.from_pretrained(model_path)
except Exception:
print("Network error loading processor, falling back to local cache...")
self.processor = AutoProcessor.from_pretrained(model_path, local_files_only=True)
if self.device == "cuda":
torch_dtype = torch.bfloat16
elif self.device == "mps":
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
try:
self.model = AutoModelForImageTextToText.from_pretrained(
model_path,
dtype=torch_dtype,
device_map="auto" if self.device == "cuda" else None
)
except Exception:
print("Network error loading model, falling back to local cache...")
self.model = AutoModelForImageTextToText.from_pretrained(
model_path,
dtype=torch_dtype,
device_map="auto" if self.device == "cuda" else None,
local_files_only=True
)
if self.device != "cuda":
self.model = self.model.to(self.device)
print("Model loaded successfully!")
def clean_repeated_substrings(self, text: str) -> str:
n = len(text)
if n < 8000:
return text
for length in range(2, n // 10 + 1):
candidate = text[-length:]
count = 0
i = n - length
while i >= 0 and text[i:i + length] == candidate:
count += 1
i -= length
if count >= 10:
return text[:n - length * (count - 1)]
return text
def load_image(self, image_source: str) -> Image.Image:
if image_source.startswith(('http://', 'https://')):
response = requests.get(image_source)
response.raise_for_status()
return Image.open(BytesIO(response.content))
else:
return Image.open(image_source)
def detect_text(self, image: Image.Image, prompt: Optional[str] = None) -> str:
"""
Detect and recognize text in image with bounding boxes
Args:
image: PIL Image object
prompt: Custom prompt (default: text spotting prompt in Chinese)
Returns:
Model response with detected text and coordinates
"""
if prompt is None:
prompt = "Spotting:"
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt},
],
}
]
text = self.processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
inputs = self.processor(
text=[text],
images=[image],
padding=True,
return_tensors="pt",
)
if self.device == "cuda":
device = next(self.model.parameters()).device
inputs = inputs.to(device)
else:
inputs = inputs.to(self.device)
with torch.no_grad():
generated_ids = self.model.generate(
**inputs,
max_new_tokens=2048,
do_sample=False
)
generated_ids_trimmed = [
out_ids[len(in_ids):]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = self.processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)[0]
output_text = self.clean_repeated_substrings(output_text)
return output_text
def parse_detection_results(self, response: str, image_width: int, image_height: int) -> List[Dict]:
"""
Parse detection response into structured format with denormalized coordinates
Args:
response: Model output text
image_width: Image width in pixels
image_height: Image height in pixels
Returns:
List of dictionaries with 'text', 'x1', 'y1', 'x2', 'y2' keys
"""
results = []
# Pattern to match text followed by <|LOC_xxx|> tokens (8 per detection, quadrilateral)
for match in re.finditer(r'([^<\n]+?)((?:<\|LOC_\d+\|>)+)', response):
try:
text = match.group(1).strip()
locs = [int(v) for v in re.findall(r'<\|LOC_(\d+)\|>', match.group(2))]
if len(locs) != 8:
continue
xs = [locs[i] for i in range(0, 8, 2)]
ys = [locs[i] for i in range(1, 8, 2)]
x1 = int(min(xs) * image_width / 1000)
y1 = int(min(ys) * image_height / 1000)
x2 = int(max(xs) * image_width / 1000)
y2 = int(max(ys) * image_height / 1000)
results.append({
'text': text,
'x1': x1,
'y1': y1,
'x2': x2,
'y2': y2
})
except Exception as e:
print(f"Error parsing detection result: {str(e)}")
continue
return results
def process_image(self, image_source: str, prompt: Optional[str] = None) -> Tuple[str, List[Dict]]:
"""
Complete pipeline: load image, detect text, parse results
Args:
image_source: Path or URL to image
prompt: Custom prompt for detection
Returns:
Tuple of (raw_response, parsed_results, image)
"""
image = self.load_image(image_source)
image_width, image_height = image.size
response = self.detect_text(image, prompt)
parsed_results = self.parse_detection_results(response, image_width, image_height)
return response, parsed_results, image