Toughen1 commited on
Commit
aebed3f
·
verified ·
1 Parent(s): fd9728f
Files changed (1) hide show
  1. app.py +261 -65
app.py CHANGED
@@ -1,26 +1,43 @@
1
  import atexit
2
  import functools
 
 
 
3
  from queue import Queue
4
  from threading import Event, Thread
5
-
 
6
  from paddleocr import PaddleOCR, draw_ocr
7
  from PIL import Image
8
- from io import BytesIO
9
- import base64
10
  import gradio as gr
11
- from fastapi import FastAPI, UploadFile, Form
12
- from pydantic import BaseModel
13
 
14
- # ========== 模型配置 ==========
15
  LANG_CONFIG = {
16
  "ch": {"num_workers": 2},
17
  "en": {"num_workers": 2},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  }
 
19
  CONCURRENCY_LIMIT = 8
20
 
21
- # ========== 模型池管理类 ==========
22
  class PaddleOCRModelManager(object):
23
- def __init__(self, num_workers, model_factory):
 
 
24
  super().__init__()
25
  self._model_factory = model_factory
26
  self._queue = Queue()
@@ -34,6 +51,7 @@ class PaddleOCRModelManager(object):
34
  self._workers.append(worker)
35
 
36
  def infer(self, *args, **kwargs):
 
37
  result_queue = Queue(maxsize=1)
38
  self._queue.put((args, kwargs, result_queue))
39
  success, payload = result_queue.get()
@@ -64,83 +82,261 @@ class PaddleOCRModelManager(object):
64
  finally:
65
  self._queue.task_done()
66
 
67
- # ========== OCR 模型初始化 ==========
68
  def create_model(lang):
69
  return PaddleOCR(lang=lang, use_angle_cls=True, use_gpu=False)
70
 
71
- model_managers = {
72
- lang: PaddleOCRModelManager(cfg["num_workers"], functools.partial(create_model, lang=lang))
73
- for lang, cfg in LANG_CONFIG.items()
74
- }
 
 
75
 
76
  def close_model_managers():
77
  for manager in model_managers.values():
78
  manager.close()
79
 
 
 
80
  atexit.register(close_model_managers)
81
 
82
- # ========== 通用 OCR 推理函数 ==========
83
- def run_ocr(image: Image.Image, lang: str):
84
- ocr = model_managers[lang]
85
- buffered = BytesIO()
86
- image.save(buffered, format="PNG")
87
- buffered.seek(0)
88
- result = ocr.infer(buffered, cls=True)[0]
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  boxes = [line[0] for line in result]
91
  txts = [line[1][0] for line in result]
92
  scores = [line[1][1] for line in result]
93
- im_show = draw_ocr(image, boxes, txts, scores, font_path="./simfang.ttf")
94
- return im_show, txts
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
- # ========== Gradio UI ==========
97
- def gradio_inference(img_path, lang):
98
- image = Image.open(img_path).convert("RGB")
99
- result_image, _ = run_ocr(image, lang)
100
- return result_image
101
 
102
- title = "PaddleOCR"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  description = '''
104
- - Gradio demo for PaddleOCR with multi-language support.
105
- - Supports Chinese, English, French, German, Korean, and Japanese.
106
- - Upload an image or use the RESTful API below.
 
 
 
 
 
 
 
107
  '''
108
 
109
  examples = [
110
- ['en_example.jpg', 'en'],
111
- ['cn_example.jpg', 'ch'],
112
- ['jp_example.jpg', 'japan'],
113
  ]
114
 
115
- css = ".output_image, .input_image {height: 40rem !important; width: 100% !important;}"
116
-
117
- gr_app = gr.Interface(
118
- gradio_inference,
119
- inputs=[
120
- gr.Image(type='filepath', label='Input'),
121
- gr.Dropdown(choices=list(LANG_CONFIG.keys()), value='en', label='language')
122
- ],
123
- outputs=gr.Image(type='pil', label='Output'),
124
- title=title,
125
- description=description,
126
- examples=examples,
127
- cache_examples=False,
128
- css=css,
129
- concurrency_limit=CONCURRENCY_LIMIT
130
- )
131
-
132
- # ========== FastAPI + REST OCR ==========
133
- app = FastAPI()
134
-
135
- @app.post("/api/ocr_base64")
136
- def ocr_base64(data: str = Form(...), lang: str = Form("ch")):
137
- try:
138
- content = base64.b64decode(data)
139
- image = Image.open(BytesIO(content)).convert("RGB")
140
- _, texts = run_ocr(image, lang)
141
- return {"success": True, "text": texts}
142
- except Exception as e:
143
- return {"success": False, "error": str(e)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
- # 挂载 Gradio 到 FastAPI
146
- app = gr.mount_gradio_app(app, gr_app, path="/")
 
1
  import atexit
2
  import functools
3
+ import base64
4
+ import io
5
+ import re
6
  from queue import Queue
7
  from threading import Event, Thread
8
+ import numpy as np
9
+ from langdetect import detect
10
  from paddleocr import PaddleOCR, draw_ocr
11
  from PIL import Image
 
 
12
  import gradio as gr
 
 
13
 
14
+
15
  LANG_CONFIG = {
16
  "ch": {"num_workers": 2},
17
  "en": {"num_workers": 2},
18
+ "fr": {"num_workers": 1},
19
+ "german": {"num_workers": 1},
20
+ "korean": {"num_workers": 1},
21
+ "japan": {"num_workers": 1},
22
+ }
23
+
24
+ # 语言检测映射
25
+ LANG_DETECT_MAP = {
26
+ "zh": "ch",
27
+ "en": "en",
28
+ "fr": "fr",
29
+ "de": "german",
30
+ "ko": "korean",
31
+ "ja": "japan",
32
  }
33
+
34
  CONCURRENCY_LIMIT = 8
35
 
36
+
37
  class PaddleOCRModelManager(object):
38
+ def __init__(self,
39
+ num_workers,
40
+ model_factory):
41
  super().__init__()
42
  self._model_factory = model_factory
43
  self._queue = Queue()
 
51
  self._workers.append(worker)
52
 
53
  def infer(self, *args, **kwargs):
54
+ # XXX: Should I use a more lightweight data structure, say, a future?
55
  result_queue = Queue(maxsize=1)
56
  self._queue.put((args, kwargs, result_queue))
57
  success, payload = result_queue.get()
 
82
  finally:
83
  self._queue.task_done()
84
 
85
+
86
  def create_model(lang):
87
  return PaddleOCR(lang=lang, use_angle_cls=True, use_gpu=False)
88
 
89
+
90
+ model_managers = {}
91
+ for lang, config in LANG_CONFIG.items():
92
+ model_manager = PaddleOCRModelManager(config["num_workers"], functools.partial(create_model, lang=lang))
93
+ model_managers[lang] = model_manager
94
+
95
 
96
  def close_model_managers():
97
  for manager in model_managers.values():
98
  manager.close()
99
 
100
+
101
+ # XXX: Not sure if gradio allows adding custom teardown logic
102
  atexit.register(close_model_managers)
103
 
 
 
 
 
 
 
 
104
 
105
+ def detect_language_from_text(text):
106
+ """根据文本内容自动检测语言"""
107
+ try:
108
+ detected = detect(text)
109
+ return LANG_DETECT_MAP.get(detected, "en") # 默认返回英文
110
+ except:
111
+ return "en" # 检测失败时默认返回英文
112
+
113
+
114
+ def auto_detect_language(image):
115
+ """尝试从图像中检测语言"""
116
+ # 先用英文OCR提取一些文本
117
+ ocr = model_managers["en"]
118
+ try:
119
+ result = ocr.infer(image, cls=True)[0]
120
+ if not result:
121
+ return "en" # 如果没有检测到文本,默认使用英文
122
+
123
+ # 将所有文本合并起来进行语言检测
124
+ all_text = " ".join([line[1][0] for line in result])
125
+ if not all_text.strip():
126
+ return "en"
127
+
128
+ # 检测语言
129
+ lang = detect_language_from_text(all_text)
130
+ return lang
131
+ except:
132
+ return "en" # 出错时默认使用英文
133
+
134
+
135
+ def process_base64_image(base64_string):
136
+ """处理Base64编码的图像"""
137
+ try:
138
+ # 移除可能的前缀
139
+ if "base64," in base64_string:
140
+ base64_string = base64_string.split("base64,")[1]
141
+
142
+ # 解码Base64
143
+ image_data = base64.b64decode(base64_string)
144
+ image = Image.open(io.BytesIO(image_data))
145
+
146
+ # 将PIL图像转换为临时文件
147
+ temp_io = io.BytesIO()
148
+ image.save(temp_io, format='PNG')
149
+ temp_io.seek(0)
150
+
151
+ return temp_io, image
152
+ except Exception as e:
153
+ raise ValueError(f"处理Base64图像时出错: {str(e)}")
154
+
155
+
156
+ def inference(img, return_text_only=True):
157
+ """OCR推理函数,自动检测语言"""
158
+ # 处理输入图像
159
+ if isinstance(img, str) and img.startswith("data:") or re.match(r'^[A-Za-z0-9+/=]+$', img):
160
+ # 处理Base64输入
161
+ img_io, pil_img = process_base64_image(img)
162
+ img_path = img_io
163
+ else:
164
+ # 处理文件路径输入
165
+ img_path = img
166
+ pil_img = Image.open(img_path).convert("RGB")
167
+
168
+ # 自动检测语言
169
+ lang = auto_detect_language(img_path)
170
+
171
+ # 使用检测到的语言进行OCR
172
+ ocr = model_managers[lang]
173
+ result = ocr.infer(img_path, cls=True)[0]
174
+
175
+ # 提取文本和位置信息
176
  boxes = [line[0] for line in result]
177
  txts = [line[1][0] for line in result]
178
  scores = [line[1][1] for line in result]
179
+
180
+ if return_text_only:
181
+ # 仅返回文本
182
+ return "\n".join(txts), lang
183
+ else:
184
+ # 返回带标注的图像
185
+ im_show = draw_ocr(pil_img, boxes, txts, scores, font_path="./simfang.ttf")
186
+ return im_show, "\n".join(txts), lang
187
+
188
+
189
+ def inference_with_image(img):
190
+ """返回带标注的图像和文本"""
191
+ im_show, text, lang = inference(img, return_text_only=False)
192
+ return im_show, text, lang
193
 
 
 
 
 
 
194
 
195
+ def inference_text_only(img):
196
+ """仅返回文本"""
197
+ text, lang = inference(img, return_text_only=True)
198
+ return text, lang
199
+
200
+
201
+ def inference_base64(base64_string):
202
+ """处理Base64图像并返回OCR结果"""
203
+ if not base64_string or base64_string.strip() == "":
204
+ return "请提供有效的Base64图像字符串", ""
205
+
206
+ try:
207
+ text, lang = inference(base64_string, return_text_only=True)
208
+ return text, lang
209
+ except Exception as e:
210
+ return f"处理Base64图像时出错: {str(e)}", ""
211
+
212
+
213
+ title = '🔍 PaddleOCR 智能文字识别'
214
  description = '''
215
+ ### 功能特点
216
+ - 支持中文、英文、法语、德语、韩语和日语的智能文字识别
217
+ - 自动检测图像中的语言,无需手动选择
218
+ - 支持Base64编码图像识别
219
+ - 同时提供文本结果和标注图像
220
+
221
+ ### 使用方法
222
+ - 上传图像或提供Base64编码的图像数据
223
+ - 系统会自动检测语言并进行OCR识别
224
+ - 查看识别结果和标注图像
225
  '''
226
 
227
  examples = [
228
+ ['en_example.jpg'],
229
+ ['cn_example.jpg'],
230
+ ['jp_example.jpg'],
231
  ]
232
 
233
+ # 自定义CSS样式,优化界面
234
+ css = """
235
+ .gradio-container {
236
+ font-family: 'Roboto', 'Microsoft YaHei', sans-serif;
237
+ }
238
+ .output_image, .input_image {
239
+ height: 30rem !important;
240
+ width: 100% !important;
241
+ object-fit: contain;
242
+ border-radius: 8px;
243
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
244
+ }
245
+ .tabs {
246
+ margin-top: 0.5rem;
247
+ }
248
+ .output-text {
249
+ font-family: 'Courier New', monospace;
250
+ line-height: 1.5;
251
+ padding: 1rem;
252
+ border-radius: 8px;
253
+ background-color: #f8f9fa;
254
+ border: 1px solid #e9ecef;
255
+ }
256
+ .detected-lang {
257
+ font-weight: bold;
258
+ color: #4285f4;
259
+ margin-bottom: 0.5rem;
260
+ }
261
+ """
262
+
263
+ # 使用Gradio Blocks创建更丰富的界面
264
+ with gr.Blocks(title=title, css=css) as demo:
265
+ gr.Markdown(f"# {title}")
266
+ gr.Markdown(description)
267
+
268
+ with gr.Tabs() as tabs:
269
+ # 图像上传标签页
270
+ with gr.TabItem("图像上传识别"):
271
+ with gr.Row():
272
+ with gr.Column(scale=1):
273
+ image_input = gr.Image(label="上传图像", type="filepath")
274
+ image_submit = gr.Button("开始识别", variant="primary")
275
+
276
+ with gr.Column(scale=2):
277
+ with gr.Row():
278
+ image_output = gr.Image(label="标注结果", type="pil")
279
+ with gr.Row():
280
+ detected_lang = gr.Textbox(label="检测到的语言", lines=1)
281
+ with gr.Row():
282
+ text_output = gr.Textbox(label="识别文本", lines=10, elem_classes=["output-text"])
283
+
284
+ # Base64标签页
285
+ with gr.TabItem("Base64图像识别"):
286
+ with gr.Row():
287
+ with gr.Column(scale=1):
288
+ base64_input = gr.Textbox(
289
+ label="输入Base64编码的图像数据",
290
+ lines=8,
291
+ placeholder="在此粘贴Base64编码的图像数据..."
292
+ )
293
+ base64_submit = gr.Button("开始识别", variant="primary")
294
+
295
+ with gr.Column(scale=2):
296
+ base64_lang = gr.Textbox(label="检测到的语言", lines=1)
297
+ base64_output = gr.Textbox(
298
+ label="识别文本",
299
+ lines=15,
300
+ elem_classes=["output-text"]
301
+ )
302
+
303
+ # API使用说明
304
+ with gr.Accordion("API使用说明", open=False):
305
+ gr.Markdown("""
306
+ ## API使用方法
307
+
308
+ ### 1. 图像上传API
309
+
310
+ ```bash
311
+ curl -X POST "http://localhost:7860/api/predict" \\
312
+ -F "fn_index=0" \\
313
+ -F "data=@/path/to/your/image.jpg"
314
+ ```
315
+
316
+ ### 2. Base64图像API
317
+
318
+ ```bash
319
+ curl -X POST "http://localhost:7860/api/predict" \\
320
+ -H "Content-Type: application/json" \\
321
+ -d '{
322
+ "fn_index": 1,
323
+ "data": ["YOUR_BASE64_STRING_HERE"]
324
+ }'
325
+ ```
326
+ """)
327
+
328
+ # 设置事件处理
329
+ image_submit.click(
330
+ fn=inference_with_image,
331
+ inputs=[image_input],
332
+ outputs=[image_output, text_output, detected_lang]
333
+ )
334
+
335
+ base64_submit.click(
336
+ fn=inference_base64,
337
+ inputs=[base64_input],
338
+ outputs=[base64_output, base64_lang]
339
+ )
340
 
341
+ # 启动Gradio应用
342
+ demo.launch(debug=False, share=False)