ranbac commited on
Commit
7b9a396
·
verified ·
1 Parent(s): 755b8bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +183 -135
app.py CHANGED
@@ -1,24 +1,38 @@
1
  import os
2
- import cv2
3
- import json
4
- import tarfile
5
- import requests
6
- import numpy as np
7
- import gradio as gr
8
- from PIL import Image, ImageDraw, ImageFont
9
- from paddleocr import PaddleOCR
10
 
11
- # ==========================================
12
- # 1. CẤU HÌNH & TẢI MODEL
13
- # ==========================================
14
  os.environ["FLAGS_use_mkldnn"] = "0"
 
 
15
  os.environ["CPP_MIN_LOG_LEVEL"] = "3"
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def check_and_download_font():
18
  font_path = "./simfang.ttf"
19
  if not os.path.exists(font_path):
20
  try:
21
- print("Đang tải font SimFang...")
22
  url = "https://github.com/StellarCN/scp_zh/raw/master/fonts/SimFang.ttf"
23
  r = requests.get(url, allow_redirects=True)
24
  with open(font_path, 'wb') as f:
@@ -29,145 +43,179 @@ def check_and_download_font():
29
 
30
  FONT_PATH = check_and_download_font()
31
 
32
- def download_model_server(save_dir="./server_models"):
33
- urls = {
34
- "det": "https://paddleocr.bj.bcebos.com/PP-OCRv4/chinese/ch_PP-OCRv4_det_server_infer.tar",
35
- "rec": "https://paddleocr.bj.bcebos.com/PP-OCRv4/chinese/ch_PP-OCRv4_rec_server_infer.tar",
36
- "cls": "https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar"
37
- }
38
- paths = {}
39
- if not os.path.exists(save_dir): os.makedirs(save_dir)
40
-
41
- for key, url in urls.items():
42
- filename = url.split("/")[-1]
43
- extract_name = filename.replace('.tar', '')
44
- full_path = os.path.join(save_dir, extract_name)
45
- if not os.path.exists(full_path):
46
- print(f"Đang tải model {key.upper()} High-Accuracy...")
47
- tar_path = os.path.join(save_dir, filename)
48
- try:
49
- r = requests.get(url, stream=True)
50
- with open(tar_path, 'wb') as f:
51
- for chunk in r.iter_content(chunk_size=1024):
52
- if chunk: f.write(chunk)
53
- with tarfile.open(tar_path) as tar:
54
- tar.extractall(path=save_dir)
55
- os.remove(tar_path)
56
- except Exception as e:
57
- print(f"Lỗi tải {filename}: {e}")
58
- paths[key] = full_path
59
- return paths
60
-
61
- print("Đang khởi tạo PaddleOCR...")
62
- try:
63
- models = download_model_server()
64
- ocr = PaddleOCR(use_angle_cls=True, lang='ch',
65
- det_model_dir=models.get('det'),
66
- rec_model_dir=models.get('rec'),
67
- cls_model_dir=models.get('cls'),
68
- use_textline_orientation=True)
69
- print("Model Server đã sẵn sàng!")
70
- except:
71
- print("Lỗi tải model server. Dùng Mobile model.")
72
- ocr = PaddleOCR(use_angle_cls=True, lang='ch')
73
-
74
- # ==========================================
75
- # 2. XỬ LÝ HÌNH ẢNH & KẾT QUẢ (ĐÃ FIX)
76
- # ==========================================
77
-
78
- def get_lines_from_result(result):
79
- """Hàm phụ trợ để chuẩn hóa đầu ra của PaddleOCR"""
80
- if not result: return []
81
- # Nếu là list phẳng [Line1, Line2] (cấu trúc mới)
82
- if isinstance(result[0], list) and len(result[0]) == 2 and \
83
- isinstance(result[0][1], (tuple, list)) and \
84
- isinstance(result[0][1][0], str):
85
- return result
86
- # Nếu là batch [[Line1, Line2]] (cấu trúc cũ)
87
- return result[0]
88
-
89
- def draw_results(image, result, font_path):
90
  if isinstance(image, np.ndarray):
91
  image = Image.fromarray(image)
92
- draw = ImageDraw.Draw(image)
 
 
 
93
 
94
  try:
95
- font = ImageFont.truetype(font_path, 20) if font_path else ImageFont.load_default()
 
96
  except:
97
  font = ImageFont.load_default()
98
 
99
- lines = get_lines_from_result(result)
100
-
101
- for line in lines:
102
  try:
103
- box = np.array(line[0]).astype(np.int32)
104
- txt = line[1][0]
105
- conf = line[1][1]
106
- tuples = [tuple(p) for p in box]
107
- draw.polygon(tuples, outline="red", width=2)
108
- txt_pos = (box[0][0], box[0][1] - 25)
109
- bbox = draw.textbbox(txt_pos, f"{txt}", font=font)
110
- draw.rectangle(bbox, fill="red")
111
- draw.text(txt_pos, txt, fill="white", font=font)
112
- except: continue
113
- return image
114
-
115
- def format_output(result):
116
- lines = get_lines_from_result(result)
117
- if not lines: return "Không tìm thấy văn bản.", "[]"
118
-
119
- md_lines = []
120
- json_data = []
121
-
122
- # Sort top-down
123
- try: sorted_lines = sorted(lines, key=lambda x: x[0][0][1])
124
- except: sorted_lines = lines
125
 
126
- for item in sorted_lines:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  try:
128
- text = item[1][0]
129
- conf = float(item[1][1])
130
- box = item[0]
131
- md_lines.append(f"- **{text}** ({conf:.1%})")
132
- json_data.append({"text": text, "confidence": conf, "box": box})
 
 
 
 
 
133
  except: continue
134
-
135
- return "\n".join(md_lines), json.dumps(json_data, ensure_ascii=False, indent=2)
136
 
137
- def predict_pipeline(image_file):
138
- if image_file is None: return None, "", ""
139
- img = np.array(Image.open(image_file).convert('RGB'))
140
-
141
- # Gọi OCR (cls=True giúp nhận diện chiều văn bản tốt hơn)
142
- result = ocr.ocr(img)
143
-
144
- vis_img = draw_results(img.copy(), result, FONT_PATH)
145
- md_out, json_out = format_output(result)
146
- return vis_img, md_out, json_out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
- # ==========================================
149
- # 3. GIAO DIỆN GRADIO
150
- # ==========================================
151
- custom_css = "body, .gradio-container { font-family: 'Noto Sans SC', sans-serif; }"
152
-
153
- with gr.Blocks(title="PaddleOCR Pro Fixed", css=custom_css, theme=gr.themes.Soft()) as app:
154
- gr.Markdown("# 🇨🇳 PaddleOCR Pro (Server Mode - Fixed)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
  with gr.Row():
157
- with gr.Column(scale=4):
158
- input_image = gr.Image(type="filepath", label="Tải ảnh lên", height=400)
159
- submit_btn = gr.Button("CHẠY NHẬN DIỆN", variant="primary")
160
 
161
- with gr.Column(scale=6):
162
  with gr.Tabs():
163
- with gr.TabItem("Kết quả"):
164
- output_image = gr.Image(type="pil", label="Visualization")
165
- with gr.TabItem("Markdown"):
166
- output_md = gr.Markdown()
167
- with gr.TabItem("JSON"):
168
- output_json = gr.Code(language="json")
169
-
170
- submit_btn.click(predict_pipeline, inputs=[input_image], outputs=[output_image, output_md, output_json])
 
 
 
 
171
 
172
  if __name__ == "__main__":
173
- app.launch(server_name="0.0.0.0", server_port=7860)
 
1
  import os
 
 
 
 
 
 
 
 
2
 
3
+ # --- CẤU HÌNH HỆ THỐNG ---
 
 
4
  os.environ["FLAGS_use_mkldnn"] = "0"
5
+ os.environ["FLAGS_enable_mkldnn"] = "0"
6
+ os.environ["DN_ENABLE_MKLDNN"] = "0"
7
  os.environ["CPP_MIN_LOG_LEVEL"] = "3"
8
 
9
+ import logging
10
+ import re
11
+ import gradio as gr
12
+ from paddleocr import PaddleOCR
13
+ from PIL import Image, ImageDraw, ImageFont
14
+ import numpy as np
15
+ import requests
16
+
17
+ # Tắt log thừa
18
+ logging.getLogger("ppocr").setLevel(logging.WARNING)
19
+
20
+ print("Đang khởi tạo PaddleOCR (Coordinate Sync Mode)...")
21
+
22
+ try:
23
+ ocr = PaddleOCR(use_textline_orientation=True, use_doc_orientation_classify=False,
24
+ use_doc_unwarping=False, lang='ch')
25
+ except Exception as e:
26
+ print(f"Lỗi khởi tạo: {e}. Chuyển về chế độ mặc định.")
27
+ ocr = PaddleOCR(lang='ch')
28
+
29
+ print("Model đã sẵn sàng!")
30
+
31
+ # --- TẢI FONT ---
32
  def check_and_download_font():
33
  font_path = "./simfang.ttf"
34
  if not os.path.exists(font_path):
35
  try:
 
36
  url = "https://github.com/StellarCN/scp_zh/raw/master/fonts/SimFang.ttf"
37
  r = requests.get(url, allow_redirects=True)
38
  with open(font_path, 'wb') as f:
 
43
 
44
  FONT_PATH = check_and_download_font()
45
 
46
+ # --- HÀM VẼ ĐA NĂNG ---
47
+ def universal_draw(image, raw_data, font_path):
48
+ if image is None: return image
49
+
50
+ # Đảm bảo image là PIL
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  if isinstance(image, np.ndarray):
52
  image = Image.fromarray(image)
53
+
54
+ # Copy để vẽ
55
+ canvas = image.copy()
56
+ draw = ImageDraw.Draw(canvas)
57
 
58
  try:
59
+ font_size = 24
60
+ font = ImageFont.truetype(font_path, font_size) if font_path else ImageFont.load_default()
61
  except:
62
  font = ImageFont.load_default()
63
 
64
+ # Hàm parse box
65
+ def parse_box(b):
 
66
  try:
67
+ if hasattr(b, 'tolist'): b = b.tolist()
68
+ if len(b) > 0 and isinstance(b[0], list): return [tuple(p) for p in b]
69
+ if len(b) == 4 and isinstance(b[0], (int, float)):
70
+ return [(b[0], b[1]), (b[2], b[1]), (b[2], b[3]), (b[0], b[3])]
71
+ return None
72
+ except: return None
73
+
74
+ items_to_draw = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
+ # Logic tìm box/text
77
+ # Ưu tiên cấu trúc PaddleX: rec_texts + dt_polys
78
+ processed = False
79
+ if isinstance(raw_data, list) and len(raw_data) > 0 and isinstance(raw_data[0], dict):
80
+ data_dict = raw_data[0]
81
+ texts = data_dict.get('rec_texts')
82
+ boxes = data_dict.get('dt_polys', data_dict.get('rec_polys', data_dict.get('dt_boxes')))
83
+
84
+ if texts and boxes and isinstance(texts, list) and isinstance(boxes, list):
85
+ for i in range(min(len(texts), len(boxes))):
86
+ txt = texts[i]
87
+ box = parse_box(boxes[i])
88
+ if box and txt: items_to_draw.append((box, txt))
89
+ processed = True
90
+
91
+ # Fallback Logic
92
+ if not processed:
93
+ def hunt(data):
94
+ if isinstance(data, dict):
95
+ box = None; text = None
96
+ for k in ['points', 'box', 'dt_boxes', 'poly']:
97
+ if k in data: box = parse_box(data[k]); break
98
+ for k in ['transcription', 'text', 'rec_text', 'label']:
99
+ if k in data: text = data[k]; break
100
+ if box and text: items_to_draw.append((box, text)); return
101
+ for v in data.values(): hunt(v)
102
+ elif isinstance(data, (list, tuple)):
103
+ if len(data) == 2 and isinstance(data[0], list) and len(data[0]) == 4:
104
+ box = parse_box(data[0])
105
+ txt_obj = data[1]
106
+ text = txt_obj[0] if isinstance(txt_obj, (list, tuple)) else txt_obj
107
+ if box and isinstance(text, str): items_to_draw.append((box, text)); return
108
+ for item in data: hunt(item)
109
+ hunt(raw_data)
110
+
111
+ # Vẽ
112
+ for box, txt in items_to_draw:
113
  try:
114
+ # Vẽ khung đỏ
115
+ draw.polygon(box, outline="red", width=3)
116
+ # Vẽ chữ
117
+ txt_x, txt_y = box[0]
118
+ if hasattr(draw, "textbbox"):
119
+ text_bbox = draw.textbbox((txt_x, txt_y), txt, font=font, anchor="lb")
120
+ draw.rectangle(text_bbox, fill="red")
121
+ draw.text((txt_x, txt_y), txt, fill="white", font=font, anchor="lb")
122
+ else:
123
+ draw.text((txt_x, txt_y - font_size), txt, fill="white", font=font)
124
  except: continue
 
 
125
 
126
+ return canvas
127
+
128
+ # --- HÀM XỬ LÝ TEXT ---
129
+ def deep_extract_text(data):
130
+ found_texts = []
131
+ if isinstance(data, str):
132
+ if len(data.strip()) > 0: return [data]
133
+ return []
134
+ if isinstance(data, (list, tuple)):
135
+ for item in data: found_texts.extend(deep_extract_text(item))
136
+ elif isinstance(data, dict):
137
+ for val in data.values(): found_texts.extend(deep_extract_text(val))
138
+ elif hasattr(data, '__dict__'): found_texts.extend(deep_extract_text(data.__dict__))
139
+ return found_texts
140
+
141
+ def clean_text_result(text_list):
142
+ cleaned = []
143
+ block_list = ['min', 'max', 'general', 'header', 'footer', 'structure']
144
+ for t in text_list:
145
+ t = t.strip()
146
+ if len(t) < 2 and not any(u'\u4e00' <= c <= u'\u9fff' for c in t): continue
147
+ if t.lower().endswith(('.ttf', '.json', '.pdparams', '.yml', '.log')): continue
148
+ if t.lower() in block_list: continue
149
+ if not re.search(r'[\w\u4e00-\u9fff]', t): continue
150
+ cleaned.append(t)
151
+ return cleaned
152
+
153
+ # --- MAIN PREDICT ---
154
+ def predict(image):
155
+ if image is None: return None, "Chưa có ảnh.", "No Data"
156
 
157
+ try:
158
+ # Chuẩn bị ảnh đầu vào
159
+ original_pil = image.copy() if isinstance(image, Image.Image) else Image.fromarray(image).copy()
160
+ image_np = np.array(image)
161
+
162
+ # 1. OCR
163
+ raw_result = ocr.ocr(image_np)
164
+
165
+ # 2. XỬ LÝ ẢNH ĐỂ VẼ (KEY FIX: Lấy ảnh từ Preprocessor nếu có)
166
+ target_image_for_drawing = original_pil
167
+
168
+ # Kiểm tra xem Paddle có chỉnh sửa ảnh không (dựa vào key 'doc_preprocessor_res')
169
+ if isinstance(raw_result, list) and len(raw_result) > 0 and isinstance(raw_result[0], dict):
170
+ if 'doc_preprocessor_res' in raw_result[0]:
171
+ proc_res = raw_result[0]['doc_preprocessor_res']
172
+ # Nếu có ảnh đầu ra đã chỉnh sửa (output_img)
173
+ if 'output_img' in proc_res:
174
+ print("Phát hiện ảnh đã qua xử lý hình học. Đang đồng bộ tọa độ...")
175
+ numpy_img = proc_res['output_img']
176
+ target_image_for_drawing = Image.fromarray(numpy_img)
177
+
178
+ # 3. Vẽ lên ảnh ĐÚNG (Target Image)
179
+ annotated_image = universal_draw(target_image_for_drawing, raw_result, FONT_PATH)
180
+
181
+ # 4. Xử lý Text
182
+ all_texts = deep_extract_text(raw_result)
183
+ final_texts = clean_text_result(all_texts)
184
+ text_output = "\n".join(final_texts) if final_texts else "Không tìm thấy văn bản."
185
+
186
+ # Debug Info
187
+ debug_str = str(raw_result)[:1000]
188
+ debug_info = f"Used Image Source: {'Preprocessed' if target_image_for_drawing != original_pil else 'Original'}\nData Preview:\n{debug_str}..."
189
+
190
+ return annotated_image, text_output, debug_info
191
+
192
+ except Exception as e:
193
+ import traceback
194
+ return image, f"Lỗi: {str(e)}", traceback.format_exc()
195
+
196
+ # --- GIAO DIỆN ---
197
+ with gr.Blocks(title="PaddleOCR Perfect Overlay") as iface:
198
+ gr.Markdown("## PaddleOCR Chinese - High Precision Overlay")
199
 
200
  with gr.Row():
201
+ with gr.Column():
202
+ input_img = gr.Image(type="pil", label="Input Image")
203
+ submit_btn = gr.Button("RUN OCR", variant="primary")
204
 
205
+ with gr.Column():
206
  with gr.Tabs():
207
+ with gr.TabItem("🖼️ Kết quả Khớp Tọa Độ"):
208
+ output_img = gr.Image(type="pil", label="Overlay Result")
209
+ with gr.TabItem("📝 Văn bản"):
210
+ output_txt = gr.Textbox(label="Text Content", lines=15)
211
+ with gr.TabItem("🐞 Debug"):
212
+ output_debug = gr.Textbox(label="Debug Info", lines=15)
213
+
214
+ submit_btn.click(
215
+ fn=predict,
216
+ inputs=input_img,
217
+ outputs=[output_img, output_txt, output_debug]
218
+ )
219
 
220
  if __name__ == "__main__":
221
+ iface.launch(server_name="0.0.0.0", server_port=7860)