Harry Pham commited on
Commit
f8fef9f
·
1 Parent(s): c87ac5f

update OCR

Browse files
Files changed (1) hide show
  1. src/inference.py +217 -77
src/inference.py CHANGED
@@ -11,6 +11,7 @@ import json
11
  import numpy as np
12
  from pathlib import Path
13
  from ultralytics import RTDETR
 
14
 
15
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
  print(f"[INFO] Device: {DEVICE}")
@@ -20,102 +21,240 @@ CLASS_DISPLAY = {"note": "Note", "part-drawing": "PartDrawing", "table": "Table"
20
  COLORS = {"note": (0,165,255), "part-drawing": (0,200,0), "table": (0,0,220)}
21
 
22
  _det_model = None
23
- _easy_reader = None
24
 
25
  def get_det_model(checkpoint="best.pt"):
26
  global _det_model
27
  if _det_model is None:
28
- print(f"[INFO] Loading model: {checkpoint}")
29
  _det_model = RTDETR(checkpoint)
30
  return _det_model
31
 
32
- def get_easy_reader():
33
- global _easy_reader
34
- if _easy_reader is None:
35
- import easyocr
36
- print("[INFO] Loading EasyOCR...")
37
- _easy_reader = easyocr.Reader(["vi","en"], gpu=False, verbose=False)
38
- return _easy_reader
39
 
40
- def preprocess_for_ocr(img_bgr):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  h, w = img_bgr.shape[:2]
 
42
  if w < 800:
43
  scale = 800 / w
44
- img_bgr = cv2.resize(img_bgr, (int(w*scale), int(h*scale)),
45
- interpolation=cv2.INTER_CUBIC)
46
- gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
47
- clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))
48
- gray = clahe.apply(gray)
49
- gray = cv2.fastNlMeansDenoising(gray, h=15,
50
- templateWindowSize=7, searchWindowSize=21)
51
- kernel = np.array([[0,-1,0],[-1,5,-1],[0,-1,0]])
52
- gray = cv2.filter2D(gray, -1, kernel)
53
- return cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
54
-
55
- def ocr_note(img_path):
56
- img = cv2.imread(img_path)
57
- if img is None:
58
- return ""
59
- img_proc = preprocess_for_ocr(img)
60
- try:
61
- reader = get_easy_reader()
62
- results = reader.readtext(img_proc, detail=1, paragraph=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  width_ths=0.7, height_ths=0.7)
64
- lines = [t for (_,t,c) in results if c >= 0.2 and t.strip()]
65
- return "\n".join(lines)
66
- except Exception as e:
67
- print(f"[WARN] ocr_note: {e}")
68
- return ""
 
 
 
 
 
 
 
 
 
69
 
70
- def _group_rows(items):
 
 
 
 
71
  if not items:
72
  return []
73
- items = sorted(items, key=lambda x: x["y"])
74
- y_vals = [it["y"] for it in items]
 
 
 
 
75
  if len(y_vals) > 1:
76
- gaps = [y_vals[i+1]-y_vals[i] for i in range(len(y_vals)-1)]
77
- thresh = max(8, (sum(gaps)/len(gaps)) * 0.6)
 
78
  else:
79
  thresh = 12
80
- rows, cur = [], [items[0]]
81
- for item in items[1:]:
82
- if item["y"] - cur[-1]["y"] < thresh:
83
- cur.append(item)
 
 
84
  else:
85
- cur.sort(key=lambda x: x["x"])
86
- rows.append([i["text"] for i in cur])
87
- cur = [item]
88
- cur.sort(key=lambda x: x["x"])
89
- rows.append([i["text"] for i in cur])
90
- return rows
91
-
92
- def ocr_table(img_path):
 
 
 
 
93
  img = cv2.imread(img_path)
94
  if img is None:
95
- return {"rows":[], "text":""}
96
- img_proc = preprocess_for_ocr(img)
97
- items = []
98
- try:
99
- reader = get_easy_reader()
100
- results = reader.readtext(img_proc, detail=1, paragraph=False,
101
- width_ths=0.5, height_ths=0.5)
102
- for (pts, text, conf) in results:
103
- if conf < 0.2 or not text.strip():
104
- continue
105
- items.append({
106
- "text": text.strip(),
107
- "y": sum(p[1] for p in pts)/4,
108
- "x": sum(p[0] for p in pts)/4,
109
- })
110
- except Exception as e:
111
- print(f"[WARN] ocr_table: {e}")
112
  if not items:
113
- return {"rows":[], "text":""}
114
- rows = _group_rows(items)
115
- return {"rows": rows, "text": "\n".join(" | ".join(r) for r in rows)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
  def run_pipeline(image_path, output_dir="outputs",
118
- checkpoint="best.pt", conf_thresh=0.3):
 
 
 
 
119
  image_path = str(image_path)
120
  img_name = Path(image_path).name
121
  stem = Path(image_path).stem
@@ -147,12 +286,12 @@ def run_pipeline(image_path, output_dir="outputs",
147
  ocr_content = None
148
  if cls_raw == "note":
149
  print(f"[OCR] Note #{i+1}...")
150
- ocr_content = ocr_note(crop_path)
151
- print(f" → {repr(ocr_content[:80]) if ocr_content else 'EMPTY'}")
152
  elif cls_raw == "table":
153
  print(f"[OCR] Table #{i+1}...")
154
- ocr_content = ocr_table(crop_path)
155
- preview = ocr_content.get("text","")[:80]
156
  print(f" → {repr(preview) if preview else 'EMPTY'}")
157
 
158
  objects.append({
@@ -185,5 +324,6 @@ def run_pipeline(image_path, output_dir="outputs",
185
  if __name__ == "__main__":
186
  import sys
187
  img = sys.argv[1] if len(sys.argv) > 1 else "test.jpg"
188
- result, _ = run_pipeline(img)
 
189
  print(json.dumps(result, ensure_ascii=False, indent=2))
 
11
  import numpy as np
12
  from pathlib import Path
13
  from ultralytics import RTDETR
14
+ import re
15
 
16
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
17
  print(f"[INFO] Device: {DEVICE}")
 
21
  COLORS = {"note": (0,165,255), "part-drawing": (0,200,0), "table": (0,0,220)}
22
 
23
  _det_model = None
24
+ _ocr_reader = None # sẽ là PaddleOCR hoặc EasyOCR
25
 
26
  def get_det_model(checkpoint="best.pt"):
27
  global _det_model
28
  if _det_model is None:
29
+ print(f"[INFO] Loading detection model: {checkpoint}")
30
  _det_model = RTDETR(checkpoint)
31
  return _det_model
32
 
33
+ def get_ocr_reader(backend="paddle"):
34
+ """Khởi tạo OCR engine, ưu tiên PaddleOCR, fallback EasyOCR"""
35
+ global _ocr_reader
36
+ if _ocr_reader is not None:
37
+ return _ocr_reader
 
 
38
 
39
+ if backend == "paddle":
40
+ try:
41
+ from paddleocr import PaddleOCR
42
+ print("[INFO] Initializing PaddleOCR (lang: vi, en)...")
43
+ _ocr_reader = PaddleOCR(
44
+ lang='vi', # tiếng Việt + tiếng Anh
45
+ use_angle_cls=True, # tự động xoay ảnh
46
+ use_gpu=(DEVICE == "cuda"),
47
+ show_log=False,
48
+ det_db_thresh=0.3,
49
+ det_db_box_thresh=0.5,
50
+ rec_algorithm='SVTR_LCNet' # mạnh cho chữ in
51
+ )
52
+ return _ocr_reader
53
+ except ImportError:
54
+ print("[WARN] PaddleOCR not installed, falling back to EasyOCR.")
55
+ except Exception as e:
56
+ print(f"[WARN] PaddleOCR init failed: {e}, fallback to EasyOCR.")
57
+
58
+ # Fallback to EasyOCR
59
+ import easyocr
60
+ print("[INFO] Loading EasyOCR (vi, en)...")
61
+ _ocr_reader = easyocr.Reader(["vi", "en"], gpu=(DEVICE == "cuda"), verbose=False)
62
+ return _ocr_reader
63
+
64
+ def preprocess_image(img_bgr, ocr_type="note"):
65
+ """
66
+ Tiền xử lý ảnh phù hợp với từng loại:
67
+ - note: tăng độ tương phản, làm mờ nhẹ, sharpening
68
+ - table: nhị phân hóa, xóa đường kẻ ngang/dọc (tùy chọn)
69
+ """
70
  h, w = img_bgr.shape[:2]
71
+ # Resize nếu quá nhỏ (cải thiện OCR)
72
  if w < 800:
73
  scale = 800 / w
74
+ img_bgr = cv2.resize(img_bgr, (int(w*scale), int(h*scale)), interpolation=cv2.INTER_CUBIC)
75
+
76
+ gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
77
+
78
+ if ocr_type == "note":
79
+ # CLAHE + Denoising + Sharpening
80
+ clahe = cv2.createCLAHE(clipLimit=2.5, tileGridSize=(8,8))
81
+ gray = clahe.apply(gray)
82
+ gray = cv2.fastNlMeansDenoising(gray, h=10, templateWindowSize=7, searchWindowSize=21)
83
+ kernel = np.array([[0,-1,0],[-1,5,-1],[0,-1,0]])
84
+ gray = cv2.filter2D(gray, -1, kernel)
85
+ # Chuyển về BGR cho PaddleOCR/EasyOCR
86
+ return cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
87
+
88
+ else: # table
89
+ # Nhị phân hóa thích ứng (giữ chữ, xóa bớt nhiễu nền)
90
+ binary = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
91
+ cv2.THRESH_BINARY, 11, 2)
92
+ # Loại bỏ đường kẻ ngang/dọc (tùy chọn, giúp OCR dễ hơn)
93
+ horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (25,1))
94
+ vertical_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1,25))
95
+ detected_lines_h = cv2.morphologyEx(binary, cv2.MORPH_OPEN, horizontal_kernel, iterations=2)
96
+ detected_lines_v = cv2.morphologyEx(binary, cv2.MORPH_OPEN, vertical_kernel, iterations=2)
97
+ # Xóa đường kẻ khỏi ảnh nhị phân
98
+ binary = cv2.bitwise_and(binary, cv2.bitwise_not(detected_lines_h))
99
+ binary = cv2.bitwise_and(binary, cv2.bitwise_not(detected_lines_v))
100
+ # Làm dày chữ một chút
101
+ kernel_dilate = np.ones((2,2), np.uint8)
102
+ binary = cv2.dilate(binary, kernel_dilate, iterations=1)
103
+ return cv2.cvtColor(binary, cv2.COLOR_GRAY2BGR)
104
+
105
+ def ocr_with_backend(img_bgr, backend="paddle", ocr_type="note"):
106
+ """Gọi OCR engine tương ứng, trả về list các (text, conf, center_x, center_y)"""
107
+ reader = get_ocr_reader(backend)
108
+ img_for_ocr = preprocess_image(img_bgr, ocr_type)
109
+
110
+ if backend == "paddle":
111
+ # PaddleOCR trả về list: [ [[box], (text, confidence)], ... ]
112
+ result = reader.ocr(img_for_ocr, cls=True)
113
+ if not result or not result[0]:
114
+ return []
115
+ items = []
116
+ for line in result[0]:
117
+ box, (text, conf) = line
118
+ if conf < 0.3 or not text.strip():
119
+ continue
120
+ # Tính trung tâm bounding box
121
+ xs = [p[0] for p in box]
122
+ ys = [p[1] for p in box]
123
+ cx, cy = np.mean(xs), np.mean(ys)
124
+ items.append({
125
+ "text": text.strip(),
126
+ "conf": conf,
127
+ "x": cx,
128
+ "y": cy,
129
+ "box": box
130
+ })
131
+ return items
132
+ else:
133
+ # EasyOCR
134
+ results = reader.readtext(img_for_ocr, detail=1, paragraph=False,
135
  width_ths=0.7, height_ths=0.7)
136
+ items = []
137
+ for (pts, text, conf) in results:
138
+ if conf < 0.2 or not text.strip():
139
+ continue
140
+ cx = sum(p[0] for p in pts) / 4
141
+ cy = sum(p[1] for p in pts) / 4
142
+ items.append({
143
+ "text": text.strip(),
144
+ "conf": conf,
145
+ "x": cx,
146
+ "y": cy,
147
+ "box": pts
148
+ })
149
+ return items
150
 
151
+ def group_rows(items, vertical_thresh_ratio=0.6):
152
+ """
153
+ Nhóm các item theo hàng dựa trên tọa độ y.
154
+ Dùng DBSCAN nếu có sklearn, nếu không thì dùng heuristic.
155
+ """
156
  if not items:
157
  return []
158
+
159
+ # Sắp xếp theo y tăng dần
160
+ items_sorted = sorted(items, key=lambda x: x["y"])
161
+ y_vals = [it["y"] for it in items_sorted]
162
+
163
+ # Tự động ước lượng ngưỡng dựa trên khoảng cách trung bình
164
  if len(y_vals) > 1:
165
+ gaps = [y_vals[i+1] - y_vals[i] for i in range(len(y_vals)-1)]
166
+ median_gap = np.median(gaps)
167
+ thresh = max(8, median_gap * vertical_thresh_ratio)
168
  else:
169
  thresh = 12
170
+
171
+ rows = []
172
+ current_row = [items_sorted[0]]
173
+ for it in items_sorted[1:]:
174
+ if it["y"] - current_row[-1]["y"] < thresh:
175
+ current_row.append(it)
176
  else:
177
+ # Sắp xếp các item trong cùng hàng theo x
178
+ current_row.sort(key=lambda x: x["x"])
179
+ rows.append(current_row)
180
+ current_row = [it]
181
+ current_row.sort(key=lambda x: x["x"])
182
+ rows.append(current_row)
183
+
184
+ # Chuyển thành list text theo hàng
185
+ return [[it["text"] for it in row] for row in rows]
186
+
187
+ def ocr_note(img_path, backend="paddle"):
188
+ """OCR cho vùng Note, trả về chuỗi văn bản."""
189
  img = cv2.imread(img_path)
190
  if img is None:
191
+ return ""
192
+
193
+ items = ocr_with_backend(img, backend, ocr_type="note")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  if not items:
195
+ # Thử lại với preprocessing khác (bỏ sharpen, chỉ CLAHE)
196
+ gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
197
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
198
+ gray = clahe.apply(gray)
199
+ img2 = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
200
+ items = ocr_with_backend(img2, backend, ocr_type="note")
201
+ if not items:
202
+ return ""
203
+
204
+ # Sắp xếp theo y rồi x để tạo đoạn văn bản
205
+ items_sorted = sorted(items, key=lambda x: (x["y"], x["x"]))
206
+ lines = []
207
+ current_line = []
208
+ y_thresh = 12 # ngưỡng dòng
209
+ for i, it in enumerate(items_sorted):
210
+ if i == 0:
211
+ current_line.append(it["text"])
212
+ else:
213
+ if abs(it["y"] - items_sorted[i-1]["y"]) < y_thresh:
214
+ current_line.append(it["text"])
215
+ else:
216
+ lines.append(" ".join(current_line))
217
+ current_line = [it["text"]]
218
+ if current_line:
219
+ lines.append(" ".join(current_line))
220
+
221
+ # Post-processing: loại bỏ ký tự lạ, chuẩn hóa khoảng trắng
222
+ clean_lines = []
223
+ for line in lines:
224
+ line = re.sub(r'[^\w\s\.\,\-\/\(\)]', '', line) # giữ chữ, số, dấu câu cơ bản
225
+ line = re.sub(r'\s+', ' ', line).strip()
226
+ if len(line) > 1:
227
+ clean_lines.append(line)
228
+ return "\n".join(clean_lines)
229
+
230
+ def ocr_table(img_path, backend="paddle"):
231
+ """OCR cho vùng Table, trả về dict rows và text."""
232
+ img = cv2.imread(img_path)
233
+ if img is None:
234
+ return {"rows": [], "text": ""}
235
+
236
+ items = ocr_with_backend(img, backend, ocr_type="table")
237
+ if not items:
238
+ # Thử lại với ảnh gốc (không xóa đường kẻ) vì đôi khi đường kẻ giúp định vị ô
239
+ img2 = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
240
+ img2 = cv2.adaptiveThreshold(img2, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
241
+ cv2.THRESH_BINARY, 15, 5)
242
+ img2 = cv2.cvtColor(img2, cv2.COLOR_GRAY2BGR)
243
+ items = ocr_with_backend(img2, backend, ocr_type="table")
244
+ if not items:
245
+ return {"rows": [], "text": ""}
246
+
247
+ rows = group_rows(items, vertical_thresh_ratio=0.6)
248
+ # Chuyển rows thành text (các cột cách nhau bằng ' | ')
249
+ text_lines = [" | ".join(row) for row in rows if row]
250
+ return {"rows": rows, "text": "\n".join(text_lines)}
251
 
252
  def run_pipeline(image_path, output_dir="outputs",
253
+ checkpoint="best.pt", conf_thresh=0.3,
254
+ ocr_backend="paddle"):
255
+ """
256
+ ocr_backend: "paddle" (khuyến nghị) hoặc "easyocr"
257
+ """
258
  image_path = str(image_path)
259
  img_name = Path(image_path).name
260
  stem = Path(image_path).stem
 
286
  ocr_content = None
287
  if cls_raw == "note":
288
  print(f"[OCR] Note #{i+1}...")
289
+ ocr_content = ocr_note(crop_path, backend=ocr_backend)
290
+ print(f" → {repr(ocr_content[:100]) if ocr_content else 'EMPTY'}")
291
  elif cls_raw == "table":
292
  print(f"[OCR] Table #{i+1}...")
293
+ ocr_content = ocr_table(crop_path, backend=ocr_backend)
294
+ preview = ocr_content.get("text", "")[:100]
295
  print(f" → {repr(preview) if preview else 'EMPTY'}")
296
 
297
  objects.append({
 
324
  if __name__ == "__main__":
325
  import sys
326
  img = sys.argv[1] if len(sys.argv) > 1 else "test.jpg"
327
+ # thể chọn backend: "paddle" (mặc định) hoặc "easyocr"
328
+ result, _ = run_pipeline(img, ocr_backend="paddle")
329
  print(json.dumps(result, ensure_ascii=False, indent=2))