Toughen1 commited on
Commit
d22018e
·
verified ·
1 Parent(s): cbfa32e
Files changed (1) hide show
  1. app.py +123 -70
app.py CHANGED
@@ -3,10 +3,11 @@ import functools
3
  import base64
4
  import io
5
  import re
 
 
6
  from queue import Queue
7
  from threading import Event, Thread
8
  import numpy as np
9
- from langdetect import detect
10
  from paddleocr import PaddleOCR, draw_ocr
11
  from PIL import Image
12
  import gradio as gr
@@ -21,14 +22,24 @@ LANG_CONFIG = {
21
  "japan": {"num_workers": 1},
22
  }
23
 
24
- # 语言检测映射
25
- LANG_DETECT_MAP = {
26
- "zh": "ch",
27
- "en": "en",
28
- "fr": "fr",
29
- "de": "german",
30
- "ko": "korean",
31
- "ja": "japan",
 
 
 
 
 
 
 
 
 
 
32
  }
33
 
34
  CONCURRENCY_LIMIT = 8
@@ -102,38 +113,64 @@ def close_model_managers():
102
  atexit.register(close_model_managers)
103
 
104
 
105
- def detect_language_from_text(text):
106
- """根据文本内容自动检测语言"""
107
- try:
108
- detected = detect(text)
109
- return LANG_DETECT_MAP.get(detected, "en") # 默认返回英文
110
- except:
111
- return "en" # 检测失败时默认返回英文
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
 
114
- def auto_detect_language(image):
115
- """尝试从图像中检测语言"""
116
- # 先用英文OCR提取一些文本
117
- ocr = model_managers["en"]
118
- try:
119
- result = ocr.infer(image, cls=True)[0]
120
- if not result:
121
- return "en" # 如果没有检测到文本,默认使用英文
122
-
123
- # 将所有文本合并起来进行语言检测
124
- all_text = " ".join([line[1][0] for line in result])
125
- if not all_text.strip():
126
- return "en"
127
-
128
- # 检测语言
129
- lang = detect_language_from_text(all_text)
130
- return lang
131
- except:
132
- return "en" # 出错时默认使用英文
 
 
 
 
 
 
133
 
134
 
135
- def process_base64_image(base64_string):
136
- """处理Base64编码的图像"""
137
  try:
138
  # 移除可能的前缀
139
  if "base64," in base64_string:
@@ -141,49 +178,65 @@ def process_base64_image(base64_string):
141
 
142
  # 解码Base64
143
  image_data = base64.b64decode(base64_string)
144
- image = Image.open(io.BytesIO(image_data))
145
 
146
- # 将PIL图像转换为临时文件
147
- temp_io = io.BytesIO()
148
- image.save(temp_io, format='PNG')
149
- temp_io.seek(0)
150
 
151
- return temp_io, image
152
  except Exception as e:
153
  raise ValueError(f"处理Base64图像时出错: {str(e)}")
154
 
155
 
156
  def inference(img, return_text_only=True):
157
  """OCR推理函数,自动检测语言"""
158
- # 处理输入图像
159
- if isinstance(img, str) and img.startswith("data:") or re.match(r'^[A-Za-z0-9+/=]+$', img):
160
- # 处理Base64输入
161
- img_io, pil_img = process_base64_image(img)
162
- img_path = img_io
163
- else:
164
- # 处理文件路径输入
165
- img_path = img
166
- pil_img = Image.open(img_path).convert("RGB")
167
-
168
- # 自动检测语言
169
- lang = auto_detect_language(img_path)
170
 
171
- # 使用检测到的语言进行OCR
172
- ocr = model_managers[lang]
173
- result = ocr.infer(img_path, cls=True)[0]
174
-
175
- # 提取文本和位置信息
176
- boxes = [line[0] for line in result]
177
- txts = [line[1][0] for line in result]
178
- scores = [line[1][1] for line in result]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
- if return_text_only:
181
- # 仅返回
182
- return "\n".join(txts), lang
183
- else:
184
- # 返回带标注的图像
185
- im_show = draw_ocr(pil_img, boxes, txts, scores, font_path="./simfang.ttf")
186
- return im_show, "\n".join(txts), lang
187
 
188
 
189
  def inference_with_image(img):
 
3
  import base64
4
  import io
5
  import re
6
+ import os
7
+ import tempfile
8
  from queue import Queue
9
  from threading import Event, Thread
10
  import numpy as np
 
11
  from paddleocr import PaddleOCR, draw_ocr
12
  from PIL import Image
13
  import gradio as gr
 
22
  "japan": {"num_workers": 1},
23
  }
24
 
25
+ # 语言检测映射 - 使用更可靠的方法
26
+ LANG_MAP = {
27
+ "ch": "中文",
28
+ "en": "英文",
29
+ "fr": "法语",
30
+ "german": "德语",
31
+ "korean": "韩语",
32
+ "japan": "日语",
33
+ }
34
+
35
+ # 语言特征字符集
36
+ LANG_FEATURES = {
37
+ "ch": set("的一是不了人我在有他这为之大来以个中上们到国说和地也子时道出而要于就下得可你年生自会那后能对着事其里所去行过家十用发天如然作方成者多日都三小军二公无同么经法当起与好看学进种将还分此心前面又定见只主没公从年可着同时至理化物现并提直题党性好它头应主实向当把几十用表已近万第调音真打太办现做感次带北林里无从化性相将应间手专这见民候深院查表化何南器声点今建月正机北装分十注位被反革力量门反象并果更系求把治取入总些形度持制管即及西做先将才结共接目路至城北口山战世强先产革律较本群决使见治及造百规热领即集什积六县接必照住治准革复每设始术精专向变团便石从按却代光命即保达干统持运复程究造何革命即系统计或设总色律象即物线划几领按更系院转些即总导度济深求传界拉干着真示制干提克度几管见导传命即总系具引势持使结构论完联常达设战表南究利世结构论完联常达设战表南究利世",),
38
+ "en": set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"),
39
+ "fr": set("àâäæçéèêëîïôœùûüÿÀÂÄÆÇÉÈÊËÎÏÔŒÙÛÜŸ"),
40
+ "german": set("äöüßÄÖÜ"),
41
+ "korean": set(), # 韩语字符集较复杂,使用其他方法检测
42
+ "japan": set("あいうえおかきくけこさしすせそたちつてとなにぬねのはひふへほまみむめもやゆよらりるれろわをんがぎぐげござじずぜぞだぢづでどばびぶべぼぱぴぷぺぽアイウエオカキクケコサシスセソタチツテトナニヌネノハヒフヘホマミムメモヤユヨラリルレロワヲンガギグゲゴザジズゼゾダヂヅデドバビブベボパピプペポ"),
43
  }
44
 
45
  CONCURRENCY_LIMIT = 8
 
113
  atexit.register(close_model_managers)
114
 
115
 
116
+ def detect_language_by_features(text):
117
+ """基于特征字符集检测语言"""
118
+ if not text:
119
+ return "en"
120
+
121
+ # 计算每种语言的特征字符出现比例
122
+ lang_scores = {}
123
+ for lang, char_set in LANG_FEATURES.items():
124
+ if not char_set: # 跳过空字符集
125
+ continue
126
+
127
+ # 计算文本中该语言特征字符的数量
128
+ count = sum(1 for char in text if char in char_set)
129
+ if count > 0:
130
+ lang_scores[lang] = count / len(text)
131
+
132
+ # 特殊处理韩语(通过Unicode范围检测)
133
+ korean_count = sum(1 for char in text if '\uac00' <= char <= '\ud7a3')
134
+ if korean_count > 0:
135
+ lang_scores["korean"] = korean_count / len(text)
136
+
137
+ # 如果没有检测到任何语言特征,默认为英语
138
+ if not lang_scores:
139
+ return "en"
140
+
141
+ # 返回特征比例最高的语言
142
+ return max(lang_scores.items(), key=lambda x: x[1])[0]
143
 
144
 
145
+ def auto_detect_language(image_path):
146
+ """使用多模型投票的方式检测语言"""
147
+ languages_to_try = ["ch", "en"] # 先用这两种常见语言尝试
148
+ results = {}
149
+
150
+ for lang in languages_to_try:
151
+ try:
152
+ ocr = model_managers[lang]
153
+ result = ocr.infer(image_path, cls=True)[0]
154
+ if result:
155
+ # 提取所有文本
156
+ all_text = " ".join([line[1][0] for line in result])
157
+ if all_text.strip():
158
+ # 基于提取的文本检测语言
159
+ detected = detect_language_by_features(all_text)
160
+ results[detected] = results.get(detected, 0) + 1
161
+ except Exception:
162
+ continue
163
+
164
+ # 如果没有检测结果,默认使用英文
165
+ if not results:
166
+ return "en"
167
+
168
+ # 返回得票最多的语言
169
+ return max(results.items(), key=lambda x: x[1])[0]
170
 
171
 
172
+ def save_base64_to_temp_file(base64_string):
173
+ """Base64图像保存为临时文件"""
174
  try:
175
  # 移除可能的前缀
176
  if "base64," in base64_string:
 
178
 
179
  # 解码Base64
180
  image_data = base64.b64decode(base64_string)
 
181
 
182
+ # 创建临时文件
183
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
184
+ temp_file.write(image_data)
185
+ temp_file.close()
186
 
187
+ return temp_file.name
188
  except Exception as e:
189
  raise ValueError(f"处理Base64图像时出错: {str(e)}")
190
 
191
 
192
  def inference(img, return_text_only=True):
193
  """OCR推理函数,自动检测语言"""
194
+ temp_file = None
 
 
 
 
 
 
 
 
 
 
 
195
 
196
+ try:
197
+ # 处理输入图像
198
+ if isinstance(img, str):
199
+ if img.startswith("data:") or re.match(r'^[A-Za-z0-9+/=]+$', img):
200
+ # 处理Base64输入
201
+ temp_file = save_base64_to_temp_file(img)
202
+ img_path = temp_file
203
+ else:
204
+ # 处理文件路径输入
205
+ img_path = img
206
+ else:
207
+ # 处理其他类型输入
208
+ img_path = img
209
+
210
+ # 自动检测语言
211
+ lang = auto_detect_language(img_path)
212
+
213
+ # 使用检测到的语言进行OCR
214
+ ocr = model_managers[lang]
215
+ result = ocr.infer(img_path, cls=True)[0]
216
+
217
+ # 提取文本和位置信息
218
+ boxes = [line[0] for line in result]
219
+ txts = [line[1][0] for line in result]
220
+ scores = [line[1][1] for line in result]
221
+
222
+ # 读取图像用于绘制
223
+ pil_img = Image.open(img_path).convert("RGB")
224
+
225
+ if return_text_only:
226
+ # 仅返回文本
227
+ return "\n".join(txts), LANG_MAP.get(lang, lang)
228
+ else:
229
+ # 返回带标注的图像
230
+ im_show = draw_ocr(pil_img, boxes, txts, scores, font_path="./simfang.ttf")
231
+ return im_show, "\n".join(txts), LANG_MAP.get(lang, lang)
232
 
233
+ finally:
234
+ # 清理临时
235
+ if temp_file and os.path.exists(temp_file):
236
+ try:
237
+ os.unlink(temp_file)
238
+ except:
239
+ pass
240
 
241
 
242
  def inference_with_image(img):