import torch from paddleocr import PaddleOCR # ── Load model ─────────────────────────────────────────── _model = None def get_model(checkpoint: str = "best.pt"): global _model if _model is None: print(f"[INFO] Loading model from {checkpoint}...") _model = RTDETR(checkpoint) return _model _orig_load = torch.load def _safe_load(*args, **kwargs): kwargs.setdefault("weights_only", False) return _orig_load(*args, **kwargs) torch.load = _safe_load # ───────────────────────────────────────────────────────── import cv2, json, os from pathlib import Path from ultralytics import RTDETR # ── Device: M1 dùng MPS ────────────────────────────────── DEVICE = ( "mps" if torch.backends.mps.is_available() else "cpu" ) print(f"[INFO] Device: {DEVICE}") # ── Class config ───────────────────────────────────────── CLASS_NAMES = ['note', 'part-drawing', 'table'] # Map sang tên chuẩn theo đề bài CLASS_DISPLAY = { 'note': 'Note', 'part-drawing': 'PartDrawing', 'table': 'Table', } COLORS = { 'note': (0, 165, 255), # cam 'part-drawing': (0, 200, 0), # xanh lá 'table': (220, 0, 0), # đỏ } # ================== OCR MỚI - HOẠT ĐỘNG TRÊN MAC M1 + PP-OCRv5 ================== from paddleocr import PaddleOCR, PPStructureV3 # ← SỬA Ở ĐÂY: PPStructure → PPStructureV3 import cv2 _ocr_engine = None _table_engine = None def get_ocr(): """OCR thường cho Note""" global _ocr_engine if _ocr_engine is None: _ocr_engine = PaddleOCR( use_textline_orientation=True, # thay cho use_angle_cls cũ lang="vi" ) return _ocr_engine def get_table_engine(): """Table structure recognition (giữ rows/columns)""" global _table_engine if _table_engine is None: _table_engine = PPStructureV3() # ← DÙNG PPStructureV3 return _table_engine def ocr_note(img_path): """OCR cho Note""" ocr = get_ocr() result = ocr.ocr(img_path) # KHÔNG dùng cls=True nữa if result and result[0]: return "\n".join([line[1][0] for line in result[0]]) return "" def ocr_table(img_path): """OCR cho Table - ưu tiên giữ cấu trúc bảng""" try: engine = get_table_engine() img = cv2.imread(img_path) result = engine(img) return str(result) # Expected output thường chấp nhận dạng này except Exception as e: print(f"[WARN] Table structure failed: {e}, fallback to plain OCR") return ocr_note(img_path) # ── Main pipeline ───────────────────────────────────────── def run_pipeline( image_path: str, output_dir: str = "outputs", checkpoint: str = "best.pt", conf: float = 0.3, ) -> tuple[dict, str]: """ Chạy full pipeline: detect → crop → OCR → JSON. Returns: (result_dict, visualized_image_path) """ image_path = str(image_path) img_name = Path(image_path).name stem = Path(image_path).stem crop_dir = Path(output_dir) / stem / "crops" crop_dir.mkdir(parents=True, exist_ok=True) # 1. Detect model = get_model(checkpoint) results = model( image_path, imgsz=1024, conf=conf, iou=0.5, device=DEVICE, verbose=False, ) img_bgr = cv2.imread(image_path) if img_bgr is None: raise ValueError(f"Không đọc được ảnh: {image_path}") objects = [] for i, box in enumerate(results[0].boxes): x1, y1, x2, y2 = map(int, box.xyxy[0].tolist()) cls_idx = int(box.cls[0]) conf_val = round(float(box.conf[0]), 4) cls_raw = CLASS_NAMES[cls_idx] cls_show = CLASS_DISPLAY[cls_raw] # 2. Crop pad = 4 # padding nhỏ quanh bbox cx1 = max(0, x1 - pad) cy1 = max(0, y1 - pad) cx2 = min(img_bgr.shape[1], x2 + pad) cy2 = min(img_bgr.shape[0], y2 + pad) crop = img_bgr[cy1:cy2, cx1:cx2] crop_path = str(crop_dir / f"{cls_show}_{i+1}.jpg") cv2.imwrite(crop_path, crop, [cv2.IMWRITE_JPEG_QUALITY, 95]) # 3. OCR ocr_content = None if cls_raw == 'note': ocr_content = ocr_note(crop_path) elif cls_raw == 'table': ocr_content = ocr_table(crop_path) objects.append({ "id": i + 1, "class": cls_show, "confidence": conf_val, "bbox": {"x1": x1, "y1": y1, "x2": x2, "y2": y2}, "crop_path": crop_path, "ocr_content": ocr_content, }) # 4. Vẽ bbox lên ảnh color = COLORS[cls_raw] cv2.rectangle(img_bgr, (x1, y1), (x2, y2), color, 2) label = f"{cls_show} {conf_val:.2f}" (tw, th), _ = cv2.getTextSize( label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2) cv2.rectangle(img_bgr, (x1, y1 - th - 8), (x1 + tw + 4, y1), color, -1) cv2.putText(img_bgr, label, (x1 + 2, y1 - 4), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2) # 5. Lưu ảnh visualize vis_path = str(Path(output_dir) / stem / "result_vis.jpg") cv2.imwrite(vis_path, img_bgr) # 6. Lưu JSON result = {"image": img_name, "objects": objects} json_path = str(Path(output_dir) / stem / "result.json") with open(json_path, "w", encoding="utf-8") as f: json.dump(result, f, ensure_ascii=False, indent=2) print(f"[✓] {img_name}: {len(objects)} objects → {json_path}") return result, vis_path # ── CLI test nhanh ──────────────────────────────────────── if __name__ == "__main__": import sys img = sys.argv[1] if len(sys.argv) > 1 else "test.jpg" result, vis = run_pipeline(img) print(json.dumps(result, ensure_ascii=False, indent=2))