Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| from typing import List, Tuple, Any, Union, Optional | |
| from scipy.spatial.distance import cdist | |
| from scipy.cluster.hierarchy import fcluster, linkage | |
| class OCRHandler: | |
| def __init__(self, lang_list: List[str] = ['en'], gpu: bool = False, ocr_engine: str = 'manga-ocr'): | |
| """ | |
| Initializes the OCR handler with lazy loading. | |
| Args: | |
| lang_list: List of languages to detect (default: ['en']). | |
| gpu: Boolean to enable GPU usage (default: False). | |
| ocr_engine: 'magi' (default), 'manga-ocr', 'paddleocr', or 'easyocr'. | |
| """ | |
| self.ocr_engine = ocr_engine | |
| self.lang_list = lang_list | |
| self.gpu = gpu | |
| # Lazy loading - modules are loaded on first use | |
| self._magi_model = None | |
| self._manga_ocr = None | |
| self._detector = None | |
| self._paddle_reader = None | |
| self._easy_reader = None | |
| print(f"OCR Handler initialized with engine: {ocr_engine} (lazy loading enabled)") | |
| def _load_magi(self): | |
| """Lazy load Magi model.""" | |
| if self._magi_model is None: | |
| print("Loading Magi (The Manga Whisperer)...") | |
| try: | |
| from transformers import AutoModel | |
| import torch | |
| self._magi_model = AutoModel.from_pretrained("ragavsachdeva/magi", trust_remote_code=True) | |
| if torch.cuda.is_available() and self.gpu: | |
| self._magi_model = self._magi_model.cuda() | |
| self._magi_model.eval() | |
| print("✓ Magi loaded successfully") | |
| except ImportError as e: | |
| raise ImportError( | |
| "Magi dependencies not installed. " | |
| "This should not happen as Magi is the default engine. " | |
| f"Error: {e}" | |
| ) | |
| return self._magi_model | |
| def _load_manga_ocr(self): | |
| """Lazy load Manga-OCR.""" | |
| if self._manga_ocr is None: | |
| print("Loading Manga-OCR...") | |
| try: | |
| from manga_ocr import MangaOcr | |
| from paddleocr import PaddleOCR | |
| self._manga_ocr = MangaOcr() | |
| # PaddleOCR 3.x API with minimal preprocessing for speed | |
| self._detector = PaddleOCR( | |
| use_doc_orientation_classify=False, | |
| use_doc_unwarping=False, | |
| use_textline_orientation=False | |
| ) | |
| print("✓ Manga-OCR loaded successfully") | |
| except ImportError: | |
| raise ImportError( | |
| "Manga-OCR not installed. Install with:\n" | |
| "pip install -r requirements-optional.txt\n" | |
| "or: pip install manga-ocr paddlepaddle paddleocr" | |
| ) | |
| return self._manga_ocr, self._detector | |
| def _load_paddleocr(self): | |
| """Lazy load PaddleOCR.""" | |
| if self._paddle_reader is None: | |
| print("Loading PaddleOCR...") | |
| try: | |
| from paddleocr import PaddleOCR | |
| # PaddleOCR 3.x API with minimal preprocessing | |
| self._paddle_reader = PaddleOCR( | |
| use_doc_orientation_classify=False, | |
| use_doc_unwarping=False, | |
| use_textline_orientation=False | |
| ) | |
| print("✓ PaddleOCR loaded successfully") | |
| except ImportError: | |
| raise ImportError( | |
| "PaddleOCR not installed. Install with:\n" | |
| "pip install paddlepaddle paddleocr" | |
| ) | |
| return self._paddle_reader | |
| def _load_easyocr(self): | |
| """Lazy load EasyOCR.""" | |
| if self._easy_reader is None: | |
| print("Loading EasyOCR (this may take a while on first run)...") | |
| try: | |
| import easyocr | |
| self._easy_reader = easyocr.Reader(self.lang_list, gpu=self.gpu) | |
| print("✓ EasyOCR loaded successfully") | |
| except ImportError: | |
| raise ImportError( | |
| "EasyOCR not installed. Install with:\n" | |
| "pip install easyocr" | |
| ) | |
| return self._easy_reader | |
| def preprocess_image(self, image: np.ndarray, mode: str = 'gentle') -> np.ndarray: | |
| """ | |
| Applies preprocessing to improve OCR quality. | |
| Args: | |
| image: Input image as numpy array (RGB). | |
| mode: Preprocessing mode: | |
| - 'none': No preprocessing, use original image | |
| - 'gentle': Light preprocessing (recommended for manga) | |
| - 'aggressive': Heavy preprocessing (old behavior) | |
| """ | |
| if mode == 'none': | |
| # Scale up 3x for better recognition of thin characters like "I" | |
| return cv2.resize(image, None, fx=3, fy=3, interpolation=cv2.INTER_CUBIC) | |
| # Convert to grayscale | |
| gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) | |
| # Scaling (2x) - helpful for small text | |
| scaled = cv2.resize(gray, None, fx=2, fy=2, interpolation=cv2.INTER_CUBIC) | |
| if mode == 'gentle': | |
| # Gentle preprocessing - preserve thin strokes like "I", "l", etc. | |
| # Light contrast enhancement instead of harsh binarization | |
| clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) | |
| enhanced = clahe.apply(scaled) | |
| # Very light denoising to preserve details | |
| denoised = cv2.fastNlMeansDenoising(enhanced, h=5, templateWindowSize=7, searchWindowSize=21) | |
| return denoised | |
| else: # aggressive | |
| # Denoising | |
| denoised = cv2.fastNlMeansDenoising(scaled, h=10, templateWindowSize=7, searchWindowSize=21) | |
| # Thresholding (Binarization) - can destroy thin characters! | |
| _, binary = cv2.threshold(denoised, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) | |
| return binary | |
| def detect_text(self, image: Union[Image.Image, np.ndarray], paragraph: bool = True, preprocess_mode: str = 'gentle', tesseract_psm: int = 6, tesseract_confidence: int = 60) -> List[Tuple[List[Tuple[int, int]], str]]: | |
| """ | |
| Detects text in an image. | |
| Args: | |
| image: PIL Image or numpy array. | |
| paragraph: If True, combines text lines into paragraphs (better for translation). | |
| preprocess_mode: Preprocessing mode ('gentle', 'none', 'aggressive'). | |
| Returns: | |
| List of tuples: (bounding_box, text) or (bounding_box, text, confidence) | |
| """ | |
| if isinstance(image, Image.Image): | |
| image = np.array(image) | |
| # Apply preprocessing for detection | |
| processed_image = self.preprocess_image(image, mode=preprocess_mode) | |
| # Scale factor depends on preprocessing mode | |
| scale_factor = 3 if preprocess_mode == 'none' else 2 | |
| if self.ocr_engine == 'magi': | |
| return self._detect_with_magi(processed_image, scale_factor) | |
| elif self.ocr_engine == 'manga-ocr': | |
| return self._detect_with_manga_ocr(processed_image, scale_factor) | |
| elif self.ocr_engine == 'paddleocr': | |
| return self._detect_with_paddleocr(processed_image, scale_factor) | |
| elif self.ocr_engine == 'easyocr': | |
| return self._detect_with_easyocr(processed_image, paragraph, scale_factor) | |
| else: | |
| raise ValueError(f"Unknown OCR engine: {self.ocr_engine}") | |
| def _detect_with_magi(self, processed_image: np.ndarray, scale_factor: int) -> List[Tuple]: | |
| """Detect text using Magi - The Manga Whisperer (best for manga).""" | |
| import torch | |
| model = self._load_magi() | |
| # Magi expects RGB numpy array | |
| if len(processed_image.shape) == 2: | |
| # Grayscale to RGB | |
| processed_image = np.stack([processed_image] * 3, axis=-1) | |
| with torch.no_grad(): | |
| # Detect text boxes | |
| results = model.predict_detections_and_associations([processed_image]) | |
| text_bboxes = [results[0]["texts"]] | |
| # Run OCR on detected text boxes | |
| ocr_results = model.predict_ocr([processed_image], text_bboxes) | |
| final_results = [] | |
| if results and len(results) > 0: | |
| text_boxes = results[0].get("texts", []) | |
| ocr_texts = ocr_results[0] if ocr_results else [] | |
| for i, bbox in enumerate(text_boxes): | |
| # bbox format: [x1, y1, x2, y2] | |
| x1, y1, x2, y2 = bbox | |
| # Convert to 4-point format and scale back | |
| bbox_4pt = [ | |
| [int(x1 / scale_factor), int(y1 / scale_factor)], | |
| [int(x2 / scale_factor), int(y1 / scale_factor)], | |
| [int(x2 / scale_factor), int(y2 / scale_factor)], | |
| [int(x1 / scale_factor), int(y2 / scale_factor)] | |
| ] | |
| # Get OCR text if available | |
| text = ocr_texts[i] if i < len(ocr_texts) else "" | |
| if text.strip(): | |
| final_results.append((bbox_4pt, text.strip(), 0.95)) | |
| return final_results | |
| def _detect_with_manga_ocr(self, processed_image: np.ndarray, scale_factor: int) -> List[Tuple]: | |
| """Detect text using Manga-OCR - specialized for manga/comic fonts.""" | |
| manga_ocr, detector = self._load_manga_ocr() | |
| # Ensure 3-channel image for PaddleOCR/PaddleX doc preprocessor | |
| if len(processed_image.shape) == 2: | |
| processed_image = np.stack([processed_image] * 3, axis=-1) | |
| # PaddleOCR 3.x uses predict() and returns result objects | |
| detection_results = list(detector.predict(processed_image)) | |
| final_results = [] | |
| for res in detection_results: | |
| # Access the result dict - PaddleOCR 3.x returns objects with 'dt_polys' attribute | |
| if hasattr(res, 'dt_polys') and res.dt_polys is not None: | |
| dt_polys = res.dt_polys | |
| elif hasattr(res, '__getitem__'): | |
| # Try dict-like access | |
| res_dict = res.get('res', res) if hasattr(res, 'get') else res | |
| dt_polys = res_dict.get('dt_polys', None) if hasattr(res_dict, 'get') else None | |
| else: | |
| continue | |
| if dt_polys is None: | |
| continue | |
| for bbox_raw in dt_polys: | |
| pts = np.array(bbox_raw).astype(int) | |
| x_min, y_min = pts.min(axis=0) | |
| x_max, y_max = pts.max(axis=0) | |
| # Ensure valid crop region | |
| x_min = max(0, x_min) | |
| y_min = max(0, y_min) | |
| x_max = min(processed_image.shape[1], x_max) | |
| y_max = min(processed_image.shape[0], y_max) | |
| if x_max <= x_min or y_max <= y_min: | |
| continue | |
| # Crop the text region | |
| cropped = processed_image[y_min:y_max, x_min:x_max] | |
| if cropped.size == 0: | |
| continue | |
| # Convert to PIL for manga-ocr | |
| cropped_pil = Image.fromarray(cropped) | |
| # Recognize with manga-ocr | |
| try: | |
| text = manga_ocr(cropped_pil) | |
| except Exception as e: | |
| print(f"Manga-OCR error: {e}") | |
| continue | |
| if not text.strip(): | |
| continue | |
| # Scale bbox back - bbox_raw is already a polygon array | |
| bbox = [[int(p[0]/scale_factor), int(p[1]/scale_factor)] for p in bbox_raw] | |
| final_results.append((bbox, text.strip(), 0.95)) | |
| return final_results | |
| def _detect_with_paddleocr(self, processed_image: np.ndarray, scale_factor: int) -> List[Tuple]: | |
| """Detect text using PaddleOCR - fast and general purpose.""" | |
| reader = self._load_paddleocr() | |
| # PaddleOCR expects 3-channel BGR/RGB numpy array | |
| if len(processed_image.shape) == 2: | |
| processed_image = np.stack([processed_image] * 3, axis=-1) | |
| # PaddleOCR 3.x uses predict() and returns result objects | |
| results = list(reader.predict(processed_image)) | |
| final_results = [] | |
| for res in results: | |
| # PaddleOCR 3.x returns objects with rec_polys, rec_texts, rec_scores | |
| rec_polys = getattr(res, 'rec_polys', None) or getattr(res, 'dt_polys', None) | |
| rec_texts = getattr(res, 'rec_texts', None) | |
| rec_scores = getattr(res, 'rec_scores', None) | |
| # Try dict-like access if attributes don't work | |
| if rec_polys is None and hasattr(res, 'get'): | |
| res_dict = res.get('res', res) if 'res' in res else res | |
| rec_polys = res_dict.get('rec_polys') or res_dict.get('dt_polys') | |
| rec_texts = res_dict.get('rec_texts') | |
| rec_scores = res_dict.get('rec_scores') | |
| if rec_polys is None or rec_texts is None: | |
| continue | |
| for i, (bbox_raw, text) in enumerate(zip(rec_polys, rec_texts)): | |
| confidence = rec_scores[i] if rec_scores is not None and i < len(rec_scores) else 0.9 | |
| # Skip empty or low confidence | |
| if not text.strip() or confidence < 0.5: | |
| continue | |
| # Scale bbox back | |
| bbox = [[int(p[0]/scale_factor), int(p[1]/scale_factor)] for p in bbox_raw] | |
| final_results.append((bbox, text.strip(), float(confidence))) | |
| return final_results | |
| def _detect_with_easyocr(self, processed_image: np.ndarray, paragraph: bool, scale_factor: int) -> List[Tuple]: | |
| """Detect text using EasyOCR.""" | |
| reader = self._load_easyocr() | |
| results = reader.readtext( | |
| processed_image, | |
| paragraph=paragraph, | |
| contrast_ths=0.05, | |
| text_threshold=0.5, | |
| low_text=0.2, | |
| width_ths=0.5, | |
| height_ths=0.5, | |
| min_size=5, | |
| rotation_info=[0], | |
| ) | |
| final_results = [] | |
| for item in results: | |
| if len(item) == 2: | |
| bbox, text = item | |
| new_bbox = [[int(p[0]/scale_factor), int(p[1]/scale_factor)] for p in bbox] | |
| final_results.append((new_bbox, text)) | |
| elif len(item) == 3: | |
| bbox, text, prob = item | |
| new_bbox = [[int(p[0]/scale_factor), int(p[1]/scale_factor)] for p in bbox] | |
| final_results.append((new_bbox, text, prob)) | |
| return final_results | |
| def get_text_regions(self, image: Union[Image.Image, np.ndarray]) -> List[Any]: | |
| """ | |
| Returns raw results from OCR. | |
| """ | |
| return self.detect_text(image) | |
| def group_text_into_bubbles(self, text_results: List[Tuple], distance_threshold: float = 50) -> List[Tuple[List[List[int]], str]]: | |
| """ | |
| Gruppiert nahe beieinanderliegende Textblöcke zu Sprechblasen. | |
| Args: | |
| text_results: Liste von (bbox, text) Tupeln aus detect_text. | |
| distance_threshold: Maximaler Abstand zwischen Textblöcken, um sie zu gruppieren. | |
| Returns: | |
| Liste von (merged_bbox, combined_text) Tupeln. | |
| """ | |
| if not text_results or len(text_results) == 0: | |
| return [] | |
| if len(text_results) == 1: | |
| # Nur ein Textblock, direkt zurückgeben | |
| bbox, text = text_results[0][:2] | |
| return [(bbox, text)] | |
| # Berechne Zentren aller Bounding Boxes | |
| centers = [] | |
| for item in text_results: | |
| bbox = item[0] | |
| pts = np.array(bbox) | |
| center_x = np.mean(pts[:, 0]) | |
| center_y = np.mean(pts[:, 1]) | |
| centers.append([center_x, center_y]) | |
| centers = np.array(centers) | |
| # Hierarchisches Clustering basierend auf Distanz | |
| if len(centers) > 1: | |
| linkage_matrix = linkage(centers, method='average') | |
| clusters = fcluster(linkage_matrix, t=distance_threshold, criterion='distance') | |
| else: | |
| clusters = [1] | |
| # Gruppiere Textblöcke nach Cluster | |
| cluster_groups = {} | |
| for idx, cluster_id in enumerate(clusters): | |
| if cluster_id not in cluster_groups: | |
| cluster_groups[cluster_id] = [] | |
| cluster_groups[cluster_id].append(idx) | |
| # Erstelle zusammengeführte Ergebnisse | |
| merged_results = [] | |
| for cluster_id, indices in cluster_groups.items(): | |
| # Sammle alle Bboxes und Texte dieser Gruppe | |
| all_bboxes = [] | |
| all_texts = [] | |
| # Sortiere nach Y-Position (oben nach unten) | |
| sorted_indices = sorted(indices, key=lambda i: np.mean(np.array(text_results[i][0])[:, 1])) | |
| for idx in sorted_indices: | |
| item = text_results[idx] | |
| bbox = item[0] | |
| text = item[1] | |
| all_bboxes.append(bbox) | |
| all_texts.append(text) | |
| # Kombiniere alle Bboxes zu einer großen Bbox | |
| all_points = [] | |
| for bbox in all_bboxes: | |
| all_points.extend(bbox) | |
| all_points = np.array(all_points) | |
| x_min = int(np.min(all_points[:, 0])) | |
| y_min = int(np.min(all_points[:, 1])) | |
| x_max = int(np.max(all_points[:, 0])) | |
| y_max = int(np.max(all_points[:, 1])) | |
| merged_bbox = [[x_min, y_min], [x_max, y_min], [x_max, y_max], [x_min, y_max]] | |
| # Kombiniere Texte mit Leerzeichen (für natürlichen Lesefluss) | |
| combined_text = ' '.join(all_texts) | |
| merged_results.append((merged_bbox, combined_text)) | |
| return merged_results | |
| def detect_and_group_text(self, image: Union[Image.Image, np.ndarray], distance_threshold: float = 50, preprocess_mode: str = 'gentle') -> List[Tuple[List[List[int]], str]]: | |
| """ | |
| Erkennt Text und gruppiert ihn automatisch nach Sprechblasen. | |
| Args: | |
| image: PIL Image oder numpy array. | |
| distance_threshold: Maximaler Abstand für Gruppierung (in Pixeln). | |
| preprocess_mode: Preprocessing mode ('gentle', 'none', 'aggressive'). | |
| Returns: | |
| Liste von (bbox, combined_text) Tupeln, gruppiert nach Sprechblasen. | |
| """ | |
| # Erst einzelne Textblöcke erkennen (paragraph=False für feinere Kontrolle) | |
| raw_results = self.detect_text(image, paragraph=False, preprocess_mode=preprocess_mode) | |
| # Dann nach räumlicher Nähe gruppieren | |
| grouped_results = self.group_text_into_bubbles(raw_results, distance_threshold) | |
| return grouped_results | |