Spaces:
Sleeping
Sleeping
File size: 3,876 Bytes
89f9504 0a5fe17 89f9504 0a5fe17 89f9504 0a5fe17 89f9504 0a5fe17 89f9504 0a5fe17 89f9504 6703a47 89f9504 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 | from __future__ import annotations
import importlib.util
import os
from functools import lru_cache
from types import ModuleType
from typing import Optional
import gradio as gr
from huggingface_hub import hf_hub_download
from PIL import Image
DEFAULT_MODEL_REPO = os.environ.get("OCR_MODEL_REPO_ID", "metythorn/ViTOCR-base")
DEFAULT_MODEL_FILENAME = os.environ.get("OCR_MODEL_FILENAME", "model.onnx")
DEFAULT_CONFIG_FILENAME = os.environ.get("OCR_CONFIG_FILENAME", "config.json")
DEFAULT_INFERENCE_FILENAME = os.environ.get("OCR_INFERENCE_FILENAME", "onnx_inference.py")
DEFAULT_REVISION = os.environ.get("OCR_MODEL_REVISION")
def _resolve_providers() -> list[str]:
raw = os.environ.get("OCR_ONNX_PROVIDERS", "CPUExecutionProvider")
providers = [p.strip() for p in raw.split(",") if p.strip()]
return providers or ["CPUExecutionProvider"]
def _download_from_hub(filename: str) -> str:
token = os.environ.get("HF_TOKEN")
return hf_hub_download(
repo_id=DEFAULT_MODEL_REPO,
filename=filename,
repo_type="model",
revision=DEFAULT_REVISION,
token=token,
)
@lru_cache(maxsize=1)
def _load_inference_module() -> ModuleType:
inference_path = os.environ.get("OCR_ONNX_INFERENCE_PATH")
if inference_path:
if not os.path.exists(inference_path):
raise FileNotFoundError(f"OCR_ONNX_INFERENCE_PATH does not exist: {inference_path}")
else:
inference_path = _download_from_hub(DEFAULT_INFERENCE_FILENAME)
spec = importlib.util.spec_from_file_location("onnx_inference_runtime", inference_path)
if spec is None or spec.loader is None:
raise ImportError(f"Could not load module spec from: {inference_path}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
@lru_cache(maxsize=1)
def _get_predictor():
model_path = os.environ.get("OCR_ONNX_PATH")
config_path = os.environ.get("OCR_CONFIG_PATH")
if model_path and not os.path.exists(model_path):
raise FileNotFoundError(f"OCR_ONNX_PATH does not exist: {model_path}")
if config_path and not os.path.exists(config_path):
raise FileNotFoundError(f"OCR_CONFIG_PATH does not exist: {config_path}")
if not model_path:
model_path = _download_from_hub(DEFAULT_MODEL_FILENAME)
if not config_path:
config_path = _download_from_hub(DEFAULT_CONFIG_FILENAME)
inference_module = _load_inference_module()
predictor_cls = getattr(inference_module, "ONNXPredictor", None)
if predictor_cls is None:
raise AttributeError("onnx_inference.py does not define ONNXPredictor")
return predictor_cls(
model_path=model_path,
config_path=config_path,
providers=_resolve_providers(),
)
def ocr_image(image: Optional[Image.Image]) -> str:
if image is None:
return ""
try:
predictor = _get_predictor()
return predictor.predict(image)
except Exception as exc:
return f"OCR error: {exc}"
def build_app() -> gr.Blocks:
with gr.Blocks(title="ViTOCR ONNX Demo") as demo:
gr.Markdown(
"Upload an image to test Khmer OCR. Inference on CPU can be slow, so please be patient."
)
with gr.Row():
input_image = gr.Image(
label="Input Image",
type="pil",
sources=["upload", "clipboard", "webcam"],
image_mode="RGB",
)
output_text = gr.Textbox(label="Recognized Text", lines=8)
run_btn = gr.Button("Run OCR", variant="primary")
run_btn.click(fn=ocr_image, inputs=input_image, outputs=output_text)
input_image.change(fn=ocr_image, inputs=input_image, outputs=output_text)
return demo
if __name__ == "__main__":
app = build_app()
app.launch(server_name="0.0.0.0")
|