| """ |
| PaddleOCR-VL-1.5 Model Wrapper |
| Provides an easy-to-use interface for text detection and recognition |
| """ |
| import re |
| import os |
| import torch |
| from typing import Dict, List, Tuple, Optional |
| from PIL import Image |
| from transformers import AutoProcessor, AutoModelForImageTextToText |
| import requests |
| from io import BytesIO |
|
|
|
|
| class PaddleOCRVL: |
| """Wrapper class for PaddleOCR-VL-1.5 model for text spotting tasks""" |
|
|
| def __init__(self, model_path: str = "PaddlePaddle/PaddleOCR-VL-1.5", device: Optional[str] = None): |
| """ |
| Initialize the PaddleOCR-VL-1.5 model |
| |
| Args: |
| model_path: Path or name of the model (default: "PaddlePaddle/PaddleOCR-VL-1.5") |
| device: Device to load model on (cuda/cpu). Auto-detected if None. |
| """ |
| self.model_path = model_path |
|
|
| if device is None: |
| if torch.cuda.is_available(): |
| self.device = "cuda" |
| elif torch.backends.mps.is_available(): |
| self.device = "mps" |
| else: |
| self.device = "cpu" |
| else: |
| self.device = device |
|
|
| print(f"Loading PaddleOCR-VL-1.5 model on {self.device}...") |
|
|
| try: |
| self.processor = AutoProcessor.from_pretrained(model_path) |
| except Exception: |
| print("Network error loading processor, falling back to local cache...") |
| self.processor = AutoProcessor.from_pretrained(model_path, local_files_only=True) |
|
|
| if self.device == "cuda": |
| torch_dtype = torch.bfloat16 |
| elif self.device == "mps": |
| torch_dtype = torch.float16 |
| else: |
| torch_dtype = torch.float32 |
|
|
| try: |
| self.model = AutoModelForImageTextToText.from_pretrained( |
| model_path, |
| dtype=torch_dtype, |
| device_map="auto" if self.device == "cuda" else None |
| ) |
| except Exception: |
| print("Network error loading model, falling back to local cache...") |
| self.model = AutoModelForImageTextToText.from_pretrained( |
| model_path, |
| dtype=torch_dtype, |
| device_map="auto" if self.device == "cuda" else None, |
| local_files_only=True |
| ) |
|
|
| if self.device != "cuda": |
| self.model = self.model.to(self.device) |
|
|
| print("Model loaded successfully!") |
|
|
| def clean_repeated_substrings(self, text: str) -> str: |
| n = len(text) |
| if n < 8000: |
| return text |
|
|
| for length in range(2, n // 10 + 1): |
| candidate = text[-length:] |
| count = 0 |
| i = n - length |
|
|
| while i >= 0 and text[i:i + length] == candidate: |
| count += 1 |
| i -= length |
|
|
| if count >= 10: |
| return text[:n - length * (count - 1)] |
|
|
| return text |
|
|
| def load_image(self, image_source: str) -> Image.Image: |
| if image_source.startswith(('http://', 'https://')): |
| response = requests.get(image_source) |
| response.raise_for_status() |
| return Image.open(BytesIO(response.content)) |
| else: |
| return Image.open(image_source) |
|
|
| def detect_text(self, image: Image.Image, prompt: Optional[str] = None) -> str: |
| """ |
| Detect and recognize text in image with bounding boxes |
| |
| Args: |
| image: PIL Image object |
| prompt: Custom prompt (default: text spotting prompt in Chinese) |
| |
| Returns: |
| Model response with detected text and coordinates |
| """ |
| if prompt is None: |
| prompt = "Spotting:" |
|
|
| messages = [ |
| { |
| "role": "user", |
| "content": [ |
| {"type": "image", "image": image}, |
| {"type": "text", "text": prompt}, |
| ], |
| } |
| ] |
|
|
| text = self.processor.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True |
| ) |
|
|
| inputs = self.processor( |
| text=[text], |
| images=[image], |
| padding=True, |
| return_tensors="pt", |
| ) |
|
|
| if self.device == "cuda": |
| device = next(self.model.parameters()).device |
| inputs = inputs.to(device) |
| else: |
| inputs = inputs.to(self.device) |
|
|
| with torch.no_grad(): |
| generated_ids = self.model.generate( |
| **inputs, |
| max_new_tokens=2048, |
| do_sample=False |
| ) |
|
|
| generated_ids_trimmed = [ |
| out_ids[len(in_ids):] |
| for in_ids, out_ids in zip(inputs.input_ids, generated_ids) |
| ] |
|
|
| output_text = self.processor.batch_decode( |
| generated_ids_trimmed, |
| skip_special_tokens=True, |
| clean_up_tokenization_spaces=False |
| )[0] |
|
|
| output_text = self.clean_repeated_substrings(output_text) |
|
|
| return output_text |
|
|
| def parse_detection_results(self, response: str, image_width: int, image_height: int) -> List[Dict]: |
| """ |
| Parse detection response into structured format with denormalized coordinates |
| |
| Args: |
| response: Model output text |
| image_width: Image width in pixels |
| image_height: Image height in pixels |
| |
| Returns: |
| List of dictionaries with 'text', 'x1', 'y1', 'x2', 'y2' keys |
| """ |
| results = [] |
|
|
| |
| for match in re.finditer(r'([^<\n]+?)((?:<\|LOC_\d+\|>)+)', response): |
| try: |
| text = match.group(1).strip() |
| locs = [int(v) for v in re.findall(r'<\|LOC_(\d+)\|>', match.group(2))] |
|
|
| if len(locs) != 8: |
| continue |
|
|
| xs = [locs[i] for i in range(0, 8, 2)] |
| ys = [locs[i] for i in range(1, 8, 2)] |
|
|
| x1 = int(min(xs) * image_width / 1000) |
| y1 = int(min(ys) * image_height / 1000) |
| x2 = int(max(xs) * image_width / 1000) |
| y2 = int(max(ys) * image_height / 1000) |
|
|
| results.append({ |
| 'text': text, |
| 'x1': x1, |
| 'y1': y1, |
| 'x2': x2, |
| 'y2': y2 |
| }) |
| except Exception as e: |
| print(f"Error parsing detection result: {str(e)}") |
| continue |
|
|
| return results |
|
|
| def process_image(self, image_source: str, prompt: Optional[str] = None) -> Tuple[str, List[Dict]]: |
| """ |
| Complete pipeline: load image, detect text, parse results |
| |
| Args: |
| image_source: Path or URL to image |
| prompt: Custom prompt for detection |
| |
| Returns: |
| Tuple of (raw_response, parsed_results, image) |
| """ |
| image = self.load_image(image_source) |
| image_width, image_height = image.size |
|
|
| response = self.detect_text(image, prompt) |
| parsed_results = self.parse_detection_results(response, image_width, image_height) |
|
|
| return response, parsed_results, image |
|
|