import cv2 import os import matplotlib.pyplot as plt from detectron2.engine import DefaultPredictor from detectron2.config import get_cfg from detectron2 import model_zoo from detectron2.utils.visualizer import Visualizer from detectron2.data import MetadataCatalog # Danh sách danh mục class. # CHÚ Ý: Cần match với thứ tự các class trong annotation file COCO mà bạn có. classes = ["PartDrawing", "Note", "Table"] def get_inference_model(): cfg = get_cfg() cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml")) # Load model Weights đã train cfg.MODEL.WEIGHTS = os.path.join("output_model", "model_final.pth") cfg.MODEL.ROI_HEADS.NUM_CLASSES = 3 # Đặt Threshold để lọc confidence (0.5 là gợi ý, có thể tuỳ chỉnh) cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # Yêu cầu CHẠY BẰNG CPU theo mong muốn của user cfg.MODEL.DEVICE = "cpu" predictor = DefaultPredictor(cfg) return predictor, cfg def run_inference(image_path): print(f"Đang inference file {image_path} ở chế độ CPU...") predictor, cfg = get_inference_model() # Đọc ảnh img = cv2.imread(image_path) if img is None: print("Lỗi không thể đọc ảnh!") return outputs = predictor(img) print("Dự đoán Bounding Box:", outputs["instances"].pred_boxes) print("Scores dự đoán:", outputs["instances"].scores) print("Class IDs:", outputs["instances"].pred_classes) # Hiển thị kết quả bằng Visualizer # Tạo một Custom Metadata thay vì đè lên mặc định (tránh lỗi cache của COCO train) MetadataCatalog.get("tech_draw_inference").set(thing_classes=classes) v = Visualizer(img[:, :, ::-1], MetadataCatalog.get("tech_draw_inference"), scale=1.2) out = v.draw_instance_predictions(outputs["instances"].to("cpu")) result_img = out.get_image()[:, :, ::-1] # Lưu file kết quả save_path = "result_" + os.path.basename(image_path) cv2.imwrite(save_path, result_img) print(f"Lưu kết quả tại: {save_path}") # Visualize lên màn hình bằng matplotlib (hiệu quả trên Linux/jupyter) plt.figure(figsize=(15, 10)) plt.imshow(cv2.cvtColor(result_img, cv2.COLOR_BGR2RGB)) plt.axis('off') plt.title("Obeject Detection Result") plt.show() import argparse if __name__ == "__main__": import sys # Cho phép truyền đường dẫn ảnh trực tiếp qua command line if len(sys.argv) > 1: test_image = sys.argv[1] else: # Nếu không truyền, dùng ảnh test mặc định (bạn cần thay đường dẫn bên dưới bằng ảnh có thật) test_image = r"/media/quyet/01DAD374BE175C40/Technical_Draw_Detection/Datasets/Dataset_main/valid/25_jpg.rf.e364c072dc880644c0eee4ece910ac8a.jpg" # Edit here if os.path.exists(test_image): run_inference(test_image) else: print(f"Không tìm thấy ảnh tại đường dẫn: {test_image}") print("Sử dụng lệnh: python Detection/inference.py <đường_dẫn_tới_ảnh>")