Harry Pham commited on
Commit
d80899e
·
1 Parent(s): f69131e

update OCR

Browse files
Files changed (1) hide show
  1. src/inference.py +95 -95
src/inference.py CHANGED
@@ -35,6 +35,22 @@ def get_det_model(checkpoint="best.pt"):
35
  _det_model = RTDETR(checkpoint)
36
  return _det_model
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  def get_paddle_reader(lang='vi'):
40
  """
@@ -239,26 +255,30 @@ def multi_pass_ocr(img_bgr, reader, ocr_type="note"):
239
  # ============================================================
240
  # DUAL-ENGINE OCR — PaddleOCR (vi) + PaddleOCR (en), chọn tốt hơn
241
  # ============================================================
242
- def dual_engine_ocr(img_bgr, ocr_type="note"):
243
  """
244
- Chạy PaddleOCR với cả lang='vi' và lang='en',
245
- chọn kết quả có confidence cao hơn.
246
- Nếu PaddleOCR fail fallback EasyOCR.
247
  """
248
- reader_vi = get_paddle_reader('vi')
249
- reader_en = get_paddle_reader('en')
 
 
 
 
 
 
250
 
251
- if reader_vi is None and reader_en is None:
252
- # Fallback to EasyOCR
253
  reader = get_easyocr_reader()
254
- texts, conf = multi_pass_ocr(img_bgr, reader, ocr_type)
255
- return texts, conf
256
 
257
  best_texts = []
258
  best_conf = 0.0
259
  best_lang = ""
260
 
261
- # Try Vietnamese
262
  if reader_vi:
263
  texts_vi, conf_vi = multi_pass_ocr(img_bgr, reader_vi, ocr_type)
264
  if conf_vi > best_conf:
@@ -266,7 +286,6 @@ def dual_engine_ocr(img_bgr, ocr_type="note"):
266
  best_texts = texts_vi
267
  best_lang = "vi"
268
 
269
- # Try English
270
  if reader_en:
271
  texts_en, conf_en = multi_pass_ocr(img_bgr, reader_en, ocr_type)
272
  if conf_en > best_conf:
@@ -274,7 +293,13 @@ def dual_engine_ocr(img_bgr, ocr_type="note"):
274
  best_texts = texts_en
275
  best_lang = "en"
276
 
277
- print(f" Best language: {best_lang} (conf={best_conf:.3f})")
 
 
 
 
 
 
278
  return best_texts, best_conf
279
 
280
 
@@ -313,22 +338,15 @@ def post_process_ocr_text(text):
313
  # OCR NOTE — Cải thiện
314
  # ============================================================
315
  def ocr_note(img_path, backend="paddle"):
316
- """
317
- OCR cho vùng Note — cải thiện:
318
- 1. Upscale mạnh (min 1500px width)
319
- 2. Multi-pass với nhiều preprocessing
320
- 3. Dual-engine (vi + en)
321
- 4. Post-processing
322
- """
323
  img = cv2.imread(img_path)
324
  if img is None:
325
  return ""
326
 
327
- texts, conf = dual_engine_ocr(img, ocr_type="note")
328
 
329
  # Post-process từng dòng
330
  processed = [post_process_ocr_text(t) for t in texts]
331
- processed = [t for t in processed if t] # remove empty
332
 
333
  return "\n".join(processed)
334
 
@@ -379,84 +397,52 @@ def parse_html_table(html_str):
379
 
380
 
381
  def ocr_table(img_path, backend="paddle"):
382
- """
383
- OCR cho vùng Table — cải thiện:
384
- 1. Thử PPStructure trước (table structure recognition tốt nhất)
385
- 2. Fallback: detect cells thủ công + OCR từng cell
386
- 3. Post-processing
387
- """
388
  img = cv2.imread(img_path)
389
  if img is None:
390
  return {"rows": [], "text": ""}
391
 
392
- # === Strategy 1: PPStructure (best for tables) ===
393
- pp_engine = get_pp_structure()
394
- if pp_engine is not None:
395
- try:
396
- # Upscale trước khi đưa vào PPStructure
397
- h, w = img.shape[:2]
398
- if w < 1200:
399
- scale = 1200 / w
400
- img_scaled = cv2.resize(img, None, fx=scale, fy=scale,
401
- interpolation=cv2.INTER_CUBIC)
402
- else:
403
- img_scaled = img
404
-
405
- result = pp_engine(img_scaled)
406
- for item in result:
407
- if item.get('type') == 'table':
408
- html = item.get('res', {}).get('html', '')
409
- if html:
410
- rows = parse_html_table(html)
411
- if rows:
412
- # Post-process mỗi cell
413
- rows = [[post_process_ocr_text(cell) for cell in row]
414
- for row in rows]
415
- text = "\n".join(" | ".join(r) for r in rows)
416
- print(f" PPStructure: {len(rows)} rows detected")
417
- return {"rows": rows, "text": text, "html": html}
418
-
419
- # PPStructure ran but no table found → extract text
420
- all_texts = []
421
- for item in result:
422
- res = item.get('res', [])
423
- if isinstance(res, list):
424
- for line in res:
425
- if isinstance(line, dict) and 'text' in line:
426
- all_texts.append(line['text'])
427
- elif isinstance(line, (list, tuple)) and len(line) >= 2:
428
- text_info = line[1]
429
- if isinstance(text_info, (list, tuple)):
430
- all_texts.append(str(text_info[0]))
431
- else:
432
- all_texts.append(str(text_info))
433
- if all_texts:
434
- return {"rows": [all_texts], "text": "\n".join(all_texts)}
435
 
436
- except Exception as e:
437
- print(f" PPStructure error: {e}, falling back to manual")
438
-
439
- # === Strategy 2: Manual cell detection + OCR ===
 
 
 
 
 
 
 
 
 
 
 
 
 
440
  return ocr_table_manual(img, img_path, backend)
441
 
442
-
443
  def ocr_table_manual(img, img_path, backend="paddle"):
444
- """
445
- Fallback: detect table cells thủ công + OCR từng cell.
446
- Cải thiện: upscale mỗi cell riêng, multi-pass OCR.
447
- """
448
  cells = detect_table_structure(img)
449
 
450
  if cells:
451
- reader = get_paddle_reader('vi') or get_easyocr_reader()
452
  ocr_results = []
453
-
454
  for (x1, y1, x2, y2) in cells:
455
- # Bỏ cell quá lớn (toàn bộ bảng) hoặc quá nhỏ
456
  cell_w, cell_h = x2 - x1, y2 - y1
457
  img_h, img_w = img.shape[:2]
458
  if cell_w > img_w * 0.9 and cell_h > img_h * 0.9:
459
- continue # Skip full-table contour
460
  if cell_w < 15 or cell_h < 15:
461
  continue
462
 
@@ -467,7 +453,7 @@ def ocr_table_manual(img, img_path, backend="paddle"):
467
  cx2 = min(img.shape[1], x2 + pad)
468
  cell_img = img[cy1:cy2, cx1:cx2]
469
 
470
- text = ocr_cell_improved(cell_img, reader)
471
  if text:
472
  ocr_results.append({
473
  "text": post_process_ocr_text(text),
@@ -483,31 +469,36 @@ def ocr_table_manual(img, img_path, backend="paddle"):
483
  "text": "\n".join(" | ".join(r) for r in rows)
484
  }
485
 
486
- # === Strategy 3: OCR toàn bộ ảnh table, group theo hàng ===
487
  return ocr_table_fullimage(img, backend)
488
 
489
 
490
- def ocr_cell_improved(img_cell, reader):
491
  """OCR 1 cell — upscale mạnh, multi-preprocessing."""
492
  if img_cell.size == 0:
493
  return ""
494
 
495
  h, w = img_cell.shape[:2]
496
-
497
- # Upscale cell nhỏ rất mạnh
498
  target_w = max(300, w)
499
  if w < target_w:
500
  scale = target_w / w
501
  img_cell = cv2.resize(img_cell, None, fx=scale, fy=scale,
502
  interpolation=cv2.INTER_CUBIC)
503
 
504
- # Try 2 variants
 
 
 
 
 
 
 
 
 
505
  best_text = ""
506
  best_conf = 0
507
 
508
  for variant in ["color", "binary"]:
509
  if variant == "color":
510
- # Gentle enhancement
511
  img_proc = cv2.bilateralFilter(img_cell, 5, 50, 50)
512
  lab = cv2.cvtColor(img_proc, cv2.COLOR_BGR2LAB)
513
  l, a, b = cv2.split(lab)
@@ -531,8 +522,18 @@ def ocr_cell_improved(img_cell, reader):
531
 
532
 
533
  def ocr_table_fullimage(img, backend="paddle"):
534
- """OCR toàn bộ ảnh table (không chia cell), group by rows."""
535
- reader = get_paddle_reader('vi') or get_easyocr_reader()
 
 
 
 
 
 
 
 
 
 
536
  img_proc = preprocess_for_ocr(img, min_width=1500, mode="table")
537
 
538
  items = []
@@ -571,7 +572,6 @@ def ocr_table_fullimage(img, backend="paddle"):
571
  rows = group_rows(items, vertical_thresh_ratio=0.6)
572
  return {"rows": rows, "text": "\n".join(" | ".join(r) for r in rows)}
573
 
574
-
575
  # ============================================================
576
  # TABLE STRUCTURE DETECTION (giữ nguyên, có cải thiện nhỏ)
577
  # ============================================================
@@ -717,5 +717,5 @@ def run_pipeline(image_path, output_dir="outputs",
717
  if __name__ == "__main__":
718
  import sys
719
  img = sys.argv[1] if len(sys.argv) > 1 else "test.jpg"
720
- result, _ = run_pipeline(img, ocr_backend="paddle")
721
  print(json.dumps(result, ensure_ascii=False, indent=2))
 
35
  _det_model = RTDETR(checkpoint)
36
  return _det_model
37
 
38
+ # Thêm Surya OCR làm engine thứ 3
39
+ from surya.ocr import run_ocr
40
+ from surya.model.detection.model import load_det_processor, load_det_model
41
+ from surya.model.recognition.model import load_rec_model
42
+ from surya.model.recognition.processor import load_rec_processor
43
+
44
+ def ocr_with_surya(img_bgr, langs=["vi", "en"]):
45
+ det_processor, det_model = load_det_processor(), load_det_model()
46
+ rec_model, rec_processor = load_rec_model(), load_rec_processor()
47
+ from PIL import Image
48
+ pil_img = Image.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
49
+ predictions = run_ocr([pil_img], [langs], det_model, det_processor,
50
+ rec_model, rec_processor)
51
+ texts = [line.text for line in predictions[0].text_lines]
52
+ return "\n".join(texts)
53
+
54
 
55
  def get_paddle_reader(lang='vi'):
56
  """
 
255
  # ============================================================
256
  # DUAL-ENGINE OCR — PaddleOCR (vi) + PaddleOCR (en), chọn tốt hơn
257
  # ============================================================
258
+ def run_ocr_with_backend(img_bgr, backend="paddle", ocr_type="note"):
259
  """
260
+ Chạy OCR với backend được chọn.
261
+ backend: "paddle", "easyocr", "surya"
262
+ Trả về (list_of_texts, avg_confidence) - với surya, confidence luôn = 1.0
263
  """
264
+ if backend == "surya":
265
+ text = ocr_with_surya(img_bgr, langs=["vi", "en"])
266
+ lines = [line.strip() for line in text.split("\n") if line.strip()]
267
+ return lines, 1.0 # Surya không trả confidence, coi như 1.0
268
+
269
+ # logic cũ cho paddle + easyocr
270
+ reader_vi = get_paddle_reader('vi') if backend == "paddle" else None
271
+ reader_en = get_paddle_reader('en') if backend == "paddle" else None
272
 
273
+ if reader_vi is None and reader_en is None and backend == "paddle":
274
+ # fallback easyocr
275
  reader = get_easyocr_reader()
276
+ return multi_pass_ocr(img_bgr, reader, ocr_type)
 
277
 
278
  best_texts = []
279
  best_conf = 0.0
280
  best_lang = ""
281
 
 
282
  if reader_vi:
283
  texts_vi, conf_vi = multi_pass_ocr(img_bgr, reader_vi, ocr_type)
284
  if conf_vi > best_conf:
 
286
  best_texts = texts_vi
287
  best_lang = "vi"
288
 
 
289
  if reader_en:
290
  texts_en, conf_en = multi_pass_ocr(img_bgr, reader_en, ocr_type)
291
  if conf_en > best_conf:
 
293
  best_texts = texts_en
294
  best_lang = "en"
295
 
296
+ if best_lang:
297
+ print(f" Best language: {best_lang} (conf={best_conf:.3f})")
298
+ else:
299
+ # fallback easyocr
300
+ reader = get_easyocr_reader()
301
+ best_texts, best_conf = multi_pass_ocr(img_bgr, reader, ocr_type)
302
+
303
  return best_texts, best_conf
304
 
305
 
 
338
  # OCR NOTE — Cải thiện
339
  # ============================================================
340
  def ocr_note(img_path, backend="paddle"):
 
 
 
 
 
 
 
341
  img = cv2.imread(img_path)
342
  if img is None:
343
  return ""
344
 
345
+ texts, _ = run_ocr_with_backend(img, backend=backend, ocr_type="note")
346
 
347
  # Post-process từng dòng
348
  processed = [post_process_ocr_text(t) for t in texts]
349
+ processed = [t for t in processed if t]
350
 
351
  return "\n".join(processed)
352
 
 
397
 
398
 
399
  def ocr_table(img_path, backend="paddle"):
 
 
 
 
 
 
400
  img = cv2.imread(img_path)
401
  if img is None:
402
  return {"rows": [], "text": ""}
403
 
404
+ # Strategy 1: PPStructure (chỉ dùng nếu backend là paddle, vì PPStructure dùng PaddleOCR)
405
+ if backend == "paddle":
406
+ pp_engine = get_pp_structure()
407
+ if pp_engine is not None:
408
+ try:
409
+ h, w = img.shape[:2]
410
+ if w < 1200:
411
+ scale = 1200 / w
412
+ img_scaled = cv2.resize(img, None, fx=scale, fy=scale,
413
+ interpolation=cv2.INTER_CUBIC)
414
+ else:
415
+ img_scaled = img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
416
 
417
+ result = pp_engine(img_scaled)
418
+ for item in result:
419
+ if item.get('type') == 'table':
420
+ html = item.get('res', {}).get('html', '')
421
+ if html:
422
+ rows = parse_html_table(html)
423
+ if rows:
424
+ rows = [[post_process_ocr_text(cell) for cell in row]
425
+ for row in rows]
426
+ text = "\n".join(" | ".join(r) for r in rows)
427
+ print(f" PPStructure: {len(rows)} rows detected")
428
+ return {"rows": rows, "text": text, "html": html}
429
+ # Nếu không tìm thấy table, fallback
430
+ except Exception as e:
431
+ print(f" PPStructure error: {e}, falling back to manual")
432
+
433
+ # Strategy 2: Manual cell detection
434
  return ocr_table_manual(img, img_path, backend)
435
 
 
436
  def ocr_table_manual(img, img_path, backend="paddle"):
 
 
 
 
437
  cells = detect_table_structure(img)
438
 
439
  if cells:
 
440
  ocr_results = []
 
441
  for (x1, y1, x2, y2) in cells:
 
442
  cell_w, cell_h = x2 - x1, y2 - y1
443
  img_h, img_w = img.shape[:2]
444
  if cell_w > img_w * 0.9 and cell_h > img_h * 0.9:
445
+ continue
446
  if cell_w < 15 or cell_h < 15:
447
  continue
448
 
 
453
  cx2 = min(img.shape[1], x2 + pad)
454
  cell_img = img[cy1:cy2, cx1:cx2]
455
 
456
+ text = ocr_cell_improved(cell_img, backend=backend)
457
  if text:
458
  ocr_results.append({
459
  "text": post_process_ocr_text(text),
 
469
  "text": "\n".join(" | ".join(r) for r in rows)
470
  }
471
 
 
472
  return ocr_table_fullimage(img, backend)
473
 
474
 
475
+ def ocr_cell_improved(img_cell, backend="paddle"):
476
  """OCR 1 cell — upscale mạnh, multi-preprocessing."""
477
  if img_cell.size == 0:
478
  return ""
479
 
480
  h, w = img_cell.shape[:2]
 
 
481
  target_w = max(300, w)
482
  if w < target_w:
483
  scale = target_w / w
484
  img_cell = cv2.resize(img_cell, None, fx=scale, fy=scale,
485
  interpolation=cv2.INTER_CUBIC)
486
 
487
+ if backend == "surya":
488
+ # Chạy Surya trực tiếp
489
+ text = ocr_with_surya(img_cell, langs=["vi", "en"])
490
+ return text.strip()
491
+
492
+ # logic cũ với reader (paddle/easyocr)
493
+ reader = get_paddle_reader('vi') if backend == "paddle" else get_easyocr_reader()
494
+ if reader is None:
495
+ reader = get_easyocr_reader()
496
+
497
  best_text = ""
498
  best_conf = 0
499
 
500
  for variant in ["color", "binary"]:
501
  if variant == "color":
 
502
  img_proc = cv2.bilateralFilter(img_cell, 5, 50, 50)
503
  lab = cv2.cvtColor(img_proc, cv2.COLOR_BGR2LAB)
504
  l, a, b = cv2.split(lab)
 
522
 
523
 
524
  def ocr_table_fullimage(img, backend="paddle"):
525
+ if backend == "surya":
526
+ # Dùng Surya OCR trên toàn bộ ảnh table
527
+ text = ocr_with_surya(img, langs=["vi", "en"])
528
+ lines = [line.strip() for line in text.split("\n") if line.strip()]
529
+ # Với Surya, ta không có bounding box, chỉ trả về một cột
530
+ rows = [[line] for line in lines]
531
+ return {"rows": rows, "text": text}
532
+
533
+ # logic cũ với paddle/easyocr
534
+ reader = get_paddle_reader('vi') if backend == "paddle" else get_easyocr_reader()
535
+ if reader is None:
536
+ reader = get_easyocr_reader()
537
  img_proc = preprocess_for_ocr(img, min_width=1500, mode="table")
538
 
539
  items = []
 
572
  rows = group_rows(items, vertical_thresh_ratio=0.6)
573
  return {"rows": rows, "text": "\n".join(" | ".join(r) for r in rows)}
574
 
 
575
  # ============================================================
576
  # TABLE STRUCTURE DETECTION (giữ nguyên, có cải thiện nhỏ)
577
  # ============================================================
 
717
  if __name__ == "__main__":
718
  import sys
719
  img = sys.argv[1] if len(sys.argv) > 1 else "test.jpg"
720
+ result, _ = run_pipeline(img, ocr_backend="surya")
721
  print(json.dumps(result, ensure_ascii=False, indent=2))