ranbac commited on
Commit
fd2f280
·
verified ·
1 Parent(s): 6a65705

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -108
app.py CHANGED
@@ -9,12 +9,11 @@ from PIL import Image, ImageDraw, ImageFont
9
  from paddleocr import PaddleOCR
10
 
11
  # ==========================================
12
- # 1. CẤU HÌNH & TẢI MODEL (SERVER VERSION)
13
  # ==========================================
14
  os.environ["FLAGS_use_mkldnn"] = "0"
15
  os.environ["CPP_MIN_LOG_LEVEL"] = "3"
16
 
17
- # Hàm tải font chữ Trung Quốc (Giữ lại từ code cũ của bạn vì rất tốt)
18
  def check_and_download_font():
19
  font_path = "./simfang.ttf"
20
  if not os.path.exists(font_path):
@@ -30,14 +29,12 @@ def check_and_download_font():
30
 
31
  FONT_PATH = check_and_download_font()
32
 
33
- # Hàm tải Model Server (Độ chính xác cao)
34
  def download_model_server(save_dir="./server_models"):
35
  urls = {
36
  "det": "https://paddleocr.bj.bcebos.com/PP-OCRv4/chinese/ch_PP-OCRv4_det_server_infer.tar",
37
  "rec": "https://paddleocr.bj.bcebos.com/PP-OCRv4/chinese/ch_PP-OCRv4_rec_server_infer.tar",
38
  "cls": "https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar"
39
  }
40
-
41
  paths = {}
42
  if not os.path.exists(save_dir): os.makedirs(save_dir)
43
 
@@ -45,42 +42,51 @@ def download_model_server(save_dir="./server_models"):
45
  filename = url.split("/")[-1]
46
  extract_name = filename.replace('.tar', '')
47
  full_path = os.path.join(save_dir, extract_name)
48
-
49
  if not os.path.exists(full_path):
50
  print(f"Đang tải model {key.upper()} High-Accuracy...")
51
  tar_path = os.path.join(save_dir, filename)
52
- r = requests.get(url, stream=True)
53
- with open(tar_path, 'wb') as f:
54
- for chunk in r.iter_content(chunk_size=1024):
55
- if chunk: f.write(chunk)
56
- with tarfile.open(tar_path) as tar:
57
- tar.extractall(path=save_dir)
58
- os.remove(tar_path)
 
 
 
59
  paths[key] = full_path
60
  return paths
61
 
62
- # Khởi tạo OCR
63
- print("Đang khởi tạo PaddleOCR Server Mode...")
64
  try:
65
  models = download_model_server()
66
- ocr = PaddleOCR(
67
- use_angle_cls=True, lang='ch',
68
- det_model_dir=models['det'],
69
- rec_model_dir=models['rec'],
70
- cls_model_dir=models['cls'],
71
- use_textline_orientation=True
72
- )
73
  print("Model Server đã sẵn sàng!")
74
- except Exception as e:
75
- print(f"Lỗi tải model server: {e}. Dùng Mobile model.")
76
  ocr = PaddleOCR(use_angle_cls=True, lang='ch')
77
 
78
  # ==========================================
79
- # 2. XỬ LÝ HÌNH ẢNH & KẾT QUẢ
80
  # ==========================================
81
 
 
 
 
 
 
 
 
 
 
 
 
82
  def draw_results(image, result, font_path):
83
- # Convert sang PIL để vẽ đẹp hơn
84
  if isinstance(image, np.ndarray):
85
  image = Image.fromarray(image)
86
  draw = ImageDraw.Draw(image)
@@ -90,122 +96,78 @@ def draw_results(image, result, font_path):
90
  except:
91
  font = ImageFont.load_default()
92
 
93
- if result and result[0]:
94
- for line in result[0]:
 
 
95
  box = np.array(line[0]).astype(np.int32)
96
  txt = line[1][0]
97
  conf = line[1][1]
98
-
99
- # Vẽ box
100
  tuples = [tuple(p) for p in box]
101
  draw.polygon(tuples, outline="red", width=2)
102
-
103
- # Vẽ nền chữ
104
  txt_pos = (box[0][0], box[0][1] - 25)
105
- bbox = draw.textbbox(txt_pos, f"{txt} ({conf:.2f})", font=font)
106
  draw.rectangle(bbox, fill="red")
107
-
108
- # Vẽ chữ
109
  draw.text(txt_pos, txt, fill="white", font=font)
110
-
111
  return image
112
 
113
  def format_output(result):
114
- """Chuyển đổi kết quả sang Markdown và JSON sạch"""
115
- if not result or not result[0]:
116
- return "Không tìm thấy văn bản.", "[]"
117
 
118
- # Tạo Markdown
119
  md_lines = []
120
  json_data = []
121
 
122
- # Sắp xếp theo trục Y để tạo dòng văn bản tự nhiên
123
- boxes = sorted(result[0], key=lambda x: x[0][0][1])
124
-
125
- for item in boxes:
126
- text = item[1][0]
127
- conf = float(item[1][1])
128
- box = item[0]
129
-
130
- md_lines.append(f"- **{text}** (Độ tin cậy: {conf:.1%})")
131
-
132
- json_data.append({
133
- "text": text,
134
- "confidence": conf,
135
- "box": box
136
- })
137
 
138
- md_output = "\n".join(md_lines)
139
- json_output = json.dumps(json_data, ensure_ascii=False, indent=2)
140
-
141
- return md_output, json_output
142
 
143
- # ==========================================
144
- # 3. HÀM DỰ ĐOÁN CHÍNH
145
- # ==========================================
146
  def predict_pipeline(image_file):
147
- if image_file is None:
148
- return None, "", ""
149
-
150
- # Đọc ảnh
151
  img = np.array(Image.open(image_file).convert('RGB'))
152
 
153
- # OCR
154
- result = ocr.ocr(img)
155
 
156
- # 1. Vẽ Visualization
157
  vis_img = draw_results(img.copy(), result, FONT_PATH)
158
-
159
- # 2. Format dữ liệu
160
  md_out, json_out = format_output(result)
161
-
162
  return vis_img, md_out, json_out
163
 
164
  # ==========================================
165
- # 4. GIAO DIỆN GRADIO (Custom CSS giống bản Demo)
166
  # ==========================================
167
- custom_css = """
168
- body, .gradio-container { font-family: "Noto Sans SC", sans-serif; }
169
- .gradio-container { max-width: 1200px !important; margin: auto; }
170
- .header-area { text-align: center; margin-bottom: 20px; }
171
- .header-area h1 { margin-bottom: 5px; color: #2d3748; }
172
- .notice { background: #f0f9ff; border: 1px solid #bae6fd; padding: 10px; border-radius: 8px; color: #0369a1; font-size: 14px; margin-bottom: 15px; }
173
- """
174
-
175
- with gr.Blocks(title="PaddleOCR Pro Local", css=custom_css, theme=gr.themes.Soft()) as app:
176
-
177
- with gr.Column(elem_classes="header-area"):
178
- gr.Markdown("# 🇨🇳 PaddleOCR Professional (Local Version)")
179
- gr.HTML("<div class='notice'>⚡ Phiên bản Server-Mode: Chạy offline với độ chính xác cao hơn bản Mobile mặc định.</div>")
180
 
 
 
 
181
  with gr.Row():
182
- # Cột TRÁI: Input
183
  with gr.Column(scale=4):
184
  input_image = gr.Image(type="filepath", label="Tải ảnh lên", height=400)
185
- submit_btn = gr.Button("🚀 CHẠY NHẬN DIỆN", variant="primary", size="lg")
186
-
187
- gr.Markdown("### 💡 Ghi chú:")
188
- gr.Markdown("- Model sẽ tự động tải phiên bản **Server (High Accuracy)** (~200MB) trong lần chạy đầu.")
189
- gr.Markdown("- Hỗ trợ tốt cho tài liệu scan, hóa đơn và văn bản tiếng Trung.")
190
-
191
- # Cột PHẢI: Output (Tabbed UI)
192
  with gr.Column(scale=6):
193
  with gr.Tabs():
194
- with gr.TabItem("🖼️ Trực quan hóa (Visualization)"):
195
- output_image = gr.Image(type="pil", label="Kết quả")
196
-
197
- with gr.TabItem("📝 Văn bản (Markdown)"):
198
- output_md = gr.Markdown(label="Nội dung trích xuất")
199
-
200
- with gr.TabItem("📊 Dữ liệu thô (JSON)"):
201
- output_json = gr.Code(language="json", label="Chi tiết tọa độ & Confidence")
202
-
203
- # Xử lý sự kiện
204
- submit_btn.click(
205
- fn=predict_pipeline,
206
- inputs=[input_image],
207
- outputs=[output_image, output_md, output_json]
208
- )
209
 
210
  if __name__ == "__main__":
211
  app.launch(server_name="0.0.0.0", server_port=7860)
 
9
  from paddleocr import PaddleOCR
10
 
11
  # ==========================================
12
+ # 1. CẤU HÌNH & TẢI MODEL
13
  # ==========================================
14
  os.environ["FLAGS_use_mkldnn"] = "0"
15
  os.environ["CPP_MIN_LOG_LEVEL"] = "3"
16
 
 
17
  def check_and_download_font():
18
  font_path = "./simfang.ttf"
19
  if not os.path.exists(font_path):
 
29
 
30
  FONT_PATH = check_and_download_font()
31
 
 
32
  def download_model_server(save_dir="./server_models"):
33
  urls = {
34
  "det": "https://paddleocr.bj.bcebos.com/PP-OCRv4/chinese/ch_PP-OCRv4_det_server_infer.tar",
35
  "rec": "https://paddleocr.bj.bcebos.com/PP-OCRv4/chinese/ch_PP-OCRv4_rec_server_infer.tar",
36
  "cls": "https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar"
37
  }
 
38
  paths = {}
39
  if not os.path.exists(save_dir): os.makedirs(save_dir)
40
 
 
42
  filename = url.split("/")[-1]
43
  extract_name = filename.replace('.tar', '')
44
  full_path = os.path.join(save_dir, extract_name)
 
45
  if not os.path.exists(full_path):
46
  print(f"Đang tải model {key.upper()} High-Accuracy...")
47
  tar_path = os.path.join(save_dir, filename)
48
+ try:
49
+ r = requests.get(url, stream=True)
50
+ with open(tar_path, 'wb') as f:
51
+ for chunk in r.iter_content(chunk_size=1024):
52
+ if chunk: f.write(chunk)
53
+ with tarfile.open(tar_path) as tar:
54
+ tar.extractall(path=save_dir)
55
+ os.remove(tar_path)
56
+ except Exception as e:
57
+ print(f"Lỗi tải {filename}: {e}")
58
  paths[key] = full_path
59
  return paths
60
 
61
+ print("Đang khởi tạo PaddleOCR...")
 
62
  try:
63
  models = download_model_server()
64
+ ocr = PaddleOCR(use_angle_cls=True, lang='ch',
65
+ det_model_dir=models.get('det'),
66
+ rec_model_dir=models.get('rec'),
67
+ cls_model_dir=models.get('cls'),
68
+ use_textline_orientation=True)
 
 
69
  print("Model Server đã sẵn sàng!")
70
+ except:
71
+ print("Lỗi tải model server. Dùng Mobile model.")
72
  ocr = PaddleOCR(use_angle_cls=True, lang='ch')
73
 
74
  # ==========================================
75
+ # 2. XỬ LÝ HÌNH ẢNH & KẾT QUẢ (ĐÃ FIX)
76
  # ==========================================
77
 
78
+ def get_lines_from_result(result):
79
+ """Hàm phụ trợ để chuẩn hóa đầu ra của PaddleOCR"""
80
+ if not result: return []
81
+ # Nếu là list phẳng [Line1, Line2] (cấu trúc mới)
82
+ if isinstance(result[0], list) and len(result[0]) == 2 and \
83
+ isinstance(result[0][1], (tuple, list)) and \
84
+ isinstance(result[0][1][0], str):
85
+ return result
86
+ # Nếu là batch [[Line1, Line2]] (cấu trúc cũ)
87
+ return result[0]
88
+
89
  def draw_results(image, result, font_path):
 
90
  if isinstance(image, np.ndarray):
91
  image = Image.fromarray(image)
92
  draw = ImageDraw.Draw(image)
 
96
  except:
97
  font = ImageFont.load_default()
98
 
99
+ lines = get_lines_from_result(result)
100
+
101
+ for line in lines:
102
+ try:
103
  box = np.array(line[0]).astype(np.int32)
104
  txt = line[1][0]
105
  conf = line[1][1]
 
 
106
  tuples = [tuple(p) for p in box]
107
  draw.polygon(tuples, outline="red", width=2)
 
 
108
  txt_pos = (box[0][0], box[0][1] - 25)
109
+ bbox = draw.textbbox(txt_pos, f"{txt}", font=font)
110
  draw.rectangle(bbox, fill="red")
 
 
111
  draw.text(txt_pos, txt, fill="white", font=font)
112
+ except: continue
113
  return image
114
 
115
  def format_output(result):
116
+ lines = get_lines_from_result(result)
117
+ if not lines: return "Không tìm thấy văn bản.", "[]"
 
118
 
 
119
  md_lines = []
120
  json_data = []
121
 
122
+ # Sort top-down
123
+ try: sorted_lines = sorted(lines, key=lambda x: x[0][0][1])
124
+ except: sorted_lines = lines
125
+
126
+ for item in sorted_lines:
127
+ try:
128
+ text = item[1][0]
129
+ conf = float(item[1][1])
130
+ box = item[0]
131
+ md_lines.append(f"- **{text}** ({conf:.1%})")
132
+ json_data.append({"text": text, "confidence": conf, "box": box})
133
+ except: continue
 
 
 
134
 
135
+ return "\n".join(md_lines), json.dumps(json_data, ensure_ascii=False, indent=2)
 
 
 
136
 
 
 
 
137
  def predict_pipeline(image_file):
138
+ if image_file is None: return None, "", ""
 
 
 
139
  img = np.array(Image.open(image_file).convert('RGB'))
140
 
141
+ # Gọi OCR (cls=True giúp nhận diện chiều văn bản tốt hơn)
142
+ result = ocr.ocr(img, cls=True)
143
 
 
144
  vis_img = draw_results(img.copy(), result, FONT_PATH)
 
 
145
  md_out, json_out = format_output(result)
 
146
  return vis_img, md_out, json_out
147
 
148
  # ==========================================
149
+ # 3. GIAO DIỆN GRADIO
150
  # ==========================================
151
+ custom_css = "body, .gradio-container { font-family: 'Noto Sans SC', sans-serif; }"
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
+ with gr.Blocks(title="PaddleOCR Pro Fixed", css=custom_css, theme=gr.themes.Soft()) as app:
154
+ gr.Markdown("# 🇨🇳 PaddleOCR Pro (Server Mode - Fixed)")
155
+
156
  with gr.Row():
 
157
  with gr.Column(scale=4):
158
  input_image = gr.Image(type="filepath", label="Tải ảnh lên", height=400)
159
+ submit_btn = gr.Button("CHẠY NHẬN DIỆN", variant="primary")
160
+
 
 
 
 
 
161
  with gr.Column(scale=6):
162
  with gr.Tabs():
163
+ with gr.TabItem("Kết quả"):
164
+ output_image = gr.Image(type="pil", label="Visualization")
165
+ with gr.TabItem("Markdown"):
166
+ output_md = gr.Markdown()
167
+ with gr.TabItem("JSON"):
168
+ output_json = gr.Code(language="json")
169
+
170
+ submit_btn.click(predict_pipeline, inputs=[input_image], outputs=[output_image, output_md, output_json])
 
 
 
 
 
 
 
171
 
172
  if __name__ == "__main__":
173
  app.launch(server_name="0.0.0.0", server_port=7860)