fdsgsfjsfg commited on
Commit
80c2529
·
verified ·
1 Parent(s): 54b86f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -24
app.py CHANGED
@@ -6,6 +6,7 @@ from PIL import Image
6
  import gc
7
  import os
8
  import spaces
 
9
  from transformers import Sam3Model, Sam3Processor
10
 
11
  HF_TOKEN = os.getenv("HF_TOKEN")
@@ -52,38 +53,42 @@ def overlay_masks(image, masks, alpha=0.6):
52
 
53
  @spaces.GPU
54
  def process_text_detection(image, text_query, threshold):
 
55
  if not image or not text_query: return None, "请输入图像和描述词"
56
  try:
57
  model, processor = get_model()
58
  inputs = processor(images=image, text=text_query, return_tensors="pt").to(device)
59
- with torch.no_grad():
60
- outputs = model(**inputs)
61
  results = processor.post_process_instance_segmentation(outputs, threshold=threshold, mask_threshold=0.5, target_sizes=inputs.get("original_sizes").tolist())[0]
62
  masks = results["masks"]
63
- scores = results["scores"].cpu().numpy() if "scores" in results else []
64
  result_img = overlay_masks(image, masks)
65
- if len(masks) > 0:
66
- status = f"✅ 检测到 {len(masks)} 个目标。"
67
- if len(scores) > 0: status += f" 置信度范围: {np.min(scores):.2f}-{np.max(scores):.2f}"
68
- else:
69
- status = "❓ 未找到匹配目标,请尝试调低阈值或修改提示词。"
70
  return result_img, status
71
  except Exception as e:
72
  return image, f"❌ 错误: {str(e)}"
73
 
74
- # 就在这一行,去掉了 theme 相关的参数
75
- with gr.Blocks() as demo:
76
- gr.Markdown("# 🚀 SAM 3 文本自动检测工具")
77
- with gr.Row():
78
- with gr.Column():
79
- t_img_in = gr.Image(type="pil", label="上传原图")
80
- t_query = gr.Textbox(label="输入检测内容(英文)", placeholder="例如: watermark, logo", value="watermark")
81
- t_thresh = gr.Slider(0.1, 0.9, value=0.3, step=0.05, label="灵敏度(越低得越多)")
82
- t_btn = gr.Button("开始自动检测", variant="primary")
83
- with gr.Column():
84
- t_img_out = gr.Image(type="pil", label="检测结果 (遮罩高亮)")
85
- t_info = gr.Textbox(label="状态信息")
86
- t_btn.click(process_text_detection, [t_img_in, t_query, t_thresh], [t_img_out, t_info])
87
-
88
- if __name__ == "__main__":
89
- demo.launch()
 
 
 
 
 
 
 
 
 
 
6
  import gc
7
  import os
8
  import spaces
9
+ import cv2 # 新增:用于图像样本的坐标定位
10
  from transformers import Sam3Model, Sam3Processor
11
 
12
  HF_TOKEN = os.getenv("HF_TOKEN")
 
53
 
54
  @spaces.GPU
55
  def process_text_detection(image, text_query, threshold):
56
+ """文本检测模式"""
57
  if not image or not text_query: return None, "请输入图像和描述词"
58
  try:
59
  model, processor = get_model()
60
  inputs = processor(images=image, text=text_query, return_tensors="pt").to(device)
61
+ with torch.no_grad(): outputs = model(**inputs)
 
62
  results = processor.post_process_instance_segmentation(outputs, threshold=threshold, mask_threshold=0.5, target_sizes=inputs.get("original_sizes").tolist())[0]
63
  masks = results["masks"]
 
64
  result_img = overlay_masks(image, masks)
65
+ status = f"✅ 文本检测完成!找到 {len(masks)} 个目标。" if len(masks) > 0 else "❓ 未找到目标,请调低阈值。"
 
 
 
 
66
  return result_img, status
67
  except Exception as e:
68
  return image, f"❌ 错误: {str(e)}"
69
 
70
+ @spaces.GPU
71
+ def process_sample_detection(main_image, sample_image):
72
+ """样本截图检测模式 (OpenCV 定位 + SAM3 分割)"""
73
+ if not main_image or not sample_image: return None, "请上传主图和样本截图"
74
+ try:
75
+ model, processor = get_model()
76
+
77
+ # 1. 使用 OpenCV 进行模板匹配到截图在主图中的坐标
78
+ main_cv = cv2.cvtColor(np.array(main_image), cv2.COLOR_RGB2BGR)
79
+ sample_cv = cv2.cvtColor(np.array(sample_image), cv2.COLOR_RGB2BGR)
80
+
81
+ # 检查样本是否比主图大
82
+ if sample_cv.shape[0] > main_cv.shape[0] or sample_cv.shape[1] > main_cv.shape[1]:
83
+ return main_image, "❌ 错误:样本截图不能比主图还大!"
84
+
85
+ result = cv2.matchTemplate(main_cv, sample_cv, cv2.TM_CCOEFF_NORMED)
86
+ min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(result)
87
+
88
+ # 如果相似度太低,说明没找到
89
+ if max_val < 0.4:
90
+ return main_image, f"❓ 未在主图中找到该样本 (最高匹配度: {max_val:.2f})。请确保截图来自该原图。"
91
+
92
+ # 计算 Bounding Box [x_min, y_min, x_max, y_max]
93
+ h, w = sample_cv.shape[:2]
94
+ box =