Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import importlib.util | |
| import os | |
| from functools import lru_cache | |
| from types import ModuleType | |
| from typing import Optional | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| from PIL import Image | |
| DEFAULT_MODEL_REPO = os.environ.get("OCR_MODEL_REPO_ID", "metythorn/ViTOCR-base") | |
| DEFAULT_MODEL_FILENAME = os.environ.get("OCR_MODEL_FILENAME", "model.onnx") | |
| DEFAULT_CONFIG_FILENAME = os.environ.get("OCR_CONFIG_FILENAME", "config.json") | |
| DEFAULT_INFERENCE_FILENAME = os.environ.get("OCR_INFERENCE_FILENAME", "onnx_inference.py") | |
| DEFAULT_REVISION = os.environ.get("OCR_MODEL_REVISION") | |
| def _resolve_providers() -> list[str]: | |
| raw = os.environ.get("OCR_ONNX_PROVIDERS", "CPUExecutionProvider") | |
| providers = [p.strip() for p in raw.split(",") if p.strip()] | |
| return providers or ["CPUExecutionProvider"] | |
| def _download_from_hub(filename: str) -> str: | |
| token = os.environ.get("HF_TOKEN") | |
| return hf_hub_download( | |
| repo_id=DEFAULT_MODEL_REPO, | |
| filename=filename, | |
| repo_type="model", | |
| revision=DEFAULT_REVISION, | |
| token=token, | |
| ) | |
| def _load_inference_module() -> ModuleType: | |
| inference_path = os.environ.get("OCR_ONNX_INFERENCE_PATH") | |
| if inference_path: | |
| if not os.path.exists(inference_path): | |
| raise FileNotFoundError(f"OCR_ONNX_INFERENCE_PATH does not exist: {inference_path}") | |
| else: | |
| inference_path = _download_from_hub(DEFAULT_INFERENCE_FILENAME) | |
| spec = importlib.util.spec_from_file_location("onnx_inference_runtime", inference_path) | |
| if spec is None or spec.loader is None: | |
| raise ImportError(f"Could not load module spec from: {inference_path}") | |
| module = importlib.util.module_from_spec(spec) | |
| spec.loader.exec_module(module) | |
| return module | |
| def _get_predictor(): | |
| model_path = os.environ.get("OCR_ONNX_PATH") | |
| config_path = os.environ.get("OCR_CONFIG_PATH") | |
| if model_path and not os.path.exists(model_path): | |
| raise FileNotFoundError(f"OCR_ONNX_PATH does not exist: {model_path}") | |
| if config_path and not os.path.exists(config_path): | |
| raise FileNotFoundError(f"OCR_CONFIG_PATH does not exist: {config_path}") | |
| if not model_path: | |
| model_path = _download_from_hub(DEFAULT_MODEL_FILENAME) | |
| if not config_path: | |
| config_path = _download_from_hub(DEFAULT_CONFIG_FILENAME) | |
| inference_module = _load_inference_module() | |
| predictor_cls = getattr(inference_module, "ONNXPredictor", None) | |
| if predictor_cls is None: | |
| raise AttributeError("onnx_inference.py does not define ONNXPredictor") | |
| return predictor_cls( | |
| model_path=model_path, | |
| config_path=config_path, | |
| providers=_resolve_providers(), | |
| ) | |
| def ocr_image(image: Optional[Image.Image]) -> str: | |
| if image is None: | |
| return "" | |
| try: | |
| predictor = _get_predictor() | |
| return predictor.predict(image) | |
| except Exception as exc: | |
| return f"OCR error: {exc}" | |
| def build_app() -> gr.Blocks: | |
| with gr.Blocks(title="ViTOCR ONNX Demo") as demo: | |
| gr.Markdown( | |
| "Upload an image to test Khmer OCR. Inference on CPU can be slow, so please be patient." | |
| ) | |
| with gr.Row(): | |
| input_image = gr.Image( | |
| label="Input Image", | |
| type="pil", | |
| sources=["upload", "clipboard", "webcam"], | |
| image_mode="RGB", | |
| ) | |
| output_text = gr.Textbox(label="Recognized Text", lines=8) | |
| run_btn = gr.Button("Run OCR", variant="primary") | |
| run_btn.click(fn=ocr_image, inputs=input_image, outputs=output_text) | |
| input_image.change(fn=ocr_image, inputs=input_image, outputs=output_text) | |
| return demo | |
| if __name__ == "__main__": | |
| app = build_app() | |
| app.launch(server_name="0.0.0.0") | |