Toughen1 commited on
Commit
47716d3
·
verified ·
1 Parent(s): 5dcc55e

支持Base64接口调用

Browse files
Files changed (1) hide show
  1. app.py +88 -28
app.py CHANGED
@@ -1,13 +1,20 @@
1
- import atexit
2
  import functools
 
 
3
  from queue import Queue
4
- from threading import Event, Thread
 
5
 
6
- from paddleocr import PaddleOCR, draw_ocr
 
 
7
  from PIL import Image
 
8
  import gradio as gr
 
 
9
 
10
-
11
  LANG_CONFIG = {
12
  "ch": {"num_workers": 2},
13
  "en": {"num_workers": 2},
@@ -16,27 +23,26 @@ LANG_CONFIG = {
16
  "korean": {"num_workers": 1},
17
  "japan": {"num_workers": 1},
18
  }
 
19
  CONCURRENCY_LIMIT = 8
20
 
21
 
22
- class PaddleOCRModelManager(object):
23
- def __init__(self,
24
- num_workers,
25
- model_factory):
26
- super().__init__()
27
  self._model_factory = model_factory
28
  self._queue = Queue()
29
  self._workers = []
30
  self._model_initialized_event = Event()
 
31
  for _ in range(num_workers):
32
- worker = Thread(target=self._worker, daemon=False)
33
  worker.start()
34
  self._model_initialized_event.wait()
35
  self._model_initialized_event.clear()
36
  self._workers.append(worker)
37
 
38
  def infer(self, *args, **kwargs):
39
- # XXX: Should I use a more lightweight data structure, say, a future?
40
  result_queue = Queue(maxsize=1)
41
  self._queue.put((args, kwargs, result_queue))
42
  success, payload = result_queue.get()
@@ -72,10 +78,11 @@ def create_model(lang):
72
  return PaddleOCR(lang=lang, use_angle_cls=True, use_gpu=False)
73
 
74
 
75
- model_managers = {}
76
- for lang, config in LANG_CONFIG.items():
77
- model_manager = PaddleOCRModelManager(config["num_workers"], functools.partial(create_model, lang=lang))
78
- model_managers[lang] = model_manager
 
79
 
80
 
81
  def close_model_managers():
@@ -83,37 +90,37 @@ def close_model_managers():
83
  manager.close()
84
 
85
 
86
- # XXX: Not sure if gradio allows adding custom teardown logic
87
  atexit.register(close_model_managers)
88
 
89
-
90
  def inference(img, lang):
91
  ocr = model_managers[lang]
92
  result = ocr.infer(img, cls=True)[0]
93
- img_path = img
94
- image = Image.open(img_path).convert("RGB")
95
  boxes = [line[0] for line in result]
96
  txts = [line[1][0] for line in result]
97
  scores = [line[1][1] for line in result]
98
- im_show = draw_ocr(image, boxes, txts, scores,
99
- font_path="./simfang.ttf")
100
  return im_show
101
 
102
 
 
103
  title = 'PaddleOCR'
104
  description = '''
105
- - Gradio demo for PaddleOCR. PaddleOCR demo supports Chinese, English, French, German, Korean and Japanese.
106
- - To use it, simply upload your image and choose a language from the dropdown menu, or click one of the examples to load them. Read more at the links below.
107
- - [Docs](https://paddlepaddle.github.io/PaddleOCR/), [Github Repository](https://github.com/PaddlePaddle/PaddleOCR).
108
  '''
109
 
110
  examples = [
111
- ['en_example.jpg','en'],
112
- ['cn_example.jpg','ch'],
113
- ['jp_example.jpg','japan'],
114
  ]
115
 
116
  css = ".output_image, .input_image {height: 40rem !important; width: 100% !important;}"
 
117
  gr.Interface(
118
  inference,
119
  [
@@ -127,4 +134,57 @@ gr.Interface(
127
  cache_examples=False,
128
  css=css,
129
  concurrency_limit=CONCURRENCY_LIMIT,
130
- ).launch(debug=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import functools
2
+ import io
3
+ import base64
4
  from queue import Queue
5
+ from threading import Thread, Event
6
+ from typing import List
7
 
8
+ import atexit
9
+ from fastapi import FastAPI, HTTPException
10
+ from pydantic import BaseModel
11
  from PIL import Image
12
+ from paddleocr import PaddleOCR, draw_ocr
13
  import gradio as gr
14
+ import uvicorn
15
+ import threading
16
 
17
+ # ---------- 配置 ----------
18
  LANG_CONFIG = {
19
  "ch": {"num_workers": 2},
20
  "en": {"num_workers": 2},
 
23
  "korean": {"num_workers": 1},
24
  "japan": {"num_workers": 1},
25
  }
26
+
27
  CONCURRENCY_LIMIT = 8
28
 
29
 
30
+ # ---------- 模型池管理 ----------
31
+ class PaddleOCRModelManager:
32
+ def __init__(self, num_workers, model_factory):
 
 
33
  self._model_factory = model_factory
34
  self._queue = Queue()
35
  self._workers = []
36
  self._model_initialized_event = Event()
37
+
38
  for _ in range(num_workers):
39
+ worker = Thread(target=self._worker, daemon=True)
40
  worker.start()
41
  self._model_initialized_event.wait()
42
  self._model_initialized_event.clear()
43
  self._workers.append(worker)
44
 
45
  def infer(self, *args, **kwargs):
 
46
  result_queue = Queue(maxsize=1)
47
  self._queue.put((args, kwargs, result_queue))
48
  success, payload = result_queue.get()
 
78
  return PaddleOCR(lang=lang, use_angle_cls=True, use_gpu=False)
79
 
80
 
81
+ # ---------- 初始化模型池 ----------
82
+ model_managers = {
83
+ lang: PaddleOCRModelManager(cfg["num_workers"], functools.partial(create_model, lang=lang))
84
+ for lang, cfg in LANG_CONFIG.items()
85
+ }
86
 
87
 
88
  def close_model_managers():
 
90
  manager.close()
91
 
92
 
 
93
  atexit.register(close_model_managers)
94
 
95
+ # ---------- Gradio 推理函数 ----------
96
  def inference(img, lang):
97
  ocr = model_managers[lang]
98
  result = ocr.infer(img, cls=True)[0]
99
+
100
+ image = Image.open(img).convert("RGB")
101
  boxes = [line[0] for line in result]
102
  txts = [line[1][0] for line in result]
103
  scores = [line[1][1] for line in result]
104
+ im_show = draw_ocr(image, boxes, txts, scores, font_path="./simfang.ttf")
 
105
  return im_show
106
 
107
 
108
+ # ---------- Gradio Web UI ----------
109
  title = 'PaddleOCR'
110
  description = '''
111
+ - PaddleOCR Gradio demo 支持中、英、法、德、韩、日文图像文字识别。
112
+ - 上传图像并选择语言即可识别;也可以通过 API 接口以 base64 图片方式调用。
113
+ - 文档见:https://github.com/PaddlePaddle/PaddleOCR
114
  '''
115
 
116
  examples = [
117
+ ['en_example.jpg', 'en'],
118
+ ['cn_example.jpg', 'ch'],
119
+ ['jp_example.jpg', 'japan'],
120
  ]
121
 
122
  css = ".output_image, .input_image {height: 40rem !important; width: 100% !important;}"
123
+
124
  gr.Interface(
125
  inference,
126
  [
 
134
  cache_examples=False,
135
  css=css,
136
  concurrency_limit=CONCURRENCY_LIMIT,
137
+ ).launch(share=False, debug=False, prevent_thread_lock=True)
138
+
139
+
140
+ # ---------- FastAPI 接口(Base64) ----------
141
+ app = FastAPI(
142
+ title="PaddleOCR REST API",
143
+ description="Support base64 image OCR with multi-language",
144
+ version="1.0.0"
145
+ )
146
+
147
+
148
+ class PredictRequest(BaseModel):
149
+ image_base64: str
150
+ lang: str
151
+
152
+
153
+ @app.post("/predict")
154
+ async def predict(request: PredictRequest):
155
+ lang = request.lang.lower()
156
+ if lang not in model_managers:
157
+ raise HTTPException(status_code=400, detail=f"Unsupported language: {lang}")
158
+
159
+ try:
160
+ image_data = base64.b64decode(request.image_base64.split(",")[-1])
161
+ image = Image.open(io.BytesIO(image_data)).convert("RGB")
162
+ temp_path = "/tmp/temp_image.png"
163
+ image.save(temp_path)
164
+ except Exception as e:
165
+ raise HTTPException(status_code=400, detail=f"Invalid base64 image: {str(e)}")
166
+
167
+ ocr = model_managers[lang]
168
+ result = ocr.infer(temp_path, cls=True)[0]
169
+ boxes = [line[0] for line in result]
170
+ txts = [line[1][0] for line in result]
171
+ scores = [line[1][1] for line in result]
172
+
173
+ im_show = draw_ocr(image, boxes, txts, scores, font_path="./simfang.ttf")
174
+ buf = io.BytesIO()
175
+ im_show.save(buf, format="PNG")
176
+ image_base64 = base64.b64encode(buf.getvalue()).decode("utf-8")
177
+
178
+ return {
179
+ "texts": txts,
180
+ "scores": scores,
181
+ "image_base64": "data:image/png;base64," + image_base64
182
+ }
183
+
184
+
185
+ # ---------- 后台启动 FastAPI ----------
186
+ def run_api():
187
+ uvicorn.run(app, host="0.0.0.0", port=7861)
188
+
189
+
190
+ threading.Thread(target=run_api, daemon=True).start()