import os import sys import cv2 import torch import numpy as np from PIL import Image # ========================================== # 1. 环境配置 # ========================================== try: from transformers import pipeline except ImportError: print("错误: 缺少 transformers 库。") print("请运行: pip install transformers accelerate") sys.exit(1) class Config: # 模型路径 (包含 config.json 和 safetensors 的文件夹) model_path = "/data/test/four_corn/sam3" # 输入图片路径 input_path = "/data/test/four_corn/Pakistan_card_img/ss_20260127153633_47_2.jpg" output_dir = "output_sam3_auto" # 设备配置 device = "cuda" if torch.cuda.is_available() else "cpu" # 推理参数 points_per_batch = 64 # ========================================== # 2. 筛选逻辑 (核心修复部分) # ========================================== def filter_masks(masks, img_area): """ 从 Mask 列表中找到最像“证件”的那个 """ candidates = [] for mask_data in masks: raw_seg = mask_data['segmentation'] # --- 格式统一化处理 (修复 TypeError 的关键) --- # 1. 如果是 PyTorch Tensor (常见于 CUDA 推理),需转到 CPU 并转为 Numpy if isinstance(raw_seg, torch.Tensor): seg = raw_seg.detach().cpu().numpy() # 2. 如果是 PIL Image,转 Numpy elif isinstance(raw_seg, Image.Image): seg = np.array(raw_seg) # 3. 如果已经是 Numpy,直接用 elif isinstance(raw_seg, np.ndarray): seg = raw_seg else: # 兜底:尝试强制转换 seg = np.array(raw_seg) # 确保是 boolean 类型 (True/False),因为有些输出是 0/255 或 0/1 的 int if seg.dtype != bool: seg = (seg > 0) # --- 以下逻辑保持不变 --- # 计算面积 (True 的像素总数) area = np.sum(seg) # 1. 面积初步筛选 (比如小于 5% 或大于 90% 过滤掉) if area < img_area * 0.05 or area > img_area * 0.90: continue # 2. 转为 uint8 以供 OpenCV 处理 mask_uint8 = (seg * 255).astype(np.uint8) # 3. 计算轮廓和外接矩形 contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if not contours: continue cnt = max(contours, key=cv2.contourArea) # 最小外接矩形 (Rotated Rect) rect = cv2.minAreaRect(cnt) box_area = rect[1][0] * rect[1][1] if box_area == 0: continue # 4. 计算“矩形度” (Mask面积 / 外接矩形面积) rectangularity = area / box_area candidates.append({ 'mask': seg, # 保存 boolean mask 用于绘图 'rect': rect, # ((cx, cy), (w, h), angle) 'score': rectangularity * np.log(area), # 综合评分 'box_points': cv2.boxPoints(rect) }) # 按评分排序,取第一名 if candidates: candidates.sort(key=lambda x: x['score'], reverse=True) return candidates[0] return None # ========================================== # 3. 主程序 # ========================================== def main(): cfg = Config() # 1. 加载模型 print(f">>> 正在加载本地 SAM 3 模型: {cfg.model_path} ...") if not os.path.exists(cfg.model_path): print(f"错误: 路径不存在 {cfg.model_path}") return try: # 加载 mask-generation 管道 generator = pipeline( "mask-generation", model=cfg.model_path, device=cfg.device, points_per_batch=cfg.points_per_batch ) except Exception as e: print(f"模型加载失败: {e}") return print(" -> 模型加载完成") # 2. 读取图片 img_cv2 = cv2.imread(cfg.input_path) if img_cv2 is None: print(f"图片读取失败: {cfg.input_path}") return img_h, img_w = img_cv2.shape[:2] img_area = img_h * img_w # 转为 PIL Image 供模型使用 img_pil = Image.fromarray(cv2.cvtColor(img_cv2, cv2.COLOR_BGR2RGB)) # 3. 自动分割全图 print(f">>> 正在全图分割: {os.path.basename(cfg.input_path)} ...") # --- 运行推理 --- outputs = generator(img_pil) # --- 数据格式标准化 --- formatted_masks = [] # 情况 A: outputs 是字典 (标准格式: {'masks': [...], 'scores': [...]}) if isinstance(outputs, dict) and 'masks' in outputs: raw_masks = outputs['masks'] for mask in raw_masks: formatted_masks.append({ 'segmentation': mask, 'area': 0 # 占位 }) # 情况 B: outputs 是列表 (兼容旧格式) elif isinstance(outputs, list): for item in outputs: if isinstance(item, dict) and 'mask' in item: formatted_masks.append({'segmentation': item['mask'], 'area': 0}) elif isinstance(item, Image.Image): formatted_masks.append({'segmentation': item, 'area': 0}) elif isinstance(item, torch.Tensor): # 处理直接返回 Tensor 列表的情况 formatted_masks.append({'segmentation': item, 'area': 0}) print(f" -> 生成了 {len(formatted_masks)} 个掩膜片段") # 4. 筛选出证件 best_candidate = filter_masks(formatted_masks, img_area) # 5. 绘图保存 vis_img = img_cv2.copy() if best_candidate: print(">>> 找到最佳证件区域!") mask = best_candidate['mask'] box_points = np.int64(best_candidate['box_points']) # 绿色 Mask vis_img[mask] = vis_img[mask] * 0.5 + np.array([0, 255, 0]) * 0.5 # 红色 旋转矩形框 cv2.drawContours(vis_img, [box_points], 0, (0, 0, 255), 3) # 角度 angle = best_candidate['rect'][-1] print(f" -> 旋转角度: {angle:.2f}") else: print(">>> 未找到符合条件的证件区域") os.makedirs(cfg.output_dir, exist_ok=True) save_path = os.path.join(cfg.output_dir, "auto_sam3_" + os.path.basename(cfg.input_path)) cv2.imwrite(save_path, vis_img) print(f"结果已保存: {save_path}") if __name__ == "__main__": main()