Toughen1 commited on
Commit
c049531
·
verified ·
1 Parent(s): 2f92370
Files changed (1) hide show
  1. app.py +45 -85
app.py CHANGED
@@ -1,20 +1,17 @@
 
1
  import functools
2
- import io
3
- import base64
4
  from queue import Queue
5
- from threading import Thread, Event
6
- from typing import List
7
 
8
- import atexit
9
- from fastapi import FastAPI, HTTPException
10
- from pydantic import BaseModel
11
- from PIL import Image
12
  from paddleocr import PaddleOCR, draw_ocr
 
 
 
13
  import gradio as gr
14
- import uvicorn
15
- import threading
16
 
17
- # ---------- 配置 ----------
18
  LANG_CONFIG = {
19
  "ch": {"num_workers": 2},
20
  "en": {"num_workers": 2},
@@ -23,20 +20,18 @@ LANG_CONFIG = {
23
  "korean": {"num_workers": 1},
24
  "japan": {"num_workers": 1},
25
  }
26
-
27
  CONCURRENCY_LIMIT = 8
28
 
29
-
30
- # ---------- 模型池管理 ----------
31
- class PaddleOCRModelManager:
32
  def __init__(self, num_workers, model_factory):
 
33
  self._model_factory = model_factory
34
  self._queue = Queue()
35
  self._workers = []
36
  self._model_initialized_event = Event()
37
-
38
  for _ in range(num_workers):
39
- worker = Thread(target=self._worker, daemon=True)
40
  worker.start()
41
  self._model_initialized_event.wait()
42
  self._model_initialized_event.clear()
@@ -73,44 +68,46 @@ class PaddleOCRModelManager:
73
  finally:
74
  self._queue.task_done()
75
 
76
-
77
  def create_model(lang):
78
  return PaddleOCR(lang=lang, use_angle_cls=True, use_gpu=False)
79
 
80
-
81
- # ---------- 初始化模型池 ----------
82
  model_managers = {
83
  lang: PaddleOCRModelManager(cfg["num_workers"], functools.partial(create_model, lang=lang))
84
  for lang, cfg in LANG_CONFIG.items()
85
  }
86
 
87
-
88
  def close_model_managers():
89
  for manager in model_managers.values():
90
  manager.close()
91
 
92
-
93
  atexit.register(close_model_managers)
94
 
95
- # ---------- Gradio 推理函数 ----------
96
- def inference(img, lang):
97
  ocr = model_managers[lang]
98
- result = ocr.infer(img, cls=True)[0]
 
 
 
99
 
100
- image = Image.open(img).convert("RGB")
101
  boxes = [line[0] for line in result]
102
  txts = [line[1][0] for line in result]
103
  scores = [line[1][1] for line in result]
104
  im_show = draw_ocr(image, boxes, txts, scores, font_path="./simfang.ttf")
105
- return im_show
106
 
 
 
 
 
 
107
 
108
- # ---------- Gradio Web UI ----------
109
- title = 'PaddleOCR'
110
  description = '''
111
- - PaddleOCR Gradio demo 支持中、英、法、德、韩、日文图像文字识别。
112
- - 上传图像并选择语言即可识别;也可以通过 API 接口以 base64 图片方式调用。
113
- - 文档见:https://github.com/PaddlePaddle/PaddleOCR
114
  '''
115
 
116
  examples = [
@@ -121,70 +118,33 @@ examples = [
121
 
122
  css = ".output_image, .input_image {height: 40rem !important; width: 100% !important;}"
123
 
124
- gr.Interface(
125
- inference,
126
- [
127
  gr.Image(type='filepath', label='Input'),
128
  gr.Dropdown(choices=list(LANG_CONFIG.keys()), value='en', label='language')
129
  ],
130
- gr.Image(type='pil', label='Output'),
131
  title=title,
132
  description=description,
133
  examples=examples,
134
  cache_examples=False,
135
  css=css,
136
- concurrency_limit=CONCURRENCY_LIMIT,
137
- ).launch(share=False, debug=False, prevent_thread_lock=True)
138
-
139
-
140
- # ---------- FastAPI 接口(Base64) ----------
141
- app = FastAPI(
142
- title="PaddleOCR REST API",
143
- description="Support base64 image OCR with multi-language",
144
- version="1.0.0"
145
  )
146
 
 
 
147
 
148
- class PredictRequest(BaseModel):
149
- image_base64: str
150
- lang: str
151
-
152
-
153
- @app.post("/predict")
154
- async def predict(request: PredictRequest):
155
- lang = request.lang.lower()
156
- if lang not in model_managers:
157
- raise HTTPException(status_code=400, detail=f"Unsupported language: {lang}")
158
-
159
  try:
160
- image_data = base64.b64decode(request.image_base64.split(",")[-1])
161
- image = Image.open(io.BytesIO(image_data)).convert("RGB")
162
- temp_path = "/tmp/temp_image.png"
163
- image.save(temp_path)
164
  except Exception as e:
165
- raise HTTPException(status_code=400, detail=f"Invalid base64 image: {str(e)}")
166
-
167
- ocr = model_managers[lang]
168
- result = ocr.infer(temp_path, cls=True)[0]
169
- boxes = [line[0] for line in result]
170
- txts = [line[1][0] for line in result]
171
- scores = [line[1][1] for line in result]
172
-
173
- im_show = draw_ocr(image, boxes, txts, scores, font_path="./simfang.ttf")
174
- buf = io.BytesIO()
175
- im_show.save(buf, format="PNG")
176
- image_base64 = base64.b64encode(buf.getvalue()).decode("utf-8")
177
-
178
- return {
179
- "texts": txts,
180
- "scores": scores,
181
- "image_base64": "data:image/png;base64," + image_base64
182
- }
183
-
184
-
185
- # ---------- 后台启动 FastAPI ----------
186
- def run_api():
187
- uvicorn.run(app, host="0.0.0.0", port=7861)
188
-
189
 
190
- threading.Thread(target=run_api, daemon=True).start()
 
 
1
+ import atexit
2
  import functools
 
 
3
  from queue import Queue
4
+ from threading import Event, Thread
 
5
 
 
 
 
 
6
  from paddleocr import PaddleOCR, draw_ocr
7
+ from PIL import Image
8
+ from io import BytesIO
9
+ import base64
10
  import gradio as gr
11
+ from fastapi import FastAPI, UploadFile, Form
12
+ from pydantic import BaseModel
13
 
14
+ # ========== 模型配置 ==========
15
  LANG_CONFIG = {
16
  "ch": {"num_workers": 2},
17
  "en": {"num_workers": 2},
 
20
  "korean": {"num_workers": 1},
21
  "japan": {"num_workers": 1},
22
  }
 
23
  CONCURRENCY_LIMIT = 8
24
 
25
+ # ========== 模型池管理类 ==========
26
+ class PaddleOCRModelManager(object):
 
27
  def __init__(self, num_workers, model_factory):
28
+ super().__init__()
29
  self._model_factory = model_factory
30
  self._queue = Queue()
31
  self._workers = []
32
  self._model_initialized_event = Event()
 
33
  for _ in range(num_workers):
34
+ worker = Thread(target=self._worker, daemon=False)
35
  worker.start()
36
  self._model_initialized_event.wait()
37
  self._model_initialized_event.clear()
 
68
  finally:
69
  self._queue.task_done()
70
 
71
+ # ========== OCR 模型初始化 ==========
72
  def create_model(lang):
73
  return PaddleOCR(lang=lang, use_angle_cls=True, use_gpu=False)
74
 
 
 
75
  model_managers = {
76
  lang: PaddleOCRModelManager(cfg["num_workers"], functools.partial(create_model, lang=lang))
77
  for lang, cfg in LANG_CONFIG.items()
78
  }
79
 
 
80
  def close_model_managers():
81
  for manager in model_managers.values():
82
  manager.close()
83
 
 
84
  atexit.register(close_model_managers)
85
 
86
+ # ========== 通用 OCR 推理函数 ==========
87
+ def run_ocr(image: Image.Image, lang: str):
88
  ocr = model_managers[lang]
89
+ buffered = BytesIO()
90
+ image.save(buffered, format="PNG")
91
+ buffered.seek(0)
92
+ result = ocr.infer(buffered, cls=True)[0]
93
 
 
94
  boxes = [line[0] for line in result]
95
  txts = [line[1][0] for line in result]
96
  scores = [line[1][1] for line in result]
97
  im_show = draw_ocr(image, boxes, txts, scores, font_path="./simfang.ttf")
98
+ return im_show, txts
99
 
100
+ # ========== Gradio UI ==========
101
+ def gradio_inference(img_path, lang):
102
+ image = Image.open(img_path).convert("RGB")
103
+ result_image, _ = run_ocr(image, lang)
104
+ return result_image
105
 
106
+ title = "PaddleOCR"
 
107
  description = '''
108
+ - Gradio demo for PaddleOCR with multi-language support.
109
+ - Supports Chinese, English, French, German, Korean, and Japanese.
110
+ - Upload an image or use the RESTful API below.
111
  '''
112
 
113
  examples = [
 
118
 
119
  css = ".output_image, .input_image {height: 40rem !important; width: 100% !important;}"
120
 
121
+ gr_app = gr.Interface(
122
+ gradio_inference,
123
+ inputs=[
124
  gr.Image(type='filepath', label='Input'),
125
  gr.Dropdown(choices=list(LANG_CONFIG.keys()), value='en', label='language')
126
  ],
127
+ outputs=gr.Image(type='pil', label='Output'),
128
  title=title,
129
  description=description,
130
  examples=examples,
131
  cache_examples=False,
132
  css=css,
133
+ concurrency_limit=CONCURRENCY_LIMIT
 
 
 
 
 
 
 
 
134
  )
135
 
136
+ # ========== FastAPI + REST OCR ==========
137
+ app = FastAPI()
138
 
139
+ @app.post("/api/ocr_base64")
140
+ def ocr_base64(data: str = Form(...), lang: str = Form("ch")):
 
 
 
 
 
 
 
 
 
141
  try:
142
+ content = base64.b64decode(data)
143
+ image = Image.open(BytesIO(content)).convert("RGB")
144
+ _, texts = run_ocr(image, lang)
145
+ return {"success": True, "text": texts}
146
  except Exception as e:
147
+ return {"success": False, "error": str(e)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
+ # 挂载 Gradio 到 FastAPI
150
+ app = gr.mount_gradio_app(app, gr_app, path="/")