Spaces:
Sleeping
Sleeping
Harry Pham commited on
Commit ·
d80899e
1
Parent(s): f69131e
update OCR
Browse files- src/inference.py +95 -95
src/inference.py
CHANGED
|
@@ -35,6 +35,22 @@ def get_det_model(checkpoint="best.pt"):
|
|
| 35 |
_det_model = RTDETR(checkpoint)
|
| 36 |
return _det_model
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
def get_paddle_reader(lang='vi'):
|
| 40 |
"""
|
|
@@ -239,26 +255,30 @@ def multi_pass_ocr(img_bgr, reader, ocr_type="note"):
|
|
| 239 |
# ============================================================
|
| 240 |
# DUAL-ENGINE OCR — PaddleOCR (vi) + PaddleOCR (en), chọn tốt hơn
|
| 241 |
# ============================================================
|
| 242 |
-
def
|
| 243 |
"""
|
| 244 |
-
Chạy
|
| 245 |
-
|
| 246 |
-
|
| 247 |
"""
|
| 248 |
-
|
| 249 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
|
| 251 |
-
if reader_vi is None and reader_en is None:
|
| 252 |
-
#
|
| 253 |
reader = get_easyocr_reader()
|
| 254 |
-
|
| 255 |
-
return texts, conf
|
| 256 |
|
| 257 |
best_texts = []
|
| 258 |
best_conf = 0.0
|
| 259 |
best_lang = ""
|
| 260 |
|
| 261 |
-
# Try Vietnamese
|
| 262 |
if reader_vi:
|
| 263 |
texts_vi, conf_vi = multi_pass_ocr(img_bgr, reader_vi, ocr_type)
|
| 264 |
if conf_vi > best_conf:
|
|
@@ -266,7 +286,6 @@ def dual_engine_ocr(img_bgr, ocr_type="note"):
|
|
| 266 |
best_texts = texts_vi
|
| 267 |
best_lang = "vi"
|
| 268 |
|
| 269 |
-
# Try English
|
| 270 |
if reader_en:
|
| 271 |
texts_en, conf_en = multi_pass_ocr(img_bgr, reader_en, ocr_type)
|
| 272 |
if conf_en > best_conf:
|
|
@@ -274,7 +293,13 @@ def dual_engine_ocr(img_bgr, ocr_type="note"):
|
|
| 274 |
best_texts = texts_en
|
| 275 |
best_lang = "en"
|
| 276 |
|
| 277 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
return best_texts, best_conf
|
| 279 |
|
| 280 |
|
|
@@ -313,22 +338,15 @@ def post_process_ocr_text(text):
|
|
| 313 |
# OCR NOTE — Cải thiện
|
| 314 |
# ============================================================
|
| 315 |
def ocr_note(img_path, backend="paddle"):
|
| 316 |
-
"""
|
| 317 |
-
OCR cho vùng Note — cải thiện:
|
| 318 |
-
1. Upscale mạnh (min 1500px width)
|
| 319 |
-
2. Multi-pass với nhiều preprocessing
|
| 320 |
-
3. Dual-engine (vi + en)
|
| 321 |
-
4. Post-processing
|
| 322 |
-
"""
|
| 323 |
img = cv2.imread(img_path)
|
| 324 |
if img is None:
|
| 325 |
return ""
|
| 326 |
|
| 327 |
-
texts,
|
| 328 |
|
| 329 |
# Post-process từng dòng
|
| 330 |
processed = [post_process_ocr_text(t) for t in texts]
|
| 331 |
-
processed = [t for t in processed if t]
|
| 332 |
|
| 333 |
return "\n".join(processed)
|
| 334 |
|
|
@@ -379,84 +397,52 @@ def parse_html_table(html_str):
|
|
| 379 |
|
| 380 |
|
| 381 |
def ocr_table(img_path, backend="paddle"):
|
| 382 |
-
"""
|
| 383 |
-
OCR cho vùng Table — cải thiện:
|
| 384 |
-
1. Thử PPStructure trước (table structure recognition tốt nhất)
|
| 385 |
-
2. Fallback: detect cells thủ công + OCR từng cell
|
| 386 |
-
3. Post-processing
|
| 387 |
-
"""
|
| 388 |
img = cv2.imread(img_path)
|
| 389 |
if img is None:
|
| 390 |
return {"rows": [], "text": ""}
|
| 391 |
|
| 392 |
-
#
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
result = pp_engine(img_scaled)
|
| 406 |
-
for item in result:
|
| 407 |
-
if item.get('type') == 'table':
|
| 408 |
-
html = item.get('res', {}).get('html', '')
|
| 409 |
-
if html:
|
| 410 |
-
rows = parse_html_table(html)
|
| 411 |
-
if rows:
|
| 412 |
-
# Post-process mỗi cell
|
| 413 |
-
rows = [[post_process_ocr_text(cell) for cell in row]
|
| 414 |
-
for row in rows]
|
| 415 |
-
text = "\n".join(" | ".join(r) for r in rows)
|
| 416 |
-
print(f" PPStructure: {len(rows)} rows detected")
|
| 417 |
-
return {"rows": rows, "text": text, "html": html}
|
| 418 |
-
|
| 419 |
-
# PPStructure ran but no table found → extract text
|
| 420 |
-
all_texts = []
|
| 421 |
-
for item in result:
|
| 422 |
-
res = item.get('res', [])
|
| 423 |
-
if isinstance(res, list):
|
| 424 |
-
for line in res:
|
| 425 |
-
if isinstance(line, dict) and 'text' in line:
|
| 426 |
-
all_texts.append(line['text'])
|
| 427 |
-
elif isinstance(line, (list, tuple)) and len(line) >= 2:
|
| 428 |
-
text_info = line[1]
|
| 429 |
-
if isinstance(text_info, (list, tuple)):
|
| 430 |
-
all_texts.append(str(text_info[0]))
|
| 431 |
-
else:
|
| 432 |
-
all_texts.append(str(text_info))
|
| 433 |
-
if all_texts:
|
| 434 |
-
return {"rows": [all_texts], "text": "\n".join(all_texts)}
|
| 435 |
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 440 |
return ocr_table_manual(img, img_path, backend)
|
| 441 |
|
| 442 |
-
|
| 443 |
def ocr_table_manual(img, img_path, backend="paddle"):
|
| 444 |
-
"""
|
| 445 |
-
Fallback: detect table cells thủ công + OCR từng cell.
|
| 446 |
-
Cải thiện: upscale mỗi cell riêng, multi-pass OCR.
|
| 447 |
-
"""
|
| 448 |
cells = detect_table_structure(img)
|
| 449 |
|
| 450 |
if cells:
|
| 451 |
-
reader = get_paddle_reader('vi') or get_easyocr_reader()
|
| 452 |
ocr_results = []
|
| 453 |
-
|
| 454 |
for (x1, y1, x2, y2) in cells:
|
| 455 |
-
# Bỏ cell quá lớn (toàn bộ bảng) hoặc quá nhỏ
|
| 456 |
cell_w, cell_h = x2 - x1, y2 - y1
|
| 457 |
img_h, img_w = img.shape[:2]
|
| 458 |
if cell_w > img_w * 0.9 and cell_h > img_h * 0.9:
|
| 459 |
-
continue
|
| 460 |
if cell_w < 15 or cell_h < 15:
|
| 461 |
continue
|
| 462 |
|
|
@@ -467,7 +453,7 @@ def ocr_table_manual(img, img_path, backend="paddle"):
|
|
| 467 |
cx2 = min(img.shape[1], x2 + pad)
|
| 468 |
cell_img = img[cy1:cy2, cx1:cx2]
|
| 469 |
|
| 470 |
-
text = ocr_cell_improved(cell_img,
|
| 471 |
if text:
|
| 472 |
ocr_results.append({
|
| 473 |
"text": post_process_ocr_text(text),
|
|
@@ -483,31 +469,36 @@ def ocr_table_manual(img, img_path, backend="paddle"):
|
|
| 483 |
"text": "\n".join(" | ".join(r) for r in rows)
|
| 484 |
}
|
| 485 |
|
| 486 |
-
# === Strategy 3: OCR toàn bộ ảnh table, group theo hàng ===
|
| 487 |
return ocr_table_fullimage(img, backend)
|
| 488 |
|
| 489 |
|
| 490 |
-
def ocr_cell_improved(img_cell,
|
| 491 |
"""OCR 1 cell — upscale mạnh, multi-preprocessing."""
|
| 492 |
if img_cell.size == 0:
|
| 493 |
return ""
|
| 494 |
|
| 495 |
h, w = img_cell.shape[:2]
|
| 496 |
-
|
| 497 |
-
# Upscale cell nhỏ rất mạnh
|
| 498 |
target_w = max(300, w)
|
| 499 |
if w < target_w:
|
| 500 |
scale = target_w / w
|
| 501 |
img_cell = cv2.resize(img_cell, None, fx=scale, fy=scale,
|
| 502 |
interpolation=cv2.INTER_CUBIC)
|
| 503 |
|
| 504 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 505 |
best_text = ""
|
| 506 |
best_conf = 0
|
| 507 |
|
| 508 |
for variant in ["color", "binary"]:
|
| 509 |
if variant == "color":
|
| 510 |
-
# Gentle enhancement
|
| 511 |
img_proc = cv2.bilateralFilter(img_cell, 5, 50, 50)
|
| 512 |
lab = cv2.cvtColor(img_proc, cv2.COLOR_BGR2LAB)
|
| 513 |
l, a, b = cv2.split(lab)
|
|
@@ -531,8 +522,18 @@ def ocr_cell_improved(img_cell, reader):
|
|
| 531 |
|
| 532 |
|
| 533 |
def ocr_table_fullimage(img, backend="paddle"):
|
| 534 |
-
|
| 535 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 536 |
img_proc = preprocess_for_ocr(img, min_width=1500, mode="table")
|
| 537 |
|
| 538 |
items = []
|
|
@@ -571,7 +572,6 @@ def ocr_table_fullimage(img, backend="paddle"):
|
|
| 571 |
rows = group_rows(items, vertical_thresh_ratio=0.6)
|
| 572 |
return {"rows": rows, "text": "\n".join(" | ".join(r) for r in rows)}
|
| 573 |
|
| 574 |
-
|
| 575 |
# ============================================================
|
| 576 |
# TABLE STRUCTURE DETECTION (giữ nguyên, có cải thiện nhỏ)
|
| 577 |
# ============================================================
|
|
@@ -717,5 +717,5 @@ def run_pipeline(image_path, output_dir="outputs",
|
|
| 717 |
if __name__ == "__main__":
|
| 718 |
import sys
|
| 719 |
img = sys.argv[1] if len(sys.argv) > 1 else "test.jpg"
|
| 720 |
-
result, _ = run_pipeline(img, ocr_backend="
|
| 721 |
print(json.dumps(result, ensure_ascii=False, indent=2))
|
|
|
|
| 35 |
_det_model = RTDETR(checkpoint)
|
| 36 |
return _det_model
|
| 37 |
|
| 38 |
+
# Thêm Surya OCR làm engine thứ 3
|
| 39 |
+
from surya.ocr import run_ocr
|
| 40 |
+
from surya.model.detection.model import load_det_processor, load_det_model
|
| 41 |
+
from surya.model.recognition.model import load_rec_model
|
| 42 |
+
from surya.model.recognition.processor import load_rec_processor
|
| 43 |
+
|
| 44 |
+
def ocr_with_surya(img_bgr, langs=["vi", "en"]):
|
| 45 |
+
det_processor, det_model = load_det_processor(), load_det_model()
|
| 46 |
+
rec_model, rec_processor = load_rec_model(), load_rec_processor()
|
| 47 |
+
from PIL import Image
|
| 48 |
+
pil_img = Image.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
|
| 49 |
+
predictions = run_ocr([pil_img], [langs], det_model, det_processor,
|
| 50 |
+
rec_model, rec_processor)
|
| 51 |
+
texts = [line.text for line in predictions[0].text_lines]
|
| 52 |
+
return "\n".join(texts)
|
| 53 |
+
|
| 54 |
|
| 55 |
def get_paddle_reader(lang='vi'):
|
| 56 |
"""
|
|
|
|
| 255 |
# ============================================================
|
| 256 |
# DUAL-ENGINE OCR — PaddleOCR (vi) + PaddleOCR (en), chọn tốt hơn
|
| 257 |
# ============================================================
|
| 258 |
+
def run_ocr_with_backend(img_bgr, backend="paddle", ocr_type="note"):
|
| 259 |
"""
|
| 260 |
+
Chạy OCR với backend được chọn.
|
| 261 |
+
backend: "paddle", "easyocr", "surya"
|
| 262 |
+
Trả về (list_of_texts, avg_confidence) - với surya, confidence luôn = 1.0
|
| 263 |
"""
|
| 264 |
+
if backend == "surya":
|
| 265 |
+
text = ocr_with_surya(img_bgr, langs=["vi", "en"])
|
| 266 |
+
lines = [line.strip() for line in text.split("\n") if line.strip()]
|
| 267 |
+
return lines, 1.0 # Surya không trả confidence, coi như 1.0
|
| 268 |
+
|
| 269 |
+
# logic cũ cho paddle + easyocr
|
| 270 |
+
reader_vi = get_paddle_reader('vi') if backend == "paddle" else None
|
| 271 |
+
reader_en = get_paddle_reader('en') if backend == "paddle" else None
|
| 272 |
|
| 273 |
+
if reader_vi is None and reader_en is None and backend == "paddle":
|
| 274 |
+
# fallback easyocr
|
| 275 |
reader = get_easyocr_reader()
|
| 276 |
+
return multi_pass_ocr(img_bgr, reader, ocr_type)
|
|
|
|
| 277 |
|
| 278 |
best_texts = []
|
| 279 |
best_conf = 0.0
|
| 280 |
best_lang = ""
|
| 281 |
|
|
|
|
| 282 |
if reader_vi:
|
| 283 |
texts_vi, conf_vi = multi_pass_ocr(img_bgr, reader_vi, ocr_type)
|
| 284 |
if conf_vi > best_conf:
|
|
|
|
| 286 |
best_texts = texts_vi
|
| 287 |
best_lang = "vi"
|
| 288 |
|
|
|
|
| 289 |
if reader_en:
|
| 290 |
texts_en, conf_en = multi_pass_ocr(img_bgr, reader_en, ocr_type)
|
| 291 |
if conf_en > best_conf:
|
|
|
|
| 293 |
best_texts = texts_en
|
| 294 |
best_lang = "en"
|
| 295 |
|
| 296 |
+
if best_lang:
|
| 297 |
+
print(f" Best language: {best_lang} (conf={best_conf:.3f})")
|
| 298 |
+
else:
|
| 299 |
+
# fallback easyocr
|
| 300 |
+
reader = get_easyocr_reader()
|
| 301 |
+
best_texts, best_conf = multi_pass_ocr(img_bgr, reader, ocr_type)
|
| 302 |
+
|
| 303 |
return best_texts, best_conf
|
| 304 |
|
| 305 |
|
|
|
|
| 338 |
# OCR NOTE — Cải thiện
|
| 339 |
# ============================================================
|
| 340 |
def ocr_note(img_path, backend="paddle"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 341 |
img = cv2.imread(img_path)
|
| 342 |
if img is None:
|
| 343 |
return ""
|
| 344 |
|
| 345 |
+
texts, _ = run_ocr_with_backend(img, backend=backend, ocr_type="note")
|
| 346 |
|
| 347 |
# Post-process từng dòng
|
| 348 |
processed = [post_process_ocr_text(t) for t in texts]
|
| 349 |
+
processed = [t for t in processed if t]
|
| 350 |
|
| 351 |
return "\n".join(processed)
|
| 352 |
|
|
|
|
| 397 |
|
| 398 |
|
| 399 |
def ocr_table(img_path, backend="paddle"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 400 |
img = cv2.imread(img_path)
|
| 401 |
if img is None:
|
| 402 |
return {"rows": [], "text": ""}
|
| 403 |
|
| 404 |
+
# Strategy 1: PPStructure (chỉ dùng nếu backend là paddle, vì PPStructure dùng PaddleOCR)
|
| 405 |
+
if backend == "paddle":
|
| 406 |
+
pp_engine = get_pp_structure()
|
| 407 |
+
if pp_engine is not None:
|
| 408 |
+
try:
|
| 409 |
+
h, w = img.shape[:2]
|
| 410 |
+
if w < 1200:
|
| 411 |
+
scale = 1200 / w
|
| 412 |
+
img_scaled = cv2.resize(img, None, fx=scale, fy=scale,
|
| 413 |
+
interpolation=cv2.INTER_CUBIC)
|
| 414 |
+
else:
|
| 415 |
+
img_scaled = img
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 416 |
|
| 417 |
+
result = pp_engine(img_scaled)
|
| 418 |
+
for item in result:
|
| 419 |
+
if item.get('type') == 'table':
|
| 420 |
+
html = item.get('res', {}).get('html', '')
|
| 421 |
+
if html:
|
| 422 |
+
rows = parse_html_table(html)
|
| 423 |
+
if rows:
|
| 424 |
+
rows = [[post_process_ocr_text(cell) for cell in row]
|
| 425 |
+
for row in rows]
|
| 426 |
+
text = "\n".join(" | ".join(r) for r in rows)
|
| 427 |
+
print(f" PPStructure: {len(rows)} rows detected")
|
| 428 |
+
return {"rows": rows, "text": text, "html": html}
|
| 429 |
+
# Nếu không tìm thấy table, fallback
|
| 430 |
+
except Exception as e:
|
| 431 |
+
print(f" PPStructure error: {e}, falling back to manual")
|
| 432 |
+
|
| 433 |
+
# Strategy 2: Manual cell detection
|
| 434 |
return ocr_table_manual(img, img_path, backend)
|
| 435 |
|
|
|
|
| 436 |
def ocr_table_manual(img, img_path, backend="paddle"):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 437 |
cells = detect_table_structure(img)
|
| 438 |
|
| 439 |
if cells:
|
|
|
|
| 440 |
ocr_results = []
|
|
|
|
| 441 |
for (x1, y1, x2, y2) in cells:
|
|
|
|
| 442 |
cell_w, cell_h = x2 - x1, y2 - y1
|
| 443 |
img_h, img_w = img.shape[:2]
|
| 444 |
if cell_w > img_w * 0.9 and cell_h > img_h * 0.9:
|
| 445 |
+
continue
|
| 446 |
if cell_w < 15 or cell_h < 15:
|
| 447 |
continue
|
| 448 |
|
|
|
|
| 453 |
cx2 = min(img.shape[1], x2 + pad)
|
| 454 |
cell_img = img[cy1:cy2, cx1:cx2]
|
| 455 |
|
| 456 |
+
text = ocr_cell_improved(cell_img, backend=backend)
|
| 457 |
if text:
|
| 458 |
ocr_results.append({
|
| 459 |
"text": post_process_ocr_text(text),
|
|
|
|
| 469 |
"text": "\n".join(" | ".join(r) for r in rows)
|
| 470 |
}
|
| 471 |
|
|
|
|
| 472 |
return ocr_table_fullimage(img, backend)
|
| 473 |
|
| 474 |
|
| 475 |
+
def ocr_cell_improved(img_cell, backend="paddle"):
|
| 476 |
"""OCR 1 cell — upscale mạnh, multi-preprocessing."""
|
| 477 |
if img_cell.size == 0:
|
| 478 |
return ""
|
| 479 |
|
| 480 |
h, w = img_cell.shape[:2]
|
|
|
|
|
|
|
| 481 |
target_w = max(300, w)
|
| 482 |
if w < target_w:
|
| 483 |
scale = target_w / w
|
| 484 |
img_cell = cv2.resize(img_cell, None, fx=scale, fy=scale,
|
| 485 |
interpolation=cv2.INTER_CUBIC)
|
| 486 |
|
| 487 |
+
if backend == "surya":
|
| 488 |
+
# Chạy Surya trực tiếp
|
| 489 |
+
text = ocr_with_surya(img_cell, langs=["vi", "en"])
|
| 490 |
+
return text.strip()
|
| 491 |
+
|
| 492 |
+
# logic cũ với reader (paddle/easyocr)
|
| 493 |
+
reader = get_paddle_reader('vi') if backend == "paddle" else get_easyocr_reader()
|
| 494 |
+
if reader is None:
|
| 495 |
+
reader = get_easyocr_reader()
|
| 496 |
+
|
| 497 |
best_text = ""
|
| 498 |
best_conf = 0
|
| 499 |
|
| 500 |
for variant in ["color", "binary"]:
|
| 501 |
if variant == "color":
|
|
|
|
| 502 |
img_proc = cv2.bilateralFilter(img_cell, 5, 50, 50)
|
| 503 |
lab = cv2.cvtColor(img_proc, cv2.COLOR_BGR2LAB)
|
| 504 |
l, a, b = cv2.split(lab)
|
|
|
|
| 522 |
|
| 523 |
|
| 524 |
def ocr_table_fullimage(img, backend="paddle"):
|
| 525 |
+
if backend == "surya":
|
| 526 |
+
# Dùng Surya OCR trên toàn bộ ảnh table
|
| 527 |
+
text = ocr_with_surya(img, langs=["vi", "en"])
|
| 528 |
+
lines = [line.strip() for line in text.split("\n") if line.strip()]
|
| 529 |
+
# Với Surya, ta không có bounding box, chỉ trả về một cột
|
| 530 |
+
rows = [[line] for line in lines]
|
| 531 |
+
return {"rows": rows, "text": text}
|
| 532 |
+
|
| 533 |
+
# logic cũ với paddle/easyocr
|
| 534 |
+
reader = get_paddle_reader('vi') if backend == "paddle" else get_easyocr_reader()
|
| 535 |
+
if reader is None:
|
| 536 |
+
reader = get_easyocr_reader()
|
| 537 |
img_proc = preprocess_for_ocr(img, min_width=1500, mode="table")
|
| 538 |
|
| 539 |
items = []
|
|
|
|
| 572 |
rows = group_rows(items, vertical_thresh_ratio=0.6)
|
| 573 |
return {"rows": rows, "text": "\n".join(" | ".join(r) for r in rows)}
|
| 574 |
|
|
|
|
| 575 |
# ============================================================
|
| 576 |
# TABLE STRUCTURE DETECTION (giữ nguyên, có cải thiện nhỏ)
|
| 577 |
# ============================================================
|
|
|
|
| 717 |
if __name__ == "__main__":
|
| 718 |
import sys
|
| 719 |
img = sys.argv[1] if len(sys.argv) > 1 else "test.jpg"
|
| 720 |
+
result, _ = run_pipeline(img, ocr_backend="surya")
|
| 721 |
print(json.dumps(result, ensure_ascii=False, indent=2))
|