n1ocr-api / app.py
handsme
修复 elapsed 处理
a469cb9
import gradio as gr
import json
import base64
import io
import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download
from rapidocr_onnxruntime import RapidOCR
from fastapi import FastAPI, Request
from fastapi.responses import RedirectResponse
import uvicorn
# 下载 PP-OCRv5 ONNX 模型
det_path = hf_hub_download("monkt/paddleocr-onnx", "detection/v5/det.onnx")
rec_path = hf_hub_download("monkt/paddleocr-onnx", "languages/english/rec.onnx")
dict_path = hf_hub_download("monkt/paddleocr-onnx", "languages/english/dict.txt")
# 初始化 RapidOCR(PP-OCRv5 ONNX 推理)
ocr_engine = RapidOCR(
det_model_path=det_path,
rec_model_path=rec_path,
rec_keys_path=dict_path,
)
def ocr_recognize(image):
"""识别图片中的文字"""
if image is None:
return json.dumps({"success": False, "error": "未提供图片"}, ensure_ascii=False)
try:
if isinstance(image, Image.Image):
img_array = np.array(image)
else:
img_array = image
# RapidOCR 调用
result, elapsed = ocr_engine(img_array)
lines = []
raw_results = []
if result:
for item in result:
# RapidOCR 返回格式: (bbox, text, confidence)
bbox = item[0] # [[x1,y1],[x2,y2],[x3,y3],[x4,y4]]
text = item[1]
confidence = item[2]
lines.append(text)
raw_results.append({
"text": text,
"confidence": round(float(confidence), 4),
"bbox": bbox if isinstance(bbox, list) else bbox.tolist(),
})
# elapsed 可能是包含多个阶段耗时的列表,求和即可
total_elapsed = sum(elapsed) if isinstance(elapsed, list) else elapsed
return json.dumps({
"success": True,
"lines": lines,
"full_text": "\n".join(lines),
"raw": raw_results,
"elapsed": round(total_elapsed, 3),
}, ensure_ascii=False, indent=2)
except Exception as e:
return json.dumps({"success": False, "error": str(e)}, ensure_ascii=False)
# ---- FastAPI 纯净接口 ----
app = FastAPI()
@app.post("/api/predict")
async def api_predict(request: Request):
"""接收 base64 编码的图片,返回 OCR 结果"""
try:
body = await request.json()
data_uri = body.get("data", [])[0]
base64_str = data_uri.split(",")[-1] if "," in data_uri else data_uri
image_bytes = base64.b64decode(base64_str)
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
# 限制图片最大边为 2000px,防止大图导致内存溢出
max_side = 2000
w, h = image.size
if max(w, h) > max_side:
scale = max_side / max(w, h)
image = image.resize((int(w * scale), int(h * scale)), Image.LANCZOS)
result_json_str = ocr_recognize(image)
return {"data": [result_json_str]}
except Exception as e:
return {"data": [json.dumps({"success": False, "error": str(e)})]}
# ---- Gradio 界面 ----
with gr.Blocks(title="n1payocr API - PP-OCRv5 ONNX", analytics_enabled=False) as demo:
gr.Markdown("# 🔍 n1payocr 文字识别引擎 (PP-OCRv5 ONNX)")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="上传图片")
submit_btn = gr.Button("识别", variant="primary")
with gr.Column():
output_text = gr.Textbox(label="识别结果 (JSON)", lines=20)
submit_btn.click(fn=ocr_recognize, inputs=input_image, outputs=output_text, api_name=False)
# 挂载 Gradio 到 /ui 路径
app = gr.mount_gradio_app(app, demo, path="/ui")
@app.get("/")
async def root():
return RedirectResponse(url="/ui")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)