Spaces:
Sleeping
Sleeping
| """ | |
| PaddleOCR-VL-1.5 Bridge Server (HF Spaces Edition) | |
| ==================================================== | |
| Returns full JSON response matching the official Baidu API format, including: | |
| - layoutParsingResults[].prunedResult (blocks, labels, bboxes, polygon points) | |
| - layoutParsingResults[].markdown (text + images) | |
| - layoutParsingResults[].outputImages (visualization URLs) | |
| - layoutParsingResults[].inputImage | |
| - preprocessedImages | |
| - dataInfo | |
| Architecture: | |
| Gradio App β This Bridge (port 7860) β vLLM Docker (117.54.141.62:8000) | |
| """ | |
| import base64 | |
| import json | |
| import os | |
| import shutil | |
| import tempfile | |
| import traceback | |
| import uuid | |
| from typing import Any, Dict, List, Optional | |
| import uvicorn | |
| from fastapi import FastAPI, File, Header, HTTPException, Request, UploadFile | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| from openai import OpenAI | |
| from PIL import Image | |
| # ============================================================================= | |
| # Configuration | |
| # ============================================================================= | |
| VLLM_SERVER_URL = os.environ.get("VLLM_SERVER_URL", "http://117.54.141.62:8000/v1") | |
| VLLM_MODEL_NAME = os.environ.get("VLLM_MODEL_NAME", "PaddleOCR-VL-1.5-0.9B") | |
| BRIDGE_PORT = int(os.environ.get("PORT", "7860")) | |
| API_KEY = os.environ.get("API_KEY", "") | |
| SPACE_HOST = os.environ.get("SPACE_HOST", "") | |
| if SPACE_HOST: | |
| PUBLIC_BASE_URL = f"https://{SPACE_HOST}" | |
| else: | |
| PUBLIC_BASE_URL = os.environ.get("PUBLIC_BASE_URL", f"http://localhost:{BRIDGE_PORT}") | |
| STATIC_DIR = "/tmp/ocr_outputs" | |
| os.makedirs(STATIC_DIR, exist_ok=True) | |
| # ============================================================================= | |
| # Initialize clients | |
| # ============================================================================= | |
| openai_client = OpenAI( | |
| api_key="EMPTY", | |
| base_url=VLLM_SERVER_URL, | |
| timeout=600 | |
| ) | |
| pipeline = None | |
| def get_pipeline(): | |
| global pipeline | |
| if pipeline is None: | |
| from paddleocr import PaddleOCRVL | |
| pipeline = PaddleOCRVL( | |
| vl_rec_backend="vllm-server", | |
| vl_rec_server_url=VLLM_SERVER_URL | |
| ) | |
| return pipeline | |
| # ============================================================================= | |
| # FastAPI App | |
| # ============================================================================= | |
| app = FastAPI( | |
| title="PaddleOCR-VL-1.5 Bridge API", | |
| description="Full document parsing API matching official Baidu API format", | |
| version="1.0.0" | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") | |
| # ============================================================================= | |
| # Auth | |
| # ============================================================================= | |
| def verify_auth(authorization: Optional[str] = None): | |
| if API_KEY and API_KEY.strip(): | |
| if not authorization or authorization != f"Bearer {API_KEY}": | |
| raise HTTPException(status_code=401, detail="Unauthorized") | |
| # ============================================================================= | |
| # Helpers | |
| # ============================================================================= | |
| TASK_PROMPTS = { | |
| "ocr": "OCR:", | |
| "formula": "Formula Recognition:", | |
| "table": "Table Recognition:", | |
| "chart": "Chart Recognition:", | |
| "spotting": "Spotting:", | |
| "seal": "Seal Recognition:", | |
| } | |
| IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif"} | |
| def save_temp_image(file_data: str) -> str: | |
| if file_data.startswith(("http://", "https://")): | |
| import requests as req | |
| resp = req.get(file_data, timeout=120) | |
| resp.raise_for_status() | |
| content = resp.content | |
| ct = resp.headers.get("content-type", "image/png") | |
| ext = ".png" | |
| if "jpeg" in ct or "jpg" in ct: | |
| ext = ".jpg" | |
| elif "webp" in ct: | |
| ext = ".webp" | |
| elif "bmp" in ct: | |
| ext = ".bmp" | |
| else: | |
| content = base64.b64decode(file_data) | |
| ext = ".png" | |
| tmp = tempfile.NamedTemporaryFile(delete=False, suffix=ext) | |
| tmp.write(content) | |
| tmp.close() | |
| return tmp.name | |
| def serve_file(src_path: str, request_id: str, filename: str) -> str: | |
| """Copy a file to the static dir and return its public URL.""" | |
| static_subdir = os.path.join(STATIC_DIR, request_id) | |
| os.makedirs(static_subdir, exist_ok=True) | |
| dst_path = os.path.join(static_subdir, filename) | |
| shutil.copy2(src_path, dst_path) | |
| return f"{PUBLIC_BASE_URL}/static/{request_id}/{filename}" | |
| def collect_images_from_dir(directory: str, request_id: str) -> Dict[str, str]: | |
| """Find all images in a directory and serve them. Returns {filename: url}.""" | |
| result = {} | |
| if not os.path.exists(directory): | |
| return result | |
| for root, dirs, files in os.walk(directory): | |
| for fname in files: | |
| ext = os.path.splitext(fname)[1].lower() | |
| if ext in IMAGE_EXTENSIONS: | |
| src = os.path.join(root, fname) | |
| # Preserve subdirectory structure in the filename | |
| rel_path = os.path.relpath(src, directory) | |
| safe_name = rel_path.replace(os.sep, "_") | |
| url = serve_file(src, request_id, safe_name) | |
| result[rel_path] = url | |
| return result | |
| def extract_pruned_result(res_obj, page_index: int = 0) -> Dict[str, Any]: | |
| """ | |
| Extract the full prunedResult from a PaddleOCR result object, | |
| matching the official Baidu API format. | |
| """ | |
| pruned = {} | |
| try: | |
| # Try to get the raw dict/json from the result object | |
| if hasattr(res_obj, 'json'): | |
| raw = res_obj.json if isinstance(res_obj.json, dict) else {} | |
| elif hasattr(res_obj, '_result'): | |
| raw = res_obj._result if isinstance(res_obj._result, dict) else {} | |
| elif hasattr(res_obj, 'to_dict'): | |
| raw = res_obj.to_dict() | |
| else: | |
| raw = {} | |
| # Try multiple attribute paths to find the parsing results | |
| parsing_res_list = [] | |
| layout_det_res = {"boxes": []} | |
| # Check common attribute names | |
| for attr in ['parsing_res_list', 'parsing_result', 'blocks']: | |
| if hasattr(res_obj, attr): | |
| parsing_res_list = getattr(res_obj, attr, []) | |
| break | |
| # Check for layout detection results | |
| for attr in ['layout_det_res', 'layout_result', 'det_res']: | |
| if hasattr(res_obj, attr): | |
| layout_det_res = getattr(res_obj, attr, {}) | |
| break | |
| # Get image dimensions | |
| width = 0 | |
| height = 0 | |
| for attr in ['img_width', 'width']: | |
| if hasattr(res_obj, attr): | |
| width = getattr(res_obj, attr, 0) | |
| break | |
| for attr in ['img_height', 'height']: | |
| if hasattr(res_obj, attr): | |
| height = getattr(res_obj, attr, 0) | |
| break | |
| # If we got raw dict, try to extract from it | |
| if raw and not parsing_res_list: | |
| parsing_res_list = raw.get('parsing_res_list', raw.get('blocks', [])) | |
| layout_det_res = raw.get('layout_det_res', {"boxes": []}) | |
| width = raw.get('width', width) | |
| height = raw.get('height', height) | |
| pruned = { | |
| "page_count": 1, | |
| "width": width, | |
| "height": height, | |
| "model_settings": { | |
| "use_doc_preprocessor": False, | |
| "use_layout_detection": True, | |
| "use_chart_recognition": False, | |
| "use_seal_recognition": True, | |
| "use_ocr_for_image_block": False, | |
| "format_block_content": True, | |
| "merge_layout_blocks": True, | |
| "markdown_ignore_labels": [ | |
| "number", "footnote", "header", | |
| "header_image", "footer", "footer_image", "aside_text" | |
| ], | |
| "return_layout_polygon_points": True | |
| }, | |
| "parsing_res_list": parsing_res_list if isinstance(parsing_res_list, list) else [], | |
| "layout_det_res": layout_det_res if isinstance(layout_det_res, dict) else {"boxes": []} | |
| } | |
| except Exception as e: | |
| print(f"Warning: Could not extract prunedResult: {e}") | |
| traceback.print_exc() | |
| pruned = { | |
| "page_count": 1, | |
| "width": 0, | |
| "height": 0, | |
| "model_settings": {}, | |
| "parsing_res_list": [], | |
| "layout_det_res": {"boxes": []} | |
| } | |
| return pruned | |
| def full_document_parsing(file_data: str, use_chart_recognition: bool = False, | |
| use_doc_unwarping: bool = True, | |
| use_doc_orientation_classify: bool = True) -> Dict[str, Any]: | |
| """Full document parsing β returns response matching official Baidu API format.""" | |
| tmp_path = save_temp_image(file_data) | |
| request_id = str(uuid.uuid4())[:12] | |
| try: | |
| # Get image dimensions | |
| try: | |
| img = Image.open(tmp_path) | |
| img_width, img_height = img.size | |
| img.close() | |
| except Exception: | |
| img_width, img_height = 0, 0 | |
| pipe = get_pipeline() | |
| output = pipe.predict(tmp_path) | |
| layout_parsing_results = [] | |
| preprocessed_images = [] | |
| data_info_pages = [] | |
| for i, res in enumerate(output): | |
| page_id = f"{request_id}_p{i}" | |
| output_dir = tempfile.mkdtemp() | |
| # Save all outputs | |
| res.save_to_json(save_path=output_dir) | |
| res.save_to_markdown(save_path=output_dir) | |
| try: | |
| res.save_to_img(save_path=output_dir) | |
| except Exception: | |
| pass | |
| # --- Read markdown --- | |
| md_text = "" | |
| md_files = [f for f in os.listdir(output_dir) if f.endswith(".md")] | |
| if md_files: | |
| with open(os.path.join(output_dir, md_files[0]), "r", encoding="utf-8") as f: | |
| md_text = f.read() | |
| # --- Read JSON (contains prunedResult data) --- | |
| json_data = {} | |
| json_files = [f for f in os.listdir(output_dir) if f.endswith(".json")] | |
| if json_files: | |
| with open(os.path.join(output_dir, json_files[0]), "r", encoding="utf-8") as f: | |
| json_data = json.load(f) | |
| # --- Collect and serve all images --- | |
| all_images = collect_images_from_dir(output_dir, page_id) | |
| # --- Build outputImages --- | |
| output_images = {} | |
| for rel_path, url in all_images.items(): | |
| name = os.path.splitext(os.path.basename(rel_path))[0] | |
| # Identify layout detection visualization | |
| if "layout" in name.lower() or "det" in name.lower() or "vis" in name.lower(): | |
| output_images["layout_det_res"] = url | |
| else: | |
| output_images[name] = url | |
| # --- Build markdown images map --- | |
| md_images = {} | |
| imgs_dir = os.path.join(output_dir, "imgs") | |
| if os.path.exists(imgs_dir): | |
| for fname in os.listdir(imgs_dir): | |
| ext = os.path.splitext(fname)[1].lower() | |
| if ext in IMAGE_EXTENSIONS: | |
| src = os.path.join(imgs_dir, fname) | |
| url = serve_file(src, page_id, fname) | |
| local_ref = f"imgs/{fname}" | |
| md_images[local_ref] = url | |
| # Replace references in markdown | |
| md_text = md_text.replace(f'src="{local_ref}"', f'src="{url}"') | |
| md_text = md_text.replace(f']({local_ref})', f']({url})') | |
| # --- Serve input image --- | |
| input_image_url = serve_file(tmp_path, page_id, f"input_img_{i}.jpg") | |
| # --- Build prunedResult from JSON data or result object --- | |
| pruned_result = {} | |
| if json_data: | |
| # Try to use the saved JSON directly | |
| pruned_result = { | |
| "page_count": json_data.get("page_count", 1), | |
| "width": json_data.get("width", img_width), | |
| "height": json_data.get("height", img_height), | |
| "model_settings": json_data.get("model_settings", { | |
| "use_doc_preprocessor": False, | |
| "use_layout_detection": True, | |
| "use_chart_recognition": use_chart_recognition, | |
| "use_seal_recognition": True, | |
| "use_ocr_for_image_block": False, | |
| "format_block_content": True, | |
| "merge_layout_blocks": True, | |
| "markdown_ignore_labels": [ | |
| "number", "footnote", "header", | |
| "header_image", "footer", "footer_image", "aside_text" | |
| ], | |
| "return_layout_polygon_points": True | |
| }), | |
| "parsing_res_list": json_data.get("parsing_res_list", | |
| json_data.get("blocks", [])), | |
| "layout_det_res": json_data.get("layout_det_res", | |
| json_data.get("det_res", {"boxes": []})) | |
| } | |
| else: | |
| pruned_result = extract_pruned_result(res, i) | |
| # Ensure dimensions are set | |
| if not pruned_result.get("width"): | |
| pruned_result["width"] = img_width | |
| if not pruned_result.get("height"): | |
| pruned_result["height"] = img_height | |
| # --- Build page result --- | |
| page_result = { | |
| "prunedResult": pruned_result, | |
| "markdown": { | |
| "text": md_text, | |
| "images": md_images | |
| }, | |
| "outputImages": output_images, | |
| "inputImage": input_image_url | |
| } | |
| layout_parsing_results.append(page_result) | |
| preprocessed_images.append(input_image_url) | |
| data_info_pages.append({ | |
| "width": img_width, | |
| "height": img_height | |
| }) | |
| return { | |
| "errorCode": 0, | |
| "result": { | |
| "layoutParsingResults": layout_parsing_results if layout_parsing_results else [{ | |
| "prunedResult": { | |
| "page_count": 0, | |
| "width": 0, | |
| "height": 0, | |
| "parsing_res_list": [], | |
| "layout_det_res": {"boxes": []} | |
| }, | |
| "markdown": {"text": "", "images": {}}, | |
| "outputImages": {}, | |
| "inputImage": "" | |
| }], | |
| "preprocessedImages": preprocessed_images, | |
| "dataInfo": { | |
| "type": "image", | |
| "numPages": len(layout_parsing_results), | |
| "pages": data_info_pages | |
| } | |
| } | |
| } | |
| finally: | |
| if os.path.exists(tmp_path): | |
| os.unlink(tmp_path) | |
| def element_level_recognition(file_data: str, prompt_label: str) -> Dict[str, Any]: | |
| """Element-level recognition via direct vLLM call.""" | |
| if file_data.startswith(("http://", "https://")): | |
| image_url = file_data | |
| else: | |
| image_url = f"data:image/png;base64,{file_data}" | |
| task_prompt = TASK_PROMPTS.get(prompt_label, "OCR:") | |
| response = openai_client.chat.completions.create( | |
| model=VLLM_MODEL_NAME, | |
| messages=[{ | |
| "role": "user", | |
| "content": [ | |
| {"type": "image_url", "image_url": {"url": image_url}}, | |
| {"type": "text", "text": task_prompt} | |
| ] | |
| }], | |
| temperature=0.0 | |
| ) | |
| result_text = response.choices[0].message.content | |
| return { | |
| "errorCode": 0, | |
| "result": { | |
| "layoutParsingResults": [{ | |
| "prunedResult": { | |
| "page_count": 1, | |
| "width": 0, | |
| "height": 0, | |
| "parsing_res_list": [{ | |
| "block_label": prompt_label, | |
| "block_content": result_text, | |
| "block_bbox": [], | |
| "block_id": 0, | |
| "block_order": 0, | |
| "group_id": 0, | |
| "global_block_id": 0, | |
| "global_group_id": 0, | |
| "block_polygon_points": [] | |
| }], | |
| "layout_det_res": {"boxes": []} | |
| }, | |
| "markdown": {"text": result_text, "images": {}}, | |
| "outputImages": {}, | |
| "prunedResult.spotting_res": _parse_spotting(result_text) if prompt_label == "spotting" else {} | |
| }] | |
| } | |
| } | |
| def _parse_spotting(text: str) -> dict: | |
| try: | |
| return json.loads(text) | |
| except (json.JSONDecodeError, TypeError): | |
| return {"raw_text": text} | |
| # ============================================================================= | |
| # Endpoints | |
| # ============================================================================= | |
| async def root(): | |
| return { | |
| "service": "PaddleOCR-VL-1.5 Bridge API", | |
| "status": "running", | |
| "endpoints": ["/health", "/api/ocr", "/api/parse", "/api/parse/markdown", "/v1/chat/completions", "/docs"] | |
| } | |
| async def health(): | |
| return {"status": "ok", "model": VLLM_MODEL_NAME, "vllm_url": VLLM_SERVER_URL} | |
| async def ocr_endpoint(request: Request, authorization: Optional[str] = Header(None)): | |
| """ | |
| Main OCR endpoint β compatible with the Gradio app. | |
| Returns full JSON matching official Baidu API format. | |
| Body: | |
| { | |
| "file": "base64_or_url", | |
| "useLayoutDetection": true/false, | |
| "promptLabel": "ocr|formula|table|chart|spotting|seal", | |
| "useChartRecognition": false, | |
| "useDocUnwarping": true, | |
| "useDocOrientationClassify": true | |
| } | |
| """ | |
| verify_auth(authorization) | |
| try: | |
| body = await request.json() | |
| except Exception: | |
| raise HTTPException(status_code=400, detail="Invalid JSON body") | |
| file_data = body.get("file", "") | |
| if not file_data: | |
| raise HTTPException(status_code=400, detail="Missing 'file' field") | |
| use_layout = body.get("useLayoutDetection", False) | |
| prompt_label = body.get("promptLabel", "ocr") | |
| use_chart = body.get("useChartRecognition", False) | |
| use_unwarp = body.get("useDocUnwarping", True) | |
| use_orient = body.get("useDocOrientationClassify", True) | |
| try: | |
| if use_layout: | |
| return full_document_parsing(file_data, use_chart, use_unwarp, use_orient) | |
| else: | |
| return element_level_recognition(file_data, prompt_label) | |
| except Exception as e: | |
| traceback.print_exc() | |
| return {"errorCode": -1, "errorMsg": str(e)} | |
| async def parse_file( | |
| file: UploadFile = File(...), | |
| use_layout_detection: bool = True, | |
| prompt_label: str = "ocr", | |
| authorization: Optional[str] = Header(None) | |
| ): | |
| """File upload endpoint.""" | |
| verify_auth(authorization) | |
| content = await file.read() | |
| b64 = base64.b64encode(content).decode("utf-8") | |
| try: | |
| if use_layout_detection: | |
| return full_document_parsing(b64) | |
| else: | |
| return element_level_recognition(b64, prompt_label) | |
| except Exception as e: | |
| traceback.print_exc() | |
| return {"errorCode": -1, "errorMsg": str(e)} | |
| async def parse_to_markdown( | |
| file: UploadFile = File(...), | |
| authorization: Optional[str] = Header(None) | |
| ): | |
| """Returns just markdown text.""" | |
| verify_auth(authorization) | |
| content = await file.read() | |
| b64 = base64.b64encode(content).decode("utf-8") | |
| try: | |
| result = full_document_parsing(b64) | |
| pages = result.get("result", {}).get("layoutParsingResults", []) | |
| markdown_parts = [p.get("markdown", {}).get("text", "") for p in pages if p.get("markdown", {}).get("text")] | |
| return { | |
| "status": "ok", | |
| "markdown": "\n\n---\n\n".join(markdown_parts), | |
| "page_count": len(pages) | |
| } | |
| except Exception as e: | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def proxy_chat_completions(request: Request, authorization: Optional[str] = Header(None)): | |
| """Proxy to vLLM for direct OpenAI-compatible calls.""" | |
| verify_auth(authorization) | |
| import httpx | |
| body = await request.json() | |
| async with httpx.AsyncClient(timeout=600) as client: | |
| resp = await client.post( | |
| f"{VLLM_SERVER_URL}/chat/completions", | |
| json=body, | |
| headers={"Content-Type": "application/json"} | |
| ) | |
| return resp.json() | |
| # ============================================================================= | |
| # Entry point | |
| # ============================================================================= | |
| if __name__ == "__main__": | |
| print(f""" | |
| ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| β PaddleOCR-VL-1.5 Bridge Server (HF Spaces) β | |
| β βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ£ | |
| β Bridge API: http://0.0.0.0:{BRIDGE_PORT} β | |
| β vLLM backend: {VLLM_SERVER_URL:<44s}β | |
| β Model: {VLLM_MODEL_NAME:<44s}β | |
| β Auth: {"ENABLED" if API_KEY else "DISABLED":<44s}β | |
| β Static URL: {PUBLIC_BASE_URL:<44s}β | |
| β βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ£ | |
| β Endpoints: β | |
| β GET /health - Health check β | |
| β GET /docs - Swagger UI β | |
| β POST /api/ocr - Gradio-compatible API β | |
| β POST /api/parse - File upload API β | |
| β POST /api/parse/markdown - Simple markdown output β | |
| β POST /v1/chat/completions - vLLM proxy (OpenAI format) β | |
| β GET /static/... - Output images β | |
| ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| """) | |
| uvicorn.run(app, host="0.0.0.0", port=BRIDGE_PORT) |