""" PaddleOCR-VL-1.5 Model Wrapper Provides an easy-to-use interface for text detection and recognition """ import re import os import torch from typing import Dict, List, Tuple, Optional from PIL import Image from transformers import AutoProcessor, AutoModelForImageTextToText import requests from io import BytesIO class PaddleOCRVL: """Wrapper class for PaddleOCR-VL-1.5 model for text spotting tasks""" def __init__(self, model_path: str = "PaddlePaddle/PaddleOCR-VL-1.5", device: Optional[str] = None): """ Initialize the PaddleOCR-VL-1.5 model Args: model_path: Path or name of the model (default: "PaddlePaddle/PaddleOCR-VL-1.5") device: Device to load model on (cuda/cpu). Auto-detected if None. """ self.model_path = model_path if device is None: if torch.cuda.is_available(): self.device = "cuda" elif torch.backends.mps.is_available(): self.device = "mps" else: self.device = "cpu" else: self.device = device print(f"Loading PaddleOCR-VL-1.5 model on {self.device}...") try: self.processor = AutoProcessor.from_pretrained(model_path) except Exception: print("Network error loading processor, falling back to local cache...") self.processor = AutoProcessor.from_pretrained(model_path, local_files_only=True) if self.device == "cuda": torch_dtype = torch.bfloat16 elif self.device == "mps": torch_dtype = torch.float16 else: torch_dtype = torch.float32 try: self.model = AutoModelForImageTextToText.from_pretrained( model_path, dtype=torch_dtype, device_map="auto" if self.device == "cuda" else None ) except Exception: print("Network error loading model, falling back to local cache...") self.model = AutoModelForImageTextToText.from_pretrained( model_path, dtype=torch_dtype, device_map="auto" if self.device == "cuda" else None, local_files_only=True ) if self.device != "cuda": self.model = self.model.to(self.device) print("Model loaded successfully!") def clean_repeated_substrings(self, text: str) -> str: n = len(text) if n < 8000: return text for length in range(2, n // 10 + 1): candidate = text[-length:] count = 0 i = n - length while i >= 0 and text[i:i + length] == candidate: count += 1 i -= length if count >= 10: return text[:n - length * (count - 1)] return text def load_image(self, image_source: str) -> Image.Image: if image_source.startswith(('http://', 'https://')): response = requests.get(image_source) response.raise_for_status() return Image.open(BytesIO(response.content)) else: return Image.open(image_source) def detect_text(self, image: Image.Image, prompt: Optional[str] = None) -> str: """ Detect and recognize text in image with bounding boxes Args: image: PIL Image object prompt: Custom prompt (default: text spotting prompt in Chinese) Returns: Model response with detected text and coordinates """ if prompt is None: prompt = "Spotting:" messages = [ { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": prompt}, ], } ] text = self.processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = self.processor( text=[text], images=[image], padding=True, return_tensors="pt", ) if self.device == "cuda": device = next(self.model.parameters()).device inputs = inputs.to(device) else: inputs = inputs.to(self.device) with torch.no_grad(): generated_ids = self.model.generate( **inputs, max_new_tokens=2048, do_sample=False ) generated_ids_trimmed = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] output_text = self.processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0] output_text = self.clean_repeated_substrings(output_text) return output_text def parse_detection_results(self, response: str, image_width: int, image_height: int) -> List[Dict]: """ Parse detection response into structured format with denormalized coordinates Args: response: Model output text image_width: Image width in pixels image_height: Image height in pixels Returns: List of dictionaries with 'text', 'x1', 'y1', 'x2', 'y2' keys """ results = [] # Pattern to match text followed by <|LOC_xxx|> tokens (8 per detection, quadrilateral) for match in re.finditer(r'([^<\n]+?)((?:<\|LOC_\d+\|>)+)', response): try: text = match.group(1).strip() locs = [int(v) for v in re.findall(r'<\|LOC_(\d+)\|>', match.group(2))] if len(locs) != 8: continue xs = [locs[i] for i in range(0, 8, 2)] ys = [locs[i] for i in range(1, 8, 2)] x1 = int(min(xs) * image_width / 1000) y1 = int(min(ys) * image_height / 1000) x2 = int(max(xs) * image_width / 1000) y2 = int(max(ys) * image_height / 1000) results.append({ 'text': text, 'x1': x1, 'y1': y1, 'x2': x2, 'y2': y2 }) except Exception as e: print(f"Error parsing detection result: {str(e)}") continue return results def process_image(self, image_source: str, prompt: Optional[str] = None) -> Tuple[str, List[Dict]]: """ Complete pipeline: load image, detect text, parse results Args: image_source: Path or URL to image prompt: Custom prompt for detection Returns: Tuple of (raw_response, parsed_results, image) """ image = self.load_image(image_source) image_width, image_height = image.size response = self.detect_text(image, prompt) parsed_results = self.parse_detection_results(response, image_width, image_height) return response, parsed_results, image