Spaces:
Configuration error
Configuration error
File size: 6,508 Bytes
ea9cf0f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 | import torch
from paddleocr import PaddleOCR
# ββ Load model βββββββββββββββββββββββββββββββββββββββββββ
_model = None
def get_model(checkpoint: str = "best.pt"):
global _model
if _model is None:
print(f"[INFO] Loading model from {checkpoint}...")
_model = RTDETR(checkpoint)
return _model
_orig_load = torch.load
def _safe_load(*args, **kwargs):
kwargs.setdefault("weights_only", False)
return _orig_load(*args, **kwargs)
torch.load = _safe_load
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
import cv2, json, os
from pathlib import Path
from ultralytics import RTDETR
# ββ Device: M1 dΓΉng MPS ββββββββββββββββββββββββββββββββββ
DEVICE = (
"mps" if torch.backends.mps.is_available()
else "cpu"
)
print(f"[INFO] Device: {DEVICE}")
# ββ Class config βββββββββββββββββββββββββββββββββββββββββ
CLASS_NAMES = ['note', 'part-drawing', 'table']
# Map sang tΓͺn chuαΊ©n theo Δα» bΓ i
CLASS_DISPLAY = {
'note': 'Note',
'part-drawing': 'PartDrawing',
'table': 'Table',
}
COLORS = {
'note': (0, 165, 255), # cam
'part-drawing': (0, 200, 0), # xanh lΓ‘
'table': (220, 0, 0), # Δα»
}
# ================== OCR Mα»I - HOαΊ T Δα»NG TRΓN MAC M1 + PP-OCRv5 ==================
from paddleocr import PaddleOCR, PPStructureV3 # β SỬA α» ΔΓY: PPStructure β PPStructureV3
import cv2
_ocr_engine = None
_table_engine = None
def get_ocr():
"""OCR thΖ°α»ng cho Note"""
global _ocr_engine
if _ocr_engine is None:
_ocr_engine = PaddleOCR(
use_textline_orientation=True, # thay cho use_angle_cls cΕ©
lang="vi"
)
return _ocr_engine
def get_table_engine():
"""Table structure recognition (giα»― rows/columns)"""
global _table_engine
if _table_engine is None:
_table_engine = PPStructureV3() # β DΓNG PPStructureV3
return _table_engine
def ocr_note(img_path):
"""OCR cho Note"""
ocr = get_ocr()
result = ocr.ocr(img_path) # KHΓNG dΓΉng cls=True nα»―a
if result and result[0]:
return "\n".join([line[1][0] for line in result[0]])
return ""
def ocr_table(img_path):
"""OCR cho Table - Ζ°u tiΓͺn giα»― cαΊ₯u trΓΊc bαΊ£ng"""
try:
engine = get_table_engine()
img = cv2.imread(img_path)
result = engine(img)
return str(result) # Expected output thΖ°α»ng chαΊ₯p nhαΊn dαΊ‘ng nΓ y
except Exception as e:
print(f"[WARN] Table structure failed: {e}, fallback to plain OCR")
return ocr_note(img_path)
# ββ Main pipeline βββββββββββββββββββββββββββββββββββββββββ
def run_pipeline(
image_path: str,
output_dir: str = "outputs",
checkpoint: str = "best.pt",
conf: float = 0.3,
) -> tuple[dict, str]:
"""
ChαΊ‘y full pipeline: detect β crop β OCR β JSON.
Returns: (result_dict, visualized_image_path)
"""
image_path = str(image_path)
img_name = Path(image_path).name
stem = Path(image_path).stem
crop_dir = Path(output_dir) / stem / "crops"
crop_dir.mkdir(parents=True, exist_ok=True)
# 1. Detect
model = get_model(checkpoint)
results = model(
image_path,
imgsz=1024,
conf=conf,
iou=0.5,
device=DEVICE,
verbose=False,
)
img_bgr = cv2.imread(image_path)
if img_bgr is None:
raise ValueError(f"KhΓ΄ng Δα»c Δược αΊ£nh: {image_path}")
objects = []
for i, box in enumerate(results[0].boxes):
x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
cls_idx = int(box.cls[0])
conf_val = round(float(box.conf[0]), 4)
cls_raw = CLASS_NAMES[cls_idx]
cls_show = CLASS_DISPLAY[cls_raw]
# 2. Crop
pad = 4 # padding nhα» quanh bbox
cx1 = max(0, x1 - pad)
cy1 = max(0, y1 - pad)
cx2 = min(img_bgr.shape[1], x2 + pad)
cy2 = min(img_bgr.shape[0], y2 + pad)
crop = img_bgr[cy1:cy2, cx1:cx2]
crop_path = str(crop_dir / f"{cls_show}_{i+1}.jpg")
cv2.imwrite(crop_path, crop, [cv2.IMWRITE_JPEG_QUALITY, 95])
# 3. OCR
ocr_content = None
if cls_raw == 'note':
ocr_content = ocr_note(crop_path)
elif cls_raw == 'table':
ocr_content = ocr_table(crop_path)
objects.append({
"id": i + 1,
"class": cls_show,
"confidence": conf_val,
"bbox": {"x1": x1, "y1": y1, "x2": x2, "y2": y2},
"crop_path": crop_path,
"ocr_content": ocr_content,
})
# 4. VαΊ½ bbox lΓͺn αΊ£nh
color = COLORS[cls_raw]
cv2.rectangle(img_bgr, (x1, y1), (x2, y2), color, 2)
label = f"{cls_show} {conf_val:.2f}"
(tw, th), _ = cv2.getTextSize(
label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)
cv2.rectangle(img_bgr,
(x1, y1 - th - 8), (x1 + tw + 4, y1),
color, -1)
cv2.putText(img_bgr, label,
(x1 + 2, y1 - 4),
cv2.FONT_HERSHEY_SIMPLEX, 0.6,
(255, 255, 255), 2)
# 5. LΖ°u αΊ£nh visualize
vis_path = str(Path(output_dir) / stem / "result_vis.jpg")
cv2.imwrite(vis_path, img_bgr)
# 6. LΖ°u JSON
result = {"image": img_name, "objects": objects}
json_path = str(Path(output_dir) / stem / "result.json")
with open(json_path, "w", encoding="utf-8") as f:
json.dump(result, f, ensure_ascii=False, indent=2)
print(f"[β] {img_name}: {len(objects)} objects β {json_path}")
return result, vis_path
# ββ CLI test nhanh ββββββββββββββββββββββββββββββββββββββββ
if __name__ == "__main__":
import sys
img = sys.argv[1] if len(sys.argv) > 1 else "test.jpg"
result, vis = run_pipeline(img)
print(json.dumps(result, ensure_ascii=False, indent=2))
|