| import os |
| import sys |
| import cv2 |
| import torch |
| import numpy as np |
| from PIL import Image |
|
|
| |
| |
| |
| try: |
| from transformers import pipeline |
| except ImportError: |
| print("错误: 缺少 transformers 库。") |
| print("请运行: pip install transformers accelerate") |
| sys.exit(1) |
|
|
| class Config: |
| |
| model_path = "/data/test/four_corn/sam3" |
| |
| |
| input_path = "/data/test/four_corn/Pakistan_card_img/ss_20260127153633_47_2.jpg" |
| output_dir = "output_sam3_auto" |
| |
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
| |
| points_per_batch = 64 |
|
|
|
|
| |
| |
| |
| def filter_masks(masks, img_area): |
| """ |
| 从 Mask 列表中找到最像“证件”的那个 |
| """ |
| candidates = [] |
| |
| for mask_data in masks: |
| raw_seg = mask_data['segmentation'] |
| |
| |
| |
| if isinstance(raw_seg, torch.Tensor): |
| seg = raw_seg.detach().cpu().numpy() |
| |
| elif isinstance(raw_seg, Image.Image): |
| seg = np.array(raw_seg) |
| |
| elif isinstance(raw_seg, np.ndarray): |
| seg = raw_seg |
| else: |
| |
| seg = np.array(raw_seg) |
|
|
| |
| if seg.dtype != bool: |
| seg = (seg > 0) |
|
|
| |
|
|
| |
| area = np.sum(seg) |
| |
| |
| if area < img_area * 0.05 or area > img_area * 0.90: |
| continue |
| |
| |
| mask_uint8 = (seg * 255).astype(np.uint8) |
| |
| |
| contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| if not contours: continue |
| cnt = max(contours, key=cv2.contourArea) |
| |
| |
| rect = cv2.minAreaRect(cnt) |
| box_area = rect[1][0] * rect[1][1] |
| |
| if box_area == 0: continue |
|
|
| |
| rectangularity = area / box_area |
| |
| candidates.append({ |
| 'mask': seg, |
| 'rect': rect, |
| 'score': rectangularity * np.log(area), |
| 'box_points': cv2.boxPoints(rect) |
| }) |
| |
| |
| if candidates: |
| candidates.sort(key=lambda x: x['score'], reverse=True) |
| return candidates[0] |
| return None |
|
|
| |
| |
| |
| def main(): |
| cfg = Config() |
| |
| |
| print(f">>> 正在加载本地 SAM 3 模型: {cfg.model_path} ...") |
| if not os.path.exists(cfg.model_path): |
| print(f"错误: 路径不存在 {cfg.model_path}") |
| return |
|
|
| try: |
| |
| generator = pipeline( |
| "mask-generation", |
| model=cfg.model_path, |
| device=cfg.device, |
| points_per_batch=cfg.points_per_batch |
| ) |
| except Exception as e: |
| print(f"模型加载失败: {e}") |
| return |
|
|
| print(" -> 模型加载完成") |
|
|
| |
| img_cv2 = cv2.imread(cfg.input_path) |
| if img_cv2 is None: |
| print(f"图片读取失败: {cfg.input_path}") |
| return |
| |
| img_h, img_w = img_cv2.shape[:2] |
| img_area = img_h * img_w |
| |
| |
| img_pil = Image.fromarray(cv2.cvtColor(img_cv2, cv2.COLOR_BGR2RGB)) |
|
|
| |
| print(f">>> 正在全图分割: {os.path.basename(cfg.input_path)} ...") |
| |
| |
| outputs = generator(img_pil) |
| |
| |
| formatted_masks = [] |
| |
| |
| if isinstance(outputs, dict) and 'masks' in outputs: |
| raw_masks = outputs['masks'] |
| for mask in raw_masks: |
| formatted_masks.append({ |
| 'segmentation': mask, |
| 'area': 0 |
| }) |
| |
| |
| elif isinstance(outputs, list): |
| for item in outputs: |
| if isinstance(item, dict) and 'mask' in item: |
| formatted_masks.append({'segmentation': item['mask'], 'area': 0}) |
| elif isinstance(item, Image.Image): |
| formatted_masks.append({'segmentation': item, 'area': 0}) |
| elif isinstance(item, torch.Tensor): |
| formatted_masks.append({'segmentation': item, 'area': 0}) |
| |
| print(f" -> 生成了 {len(formatted_masks)} 个掩膜片段") |
|
|
| |
| best_candidate = filter_masks(formatted_masks, img_area) |
|
|
| |
| vis_img = img_cv2.copy() |
| |
| if best_candidate: |
| print(">>> 找到最佳证件区域!") |
| mask = best_candidate['mask'] |
| box_points = np.int64(best_candidate['box_points']) |
| |
| |
| vis_img[mask] = vis_img[mask] * 0.5 + np.array([0, 255, 0]) * 0.5 |
| |
| |
| cv2.drawContours(vis_img, [box_points], 0, (0, 0, 255), 3) |
| |
| |
| angle = best_candidate['rect'][-1] |
| print(f" -> 旋转角度: {angle:.2f}") |
| else: |
| print(">>> 未找到符合条件的证件区域") |
|
|
| os.makedirs(cfg.output_dir, exist_ok=True) |
| save_path = os.path.join(cfg.output_dir, "auto_sam3_" + os.path.basename(cfg.input_path)) |
| cv2.imwrite(save_path, vis_img) |
| print(f"结果已保存: {save_path}") |
|
|
| if __name__ == "__main__": |
| main() |