Spaces:
Paused
Paused
| import os, io, stat, logging, sys, asyncio | |
| from typing import Any, Dict, Iterable, List, Tuple, Union | |
| from fastapi import FastAPI, UploadFile, File, Form, Header, HTTPException, Security | |
| from fastapi.security import APIKeyHeader | |
| from fastapi.responses import JSONResponse | |
| from PIL import Image, ImageEnhance, ImageFilter | |
| import numpy as np | |
| # Configure logging to stdout for HuggingFace Spaces | |
| logging.basicConfig( | |
| level=logging.DEBUG, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| handlers=[ | |
| logging.StreamHandler(sys.stdout) | |
| ] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| logger.setLevel(logging.DEBUG) | |
| # Also set root logger to DEBUG | |
| logging.getLogger().setLevel(logging.DEBUG) | |
| # API Key Authentication | |
| API_KEY = os.environ.get("API_KEY", None) # Set this in HuggingFace Spaces Secrets | |
| API_KEY_NAME = "X-API-Key" | |
| api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False) | |
| async def verify_api_key(api_key: str = Security(api_key_header)): | |
| """Verify API key if authentication is enabled.""" | |
| if API_KEY is None: | |
| # No API key configured - allow all requests | |
| logger.warning("API_KEY not set - endpoint is unprotected!") | |
| return None | |
| if api_key is None: | |
| logger.warning("Request missing API key") | |
| raise HTTPException( | |
| status_code=401, | |
| detail="Missing API Key. Include 'X-API-Key' header." | |
| ) | |
| if api_key != API_KEY: | |
| logger.warning(f"Invalid API key attempt: {api_key[:10]}...") | |
| raise HTTPException( | |
| status_code=403, | |
| detail="Invalid API Key" | |
| ) | |
| return api_key | |
| # ----------------------------------------------------------------------------- | |
| # Writable caches (HF/Docker safe) & clear thread envs (suppress OpenBLAS warn) | |
| # ----------------------------------------------------------------------------- | |
| os.environ.setdefault("HOME", "/tmp") | |
| os.environ.setdefault("TMPDIR", "/tmp") | |
| os.environ.setdefault("XDG_CACHE_HOME", "/tmp/.cache") | |
| os.environ.setdefault("PADDLE_HOME", "/tmp/.paddle") | |
| os.environ.setdefault("PADDLEX_HOME", "/tmp/.paddlex") | |
| for d in [ | |
| os.environ["XDG_CACHE_HOME"], | |
| os.path.join(os.environ["XDG_CACHE_HOME"], "paddle"), | |
| os.environ["PADDLE_HOME"], | |
| os.path.join(os.environ["PADDLEX_HOME"], "temp"), | |
| ]: | |
| try: | |
| os.makedirs(d, exist_ok=True) | |
| os.chmod(d, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO) | |
| except Exception: | |
| pass | |
| # Unset any inherited BLAS/OMP thread caps BEFORE importing paddle/paddleocr | |
| for v in ("OMP_NUM_THREADS", "OPENBLAS_NUM_THREADS", "MKL_NUM_THREADS", "NUMEXPR_NUM_THREADS"): | |
| os.environ.pop(v, None) | |
| logger.info("Environment setup complete. Cache directories configured.") | |
| logger.info(f"PADDLE_HOME: {os.environ['PADDLE_HOME']}") | |
| logger.info(f"XDG_CACHE_HOME: {os.environ['XDG_CACHE_HOME']}") | |
| from paddleocr import PaddleOCR # import AFTER env cleanup | |
| logger.info("PaddleOCR module imported successfully") | |
| # ============================================================================= | |
| # THREAD-SAFE OCR POOL - NEW IMPLEMENTATION | |
| # ============================================================================= | |
| class OCRPool: | |
| """ | |
| Thread-safe pool of PaddleOCR instances per language. | |
| This class manages multiple PaddleOCR instances (one per language) and | |
| ensures thread-safe access. It uses asyncio locks to prevent race conditions | |
| when multiple concurrent requests arrive. | |
| Features: | |
| - Lazy initialization: Creates instances only when needed | |
| - Thread-safe: Uses locks to prevent concurrent access issues | |
| - GPU serialization: Ensures only one OCR operation runs at a time | |
| - Language caching: Keeps models in memory for fast switching | |
| """ | |
| def __init__(self): | |
| self._instances: Dict[str, PaddleOCR] = {} | |
| self._pool_lock = asyncio.Lock() # Protects instance creation | |
| self._gpu_lock = asyncio.Lock() # Serializes GPU access | |
| logger.info("OCRPool initialized") | |
| async def get_ocr(self, lang: str = "en") -> PaddleOCR: | |
| """ | |
| Get or create OCR instance for the specified language. | |
| This method is thread-safe and uses double-checked locking to minimize | |
| lock contention. If an instance already exists, it's returned immediately. | |
| Otherwise, a new instance is created under lock protection. | |
| Args: | |
| lang: Language code (e.g., "en", "fr", "es", "zh") | |
| Returns: | |
| PaddleOCR instance configured for the specified language | |
| """ | |
| # Fast path: instance already exists (no lock needed) | |
| if lang in self._instances: | |
| logger.debug(f"Using cached OCR instance for language: {lang}") | |
| return self._instances[lang] | |
| # Slow path: need to create instance (acquire lock) | |
| async with self._pool_lock: | |
| # Double-check after acquiring lock (another request may have created it) | |
| if lang in self._instances: | |
| logger.debug(f"OCR instance for {lang} created by another request") | |
| return self._instances[lang] | |
| logger.info(f"Creating new OCR instance for language: {lang}") | |
| try: | |
| self._instances[lang] = PaddleOCR( | |
| use_angle_cls=True, | |
| lang=lang, | |
| use_gpu=True, | |
| gpu_mem=500 # GPU memory limit in MB | |
| ) | |
| logger.info(f"✓ OCR instance created successfully for {lang}") | |
| except Exception as e: | |
| logger.error(f"Failed to create OCR instance for {lang}: {e}") | |
| raise | |
| return self._instances[lang] | |
| async def run_ocr(self, lang: str, image_array: np.ndarray) -> List: | |
| """ | |
| Run OCR on an image array with GPU serialization. | |
| This method ensures that only one OCR operation runs at a time on the GPU. | |
| Even though we cache multiple language models, GPU operations are serialized | |
| to prevent contention and maximize throughput on single-GPU systems. | |
| Args: | |
| lang: Language code for OCR | |
| image_array: Numpy array of the image (HxWx3, RGB) | |
| Returns: | |
| PaddleOCR results (list of detections per page) | |
| """ | |
| # Get the OCR instance for this language | |
| ocr = await self.get_ocr(lang) | |
| # Serialize GPU access (only one OCR operation at a time) | |
| async with self._gpu_lock: | |
| logger.debug(f"Running OCR on GPU with {lang} model...") | |
| # PaddleOCR is synchronous, so we run it directly | |
| # (in production, you might want to use run_in_executor for CPU-heavy tasks) | |
| results = ocr.ocr(image_array, cls=True) | |
| logger.debug(f"OCR completed for {lang}") | |
| return results | |
| def get_stats(self) -> dict: | |
| """Get statistics about the OCR pool.""" | |
| return { | |
| "cached_languages": list(self._instances.keys()), | |
| "total_instances": len(self._instances), | |
| } | |
| # Initialize global OCR pool (this object itself is never reassigned, so it's safe) | |
| ocr_pool = OCRPool() | |
| logger.info("Global OCR pool created") | |
| # ============================================================================= | |
| # FASTAPI APP INITIALIZATION | |
| # ============================================================================= | |
| app = FastAPI( | |
| title="PaddleOCR 2.8 API (GPU-Accelerated)", | |
| version="2.8.1-gpu-threadsafe", | |
| root_path="/", | |
| docs_url="/docs", | |
| openapi_url="/openapi.json" | |
| ) | |
| logger.info("FastAPI app initialized") | |
| async def startup_event(): | |
| """Log when application starts up.""" | |
| logger.info("="*50) | |
| logger.info("PaddleOCR GPU API APPLICATION STARTED") | |
| logger.info("PaddleOCR Version: 2.8.1 (Thread-Safe)") | |
| logger.info("CUDA Version: 11.8") | |
| logger.info("Source: PyPI (fast downloads)") | |
| logger.info("Thread Safety: ENABLED (OCRPool)") | |
| logger.info("="*50) | |
| logger.info("Available endpoints:") | |
| logger.info(" GET / - Health check") | |
| logger.info(" GET /test - Test endpoint") | |
| logger.info(" GET /stats - OCR pool statistics") | |
| logger.info(" GET /docs - API documentation") | |
| logger.info(" POST /ocr - OCR processing (thread-safe)") | |
| logger.info("="*50) | |
| # ============================================================================= | |
| # HELPER FUNCTIONS (unchanged, already thread-safe) | |
| # ============================================================================= | |
| def _is_number(x: Any) -> bool: | |
| """Check if a value can be converted to float.""" | |
| try: | |
| float(x) | |
| return True | |
| except Exception: | |
| return False | |
| def _is_point(pt: Any) -> bool: | |
| """Check if pt is a valid 2D point [x, y].""" | |
| return ( | |
| isinstance(pt, (list, tuple)) and | |
| len(pt) == 2 and | |
| _is_number(pt[0]) and | |
| _is_number(pt[1]) | |
| ) | |
| def _is_quad(box: Any) -> bool: | |
| """Check if box is a valid quadrilateral (4 points).""" | |
| return ( | |
| isinstance(box, (list, tuple)) and | |
| len(box) == 4 and | |
| all(_is_point(p) for p in box) | |
| ) | |
| def _coerce_box(box: Any) -> Union[List[List[float]], None]: | |
| """Try to coerce various box formats into a standard quad; return None if impossible.""" | |
| # Convert numpy array to list first | |
| if isinstance(box, np.ndarray): | |
| box = box.tolist() | |
| # Already a proper quad? | |
| if _is_quad(box): | |
| return [[float(p[0]), float(p[1])] for p in box] | |
| # Some variants: dict with 'points' or 'box' | |
| if isinstance(box, dict): | |
| for k in ("points", "box", "polygon"): | |
| if k in box and _is_quad(box[k]): | |
| return [[float(p[0]), float(p[1])] for p in box[k]] | |
| # Some models may output rect [x_min, y_min, x_max, y_max] | |
| if ( | |
| isinstance(box, (list, tuple)) and | |
| len(box) == 4 and | |
| all(_is_number(v) for v in box) | |
| ): | |
| x1, y1, x2, y2 = map(float, box) | |
| return [[x1, y1], [x2, y1], [x2, y2], [x1, y2]] | |
| # Anything else: give up | |
| return None | |
| def _format_as_markdown(results: List[dict]) -> str: | |
| """Format OCR results as clean, readable markdown with table detection.""" | |
| if not results: | |
| return "" | |
| # Sort by Y position (top to bottom), then X position (left to right) | |
| sorted_results = sorted(results, key=lambda x: ( | |
| min(p[1] for p in x["box"]), # Y position | |
| min(p[0] for p in x["box"]) # X position | |
| )) | |
| # Group into rows based on Y position | |
| rows = [] | |
| current_row = [] | |
| last_y = None | |
| y_threshold = 15 # Pixels - items within this are on same line | |
| for item in sorted_results: | |
| box = item["box"] | |
| y_center = sum(p[1] for p in box) / 4 | |
| x_min = min(p[0] for p in box) | |
| x_max = max(p[0] for p in box) | |
| text = item["text"].strip() | |
| if not text: | |
| continue | |
| # Check if we're on a new line | |
| if last_y is None or abs(y_center - last_y) > y_threshold: | |
| # Save previous line | |
| if current_row: | |
| rows.append(current_row) | |
| current_row = [{ | |
| "text": text, | |
| "x_min": x_min, | |
| "x_max": x_max, | |
| "x_center": (x_min + x_max) / 2, | |
| "y_center": y_center | |
| }] | |
| last_y = y_center | |
| else: | |
| # Same line - add to current row | |
| current_row.append({ | |
| "text": text, | |
| "x_min": x_min, | |
| "x_max": x_max, | |
| "x_center": (x_min + x_max) / 2, | |
| "y_center": y_center | |
| }) | |
| # Don't forget the last row | |
| if current_row: | |
| rows.append(current_row) | |
| # Sort items within each row by X position | |
| for row in rows: | |
| row.sort(key=lambda x: x["x_min"]) | |
| # Detect tables | |
| markdown = [] | |
| i = 0 | |
| while i < len(rows): | |
| row = rows[i] | |
| # Only consider table if row has 2+ columns | |
| if len(row) >= 2: | |
| # Look ahead for similar column structure | |
| table_rows = _detect_table(rows[i:]) | |
| if len(table_rows) >= 3: # Need at least 3 rows to be a table | |
| # Format as table | |
| markdown.append("") # Spacing before table | |
| _add_table_to_markdown(table_rows, markdown) | |
| markdown.append("") # Spacing after table | |
| i += len(table_rows) | |
| continue | |
| # Not a table - format as regular text | |
| line_text = " ".join(item["text"] for item in row) | |
| # Format based on content | |
| if not line_text.strip(): | |
| i += 1 | |
| continue | |
| # Title (first line if short enough) | |
| if i == 0 and len(line_text) < 100: | |
| markdown.append(f"# {line_text}") | |
| markdown.append("") | |
| # Section headers (short lines with colons or all caps) | |
| elif (len(line_text) < 60 and | |
| (line_text.endswith(':') or line_text.isupper())): | |
| if markdown: | |
| markdown.append("") # Spacing before header | |
| markdown.append(f"**{line_text}**") | |
| markdown.append("") | |
| # Numbered items | |
| elif (len(line_text) <= 3 and | |
| any(line_text.startswith(str(n)) for n in range(1, 20))): | |
| markdown.append(f"\n{line_text}") | |
| # Regular paragraph | |
| else: | |
| markdown.append(line_text) | |
| i += 1 | |
| return "\n".join(markdown).strip() | |
| def _detect_table(rows: List[List[dict]]) -> List[List[dict]]: | |
| """ | |
| Detect if rows form a table by checking for consistent column alignment. | |
| Returns the rows that form a table (empty if not a table). | |
| """ | |
| if len(rows) < 3: # Need at least 3 rows for a table | |
| return [] | |
| first_row = rows[0] | |
| if len(first_row) < 2: # Need at least 2 columns | |
| return [] | |
| # Extract column X positions from first row | |
| col_positions = [item["x_center"] for item in first_row] | |
| num_cols = len(col_positions) | |
| table_rows = [first_row] | |
| col_threshold = 40 # Pixels - columns must align within this | |
| # Check subsequent rows for alignment | |
| for row in rows[1:]: | |
| if len(row) < 2: # Skip single-column rows | |
| break | |
| # Check if this row's columns align with the first row | |
| if _row_aligns_with_columns(row, col_positions, col_threshold): | |
| table_rows.append(row) | |
| else: | |
| # Stop at first non-aligning row | |
| break | |
| # Stop checking after 20 rows (max table size) | |
| if len(table_rows) >= 20: | |
| break | |
| # Only return as table if we found at least 3 aligned rows | |
| return table_rows if len(table_rows) >= 3 else [] | |
| def _row_aligns_with_columns(row: List[dict], col_positions: List[float], threshold: float) -> bool: | |
| """Check if a row's columns align with expected column positions.""" | |
| if len(row) != len(col_positions): | |
| # Allow rows with fewer columns (merged cells) | |
| if len(row) > len(col_positions): | |
| return False | |
| # Check if each item in the row aligns with a column position | |
| for item in row: | |
| item_x = item["x_center"] | |
| # Find closest column position | |
| min_distance = min(abs(item_x - col_x) for col_x in col_positions) | |
| if min_distance > threshold: | |
| return False | |
| return True | |
| def _add_table_to_markdown(table_rows: List[List[dict]], markdown: List[str]): | |
| """Add a formatted markdown table to the markdown list.""" | |
| if not table_rows: | |
| return | |
| # Determine max columns | |
| max_cols = max(len(row) for row in table_rows) | |
| # Format each row | |
| for row_idx, row in enumerate(table_rows): | |
| # Pad row to max columns | |
| row_texts = [item["text"] for item in row] | |
| while len(row_texts) < max_cols: | |
| row_texts.append("") | |
| # Add row | |
| markdown.append("| " + " | ".join(row_texts) + " |") | |
| # Add separator after first row (header) | |
| if row_idx == 0: | |
| markdown.append("| " + " | ".join(["---"] * max_cols) + " |") | |
| # ============================================================================= | |
| # API ENDPOINTS | |
| # ============================================================================= | |
| def health_check(): | |
| """Health check endpoint - HuggingFace Spaces checks this.""" | |
| logger.info("Health check endpoint called") | |
| stats = ocr_pool.get_stats() | |
| return JSONResponse({ | |
| "status": "ok", | |
| "engine": "PaddleOCR 2.8.1 (GPU-Accelerated, Thread-Safe)", | |
| "version": "2.8.1-threadsafe", | |
| "paddlepaddle_version": "2.6.2", | |
| "cuda_version": "11.8", | |
| "source": "PyPI", | |
| "lang_default": "en", | |
| "gpu_enabled": True, | |
| "thread_safe": True, | |
| "ocr_pool": stats, | |
| "endpoints": { | |
| "health": "/", | |
| "ocr": "/ocr", | |
| "stats": "/stats", | |
| "docs": "/docs", | |
| "test": "/test" | |
| }, | |
| "cache": { | |
| "XDG_CACHE_HOME": os.environ["XDG_CACHE_HOME"], | |
| "PADDLE_HOME": os.environ["PADDLE_HOME"], | |
| "PADDLEX_HOME": os.environ["PADDLEX_HOME"], | |
| }, | |
| }) | |
| def test_endpoint(): | |
| """Simple test endpoint to verify routing.""" | |
| logger.info("Test endpoint called") | |
| return JSONResponse({ | |
| "message": "Test endpoint works! (GPU mode, thread-safe)", | |
| "timestamp": "2025-01-08", | |
| "thread_safe": True | |
| }) | |
| def stats_endpoint(): | |
| """Get OCR pool statistics.""" | |
| logger.info("Stats endpoint called") | |
| stats = ocr_pool.get_stats() | |
| return JSONResponse({ | |
| "ocr_pool": stats, | |
| "thread_safe": True, | |
| "gpu_serialization": "enabled" | |
| }) | |
| async def ocr_endpoint( | |
| file: UploadFile = File(...), | |
| lang: str = Form("en"), | |
| confidence_threshold: float = Form(0.4), | |
| api_key: str = Security(verify_api_key), | |
| ): | |
| """ | |
| OCR endpoint for text detection and recognition (THREAD-SAFE). | |
| This endpoint is fully thread-safe and can handle concurrent requests | |
| with different languages without race conditions. Each language gets | |
| its own cached OCR instance, and GPU access is serialized to prevent | |
| contention. | |
| Args: | |
| file: Image file to process | |
| lang: Language code (default: "en") | |
| confidence_threshold: Minimum confidence score (0.0-1.0, default: 0.4) | |
| api_key: API key for authentication (required if API_KEY is set) | |
| Returns: | |
| JSON with detected text, confidence scores, bounding boxes, and formatted markdown | |
| """ | |
| logger.info(f"[THREAD-SAFE] OCR request - filename: {file.filename}, lang: {lang}, threshold: {confidence_threshold}") | |
| try: | |
| # PHASE 1: Image preprocessing (can run in parallel, no shared state) | |
| logger.debug("Reading image file...") | |
| contents = await file.read() | |
| logger.debug(f"Image file read - size: {len(contents)} bytes") | |
| img = Image.open(io.BytesIO(contents)).convert("RGB") | |
| logger.debug(f"Image opened - dimensions: {img.size}, mode: {img.mode}") | |
| # Optimal preprocessing for OCR text detection | |
| logger.debug("Applying OCR preprocessing...") | |
| img = ImageEnhance.Contrast(img).enhance(1.2) | |
| img = ImageEnhance.Sharpness(img).enhance(1.2) | |
| arr = np.array(img) | |
| logger.debug(f"Image converted to array - shape: {arr.shape}, dtype: {arr.dtype}") | |
| # Ensure HxWx3 format | |
| if arr.ndim == 2: | |
| logger.debug("Converting grayscale to RGB") | |
| arr = np.stack([arr, arr, arr], axis=-1) | |
| elif arr.ndim == 3 and arr.shape[2] == 4: | |
| logger.debug("Removing alpha channel") | |
| arr = arr[:, :, :3] | |
| logger.debug(f"Final array shape: {arr.shape}") | |
| # PHASE 2: OCR execution (thread-safe via OCRPool) | |
| logger.info(f"Running thread-safe OCR with language: {lang}") | |
| results = await ocr_pool.run_ocr(lang, arr) | |
| logger.info("OCR processing complete") | |
| if not results or results is None: | |
| logger.warning("No results returned from OCR") | |
| return JSONResponse({ | |
| "results": [], | |
| "markdown": "", | |
| "summary": { | |
| "total_detections": 0, | |
| "average_confidence": 0 | |
| } | |
| }) | |
| # PHASE 3: Result processing (no shared state, thread-safe) | |
| out = [] | |
| detection_count = 0 | |
| skipped_count = 0 | |
| logger.debug("Processing OCR results...") | |
| for page_idx, page_result in enumerate(results): | |
| # Skip None pages | |
| if page_result is None: | |
| logger.debug(f"Page {page_idx}: No text detected") | |
| continue | |
| if not isinstance(page_result, list): | |
| logger.warning(f"Page {page_idx}: Unexpected type {type(page_result)}, skipping") | |
| skipped_count += 1 | |
| continue | |
| logger.debug(f"Page {page_idx}: Processing {len(page_result)} detections") | |
| for line_idx, line in enumerate(page_result): | |
| if not (isinstance(line, (list, tuple)) and len(line) >= 2): | |
| logger.warning(f"Page {page_idx}, Line {line_idx}: Invalid format") | |
| skipped_count += 1 | |
| continue | |
| box_raw = line[0] | |
| info = line[1] | |
| box = _coerce_box(box_raw) | |
| if box is None: | |
| logger.warning(f"Page {page_idx}, Line {line_idx}: Could not coerce box") | |
| skipped_count += 1 | |
| continue | |
| # Extract text and confidence | |
| if isinstance(info, (list, tuple)) and len(info) >= 1: | |
| text = str(info[0]) | |
| conf = None | |
| if len(info) >= 2 and _is_number(info[1]): | |
| conf = float(info[1]) | |
| else: | |
| text, conf = str(info), None | |
| # Skip empty text or low confidence | |
| if not text.strip(): | |
| skipped_count += 1 | |
| continue | |
| if conf is not None and conf < confidence_threshold: | |
| skipped_count += 1 | |
| logger.debug(f"Skipping low confidence ({conf:.3f}): {text[:30]}") | |
| continue | |
| out.append({"text": text, "confidence": conf, "box": box}) | |
| detection_count += 1 | |
| logger.info(f"Results: {detection_count} detections, {skipped_count} skipped") | |
| # Generate formatted markdown | |
| markdown_text = _format_as_markdown(out) | |
| logger.debug("Markdown generated") | |
| return JSONResponse({ | |
| "results": out, | |
| "markdown": markdown_text, | |
| "summary": { | |
| "total_detections": len(out), | |
| "average_confidence": sum(item["confidence"] for item in out if item["confidence"]) / len(out) if out else 0 | |
| } | |
| }) | |
| except Exception as e: | |
| logger.error(f"Error processing OCR request: {str(e)}", exc_info=True) | |
| return JSONResponse( | |
| { | |
| "error": str(e), | |
| "results": [], | |
| "markdown": "", | |
| "summary": { | |
| "total_detections": 0, | |
| "average_confidence": 0 | |
| } | |
| }, | |
| status_code=500 | |
| ) | |