fdsgsfjsfg commited on
Commit
3ae5094
·
verified ·
1 Parent(s): 0a8b289

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +153 -101
app.py CHANGED
@@ -1,44 +1,66 @@
 
 
 
1
  import gradio as gr
 
2
  import torch
3
  import numpy as np
4
  import matplotlib.pyplot as plt
5
  import matplotlib
6
  from PIL import Image, ImageDraw
7
- import gc
8
- import os
9
- import spaces
10
- import cv2
11
  from transformers import (
12
- Sam3Model, Sam3Processor,
13
- Sam3TrackerModel, Sam3TrackerProcessor,
 
 
14
  )
15
 
 
 
 
 
16
  HF_TOKEN = os.getenv("HF_TOKEN")
17
  MODELS = {}
18
- device = "cuda"
 
 
 
 
19
 
20
  def cleanup_memory():
21
  if MODELS:
22
  MODELS.clear()
23
  gc.collect()
24
- torch.cuda.empty_cache()
 
 
25
 
26
  def get_model(model_type):
27
  if model_type in MODELS:
28
  return MODELS[model_type]
 
29
  cleanup_memory()
30
  print(f"⏳ 正在加载 {model_type} 模型...")
31
- if model_type == "sam3_image_text":
32
- model = Sam3Model.from_pretrained("facebook/sam3", token=HF_TOKEN).to(device)
33
- processor = Sam3Processor.from_pretrained("facebook/sam3", token=HF_TOKEN)
34
- elif model_type == "sam3_image_tracker":
35
- model = Sam3TrackerModel.from_pretrained("facebook/sam3", token=HF_TOKEN).to(device)
36
- processor = Sam3TrackerProcessor.from_pretrained("facebook/sam3", token=HF_TOKEN)
37
- else:
38
- raise ValueError(f"未知模型类型: {model_type}")
39
- MODELS[model_type] = (model, processor)
40
- print(f" {model_type} 加载完成。")
41
- return MODELS[model_type]
 
 
 
 
 
 
 
42
 
43
  def overlay_masks(image, masks, alpha=0.6):
44
  if image is None:
@@ -46,22 +68,30 @@ def overlay_masks(image, masks, alpha=0.6):
46
  if isinstance(image, np.ndarray):
47
  image = Image.fromarray(image)
48
  image = image.convert("RGBA")
 
49
  if masks is None or len(masks) == 0:
50
  return image.convert("RGB")
 
51
  if isinstance(masks, torch.Tensor):
52
- masks = masks.cpu().numpy()
53
  masks = masks.astype(np.uint8)
54
- if masks.ndim == 4: masks = masks[0]
55
- if masks.ndim == 3 and masks.shape[0] == 1: masks = masks[0]
56
- if masks.ndim == 2: masks = [masks]
 
 
 
 
 
57
  n_masks = len(masks)
58
  try:
59
  cmap = matplotlib.colormaps["rainbow"].resampled(max(n_masks, 1))
60
  except AttributeError:
61
  cmap = plt.get_cmap("rainbow", max(n_masks, 1))
 
62
  overlay_layer = Image.new("RGBA", image.size, (0, 0, 0, 0))
63
  for i, mask in enumerate(masks):
64
- mask_img = Image.fromarray((mask * 255).astype(np.uint8))
65
  if mask_img.size != image.size:
66
  mask_img = mask_img.resize(image.size, resample=Image.NEAREST)
67
  rgb = [int(x * 255) for x in cmap(i)[:3]]
@@ -69,136 +99,155 @@ def overlay_masks(image, masks, alpha=0.6):
69
  mask_alpha = mask_img.point(lambda v: int(v * alpha) if v > 0 else 0)
70
  color_layer.putalpha(mask_alpha)
71
  overlay_layer = Image.alpha_composite(overlay_layer, color_layer)
 
72
  return Image.alpha_composite(image, overlay_layer).convert("RGB")
73
 
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  def multi_scale_template_match(main_cv, sample_cv, min_scale=0.5, max_scale=1.5, steps=15):
76
- """
77
- 多尺度 + 多方法模板匹配,大幅提升匹配成功率。
78
- 返回 (best_score, best_loc, best_w, best_h) 或 None。
79
- """
80
  main_gray = cv2.cvtColor(main_cv, cv2.COLOR_BGR2GRAY)
81
  sample_gray = cv2.cvtColor(sample_cv, cv2.COLOR_BGR2GRAY)
82
-
83
  sh, sw = sample_gray.shape[:2]
84
  mh, mw = main_gray.shape[:2]
85
-
86
- # 多种匹配方法
87
  methods = [
88
  cv2.TM_CCOEFF_NORMED,
89
  cv2.TM_CCORR_NORMED,
90
  ]
91
-
92
  best_score = -1
93
  best_loc = None
94
  best_w, best_h = sw, sh
95
-
96
  for scale in np.linspace(min_scale, max_scale, steps):
97
  new_w = int(sw * scale)
98
  new_h = int(sh * scale)
99
-
100
- # 跳过比主图还大的尺度
101
  if new_w > mw or new_h > mh:
102
  continue
103
- # 跳过太小的尺度
104
  if new_w < 10 or new_h < 10:
105
  continue
106
-
107
  resized_sample = cv2.resize(sample_gray, (new_w, new_h))
108
-
109
  for method in methods:
110
  result = cv2.matchTemplate(main_gray, resized_sample, method)
111
  _, max_val, _, max_loc = cv2.minMaxLoc(result)
112
-
113
  if max_val > best_score:
114
  best_score = max_val
115
  best_loc = max_loc
116
  best_w, best_h = new_w, new_h
117
-
118
  if best_loc is None:
119
  return None
120
-
121
- return best_score, best_loc, best_w, best_h
122
-
123
 
124
- def draw_box_on_image(image, box, color="lime", width=3):
125
- """在图像上画一个矩形框,用于预览匹配位置。"""
126
- if isinstance(image, np.ndarray):
127
- image = Image.fromarray(image)
128
- draw_img = image.copy()
129
- draw = ImageDraw.Draw(draw_img)
130
- x1, y1, x2, y2 = box
131
- draw.rectangle([x1, y1, x2, y2], outline=color, width=width)
132
- return draw_img
133
 
134
 
135
- # ========== 文本描述检测 ==========
136
  @spaces.GPU
137
  def process_text_detection(image, text_query, threshold):
138
- if not image or not text_query:
139
- return None, "请输入图像和描述词"
140
  try:
141
  model, processor = get_model("sam3_image_text")
142
- inputs = processor(
143
- images=image, text=text_query, return_tensors="pt"
144
- ).to(device)
145
  with torch.no_grad():
146
  outputs = model(**inputs)
 
147
  results = processor.post_process_instance_segmentation(
148
- outputs, threshold=threshold, mask_threshold=0.5,
149
- target_sizes=inputs.get("original_sizes").tolist()
 
 
150
  )[0]
151
- masks = results["masks"]
152
- result_img = overlay_masks(image, masks)
153
- if len(masks) > 0:
154
- status = f"✅ 文本检测完成!找到 {len(masks)} 个目标。"
 
 
 
 
155
  else:
156
  status = "❓ 未找到目标,请调低阈值。"
157
- return result_img, status
 
158
  except Exception as e:
159
- return image, f"❌ 错误: {str(e)}"
160
 
161
 
162
- # ========== 样本截图检测 ==========
163
  @spaces.GPU
164
  def process_sample_detection(main_image, sample_image, match_threshold):
165
- if not main_image or not sample_image:
166
- return None, "请上传主图和样本截图"
167
  try:
168
  model, processor = get_model("sam3_image_tracker")
169
 
170
  main_cv = cv2.cvtColor(np.array(main_image), cv2.COLOR_RGB2BGR)
171
  sample_cv = cv2.cvtColor(np.array(sample_image), cv2.COLOR_RGB2BGR)
172
 
173
- # Step 1: 多尺度模板��配
174
  match = multi_scale_template_match(main_cv, sample_cv)
175
-
176
  if match is None:
177
- return main_image, "❌ 样本图太大或无法匹配。"
178
-
179
  best_score, best_loc, best_w, best_h = match
 
180
 
181
  if best_score < match_threshold:
182
- # 即使匹配度不够,也显示最佳匹配位置供参考
183
- box = [best_loc[0], best_loc[1], best_loc[0] + best_w, best_loc[1] + best_h]
184
  preview = draw_box_on_image(main_image, box, color="red")
185
- return preview, (
186
  f"❓ 匹配度不足 (最高: {best_score:.2f},阈值: {match_threshold:.2f})。\n"
187
  f"红框为最佳匹配位置,可尝试降低阈值或使用更清晰的截图。"
188
  )
189
 
190
- box = [
191
- best_loc[0],
192
- best_loc[1],
193
- best_loc[0] + best_w,
194
- best_loc[1] + best_h
195
- ]
196
-
197
- # Step 2: 用 Sam3Tracker 的 box prompt 做精细分割
198
  inputs = processor(
199
  images=main_image,
200
  input_boxes=[[box]],
201
- return_tensors="pt"
202
  ).to(device)
203
 
204
  with torch.no_grad():
@@ -207,31 +256,29 @@ def process_sample_detection(main_image, sample_image, match_threshold):
207
  masks = processor.post_process_masks(
208
  outputs.pred_masks.cpu(),
209
  inputs["original_sizes"],
210
- binarize=True
211
  )[0]
212
 
213
  if masks.ndim == 4:
214
- if hasattr(outputs, 'iou_scores') and outputs.iou_scores is not None:
215
  scores = outputs.iou_scores.cpu().numpy()[0, 0]
216
  best_idx = np.argmax(scores)
217
- masks = masks[0, best_idx:best_idx+1]
218
  else:
219
  masks = masks[0, 0:1]
220
 
221
- result_img = overlay_masks(main_image, masks)
222
- # 在结果上也画出匹配框
223
- result_img = draw_box_on_image(result_img, box, color="lime")
224
-
225
- return result_img, (
226
  f"✅ 样本检测成功!\n"
227
  f"匹配度: {best_score:.2f} | 匹配位置: ({box[0]}, {box[1]}) → ({box[2]}, {box[3]})"
228
  )
229
-
230
  except Exception as e:
231
- return main_image, f"❌ 错误: {str(e)}"
232
 
233
 
234
- # ========== Gradio 界面 ==========
235
  with gr.Blocks() as demo:
236
  gr.Markdown("# 🚀 SAM 3 自动检测工具 (双模式)")
237
 
@@ -244,12 +291,14 @@ with gr.Blocks() as demo:
244
  t_thresh = gr.Slider(0.1, 0.9, value=0.3, step=0.05, label="灵敏度")
245
  t_btn = gr.Button("开始文本检测", variant="primary")
246
  with gr.Column():
247
- t_img_out = gr.Image(type="pil", label="检测结果")
 
248
  t_info = gr.Textbox(label="状态信息")
249
  t_btn.click(
250
  process_text_detection,
251
  [t_img_in, t_query, t_thresh],
252
- [t_img_out, t_info]
 
253
  )
254
 
255
  with gr.Tab("🖼️ 样本截图检测"):
@@ -267,13 +316,16 @@ with gr.Blocks() as demo:
267
  s_thresh = gr.Slider(0.1, 0.9, value=0.25, step=0.05, label="匹配阈值(越低越容易匹配)")
268
  s_btn = gr.Button("开始样本检测", variant="primary")
269
  with gr.Column():
270
- s_img_out = gr.Image(type="pil", label="检测结果")
 
271
  s_info = gr.Textbox(label="状态信息", lines=3)
272
  s_btn.click(
273
  process_sample_detection,
274
  [s_img_main, s_img_sample, s_thresh],
275
- [s_img_out, s_info]
 
276
  )
277
 
 
278
  if __name__ == "__main__":
279
- demo.launch()
 
1
+ import os
2
+ import gc
3
+ import cv2
4
  import gradio as gr
5
+ import spaces
6
  import torch
7
  import numpy as np
8
  import matplotlib.pyplot as plt
9
  import matplotlib
10
  from PIL import Image, ImageDraw
11
+ import transformers
12
+ import pydantic
 
 
13
  from transformers import (
14
+ Sam3Model,
15
+ Sam3Processor,
16
+ Sam3TrackerModel,
17
+ Sam3TrackerProcessor,
18
  )
19
 
20
+ print("torch:", torch.__version__)
21
+ print("transformers:", transformers.__version__)
22
+ print("pydantic:", pydantic.__version__)
23
+
24
  HF_TOKEN = os.getenv("HF_TOKEN")
25
  MODELS = {}
26
+ device = "cuda" if torch.cuda.is_available() else "cpu"
27
+
28
+ if device != "cuda":
29
+ raise RuntimeError("CUDA 不可用,SAM3 无法运行")
30
+
31
 
32
  def cleanup_memory():
33
  if MODELS:
34
  MODELS.clear()
35
  gc.collect()
36
+ if torch.cuda.is_available():
37
+ torch.cuda.empty_cache()
38
+
39
 
40
  def get_model(model_type):
41
  if model_type in MODELS:
42
  return MODELS[model_type]
43
+
44
  cleanup_memory()
45
  print(f"⏳ 正在加载 {model_type} 模型...")
46
+
47
+ try:
48
+ if model_type == "sam3_image_text":
49
+ model = Sam3Model.from_pretrained("facebook/sam3", token=HF_TOKEN).to(device)
50
+ processor = Sam3Processor.from_pretrained("facebook/sam3", token=HF_TOKEN)
51
+ elif model_type == "sam3_image_tracker":
52
+ model = Sam3TrackerModel.from_pretrained("facebook/sam3", token=HF_TOKEN).to(device)
53
+ processor = Sam3TrackerProcessor.from_pretrained("facebook/sam3", token=HF_TOKEN)
54
+ else:
55
+ raise ValueError(f"未知模型类型: {model_type}")
56
+
57
+ MODELS[model_type] = (model, processor)
58
+ print(f"✅ {model_type} 加载完成。")
59
+ return MODELS[model_type]
60
+ except Exception as e:
61
+ cleanup_memory()
62
+ raise RuntimeError(f"{model_type} 加载失败: {e}")
63
+
64
 
65
  def overlay_masks(image, masks, alpha=0.6):
66
  if image is None:
 
68
  if isinstance(image, np.ndarray):
69
  image = Image.fromarray(image)
70
  image = image.convert("RGBA")
71
+
72
  if masks is None or len(masks) == 0:
73
  return image.convert("RGB")
74
+
75
  if isinstance(masks, torch.Tensor):
76
+ masks = masks.detach().cpu().numpy()
77
  masks = masks.astype(np.uint8)
78
+
79
+ if masks.ndim == 4:
80
+ masks = masks[0]
81
+ if masks.ndim == 3 and masks.shape[0] == 1:
82
+ masks = masks[0]
83
+ if masks.ndim == 2:
84
+ masks = [masks]
85
+
86
  n_masks = len(masks)
87
  try:
88
  cmap = matplotlib.colormaps["rainbow"].resampled(max(n_masks, 1))
89
  except AttributeError:
90
  cmap = plt.get_cmap("rainbow", max(n_masks, 1))
91
+
92
  overlay_layer = Image.new("RGBA", image.size, (0, 0, 0, 0))
93
  for i, mask in enumerate(masks):
94
+ mask_img = Image.fromarray((mask > 0).astype(np.uint8) * 255)
95
  if mask_img.size != image.size:
96
  mask_img = mask_img.resize(image.size, resample=Image.NEAREST)
97
  rgb = [int(x * 255) for x in cmap(i)[:3]]
 
99
  mask_alpha = mask_img.point(lambda v: int(v * alpha) if v > 0 else 0)
100
  color_layer.putalpha(mask_alpha)
101
  overlay_layer = Image.alpha_composite(overlay_layer, color_layer)
102
+
103
  return Image.alpha_composite(image, overlay_layer).convert("RGB")
104
 
105
 
106
+ def masks_to_binary_mask(masks, image_size):
107
+ """把多个 mask 合并成一张二值 mask。白色=目标区域"""
108
+ if masks is None:
109
+ return None
110
+
111
+ if isinstance(masks, torch.Tensor):
112
+ masks = masks.detach().float().cpu().numpy()
113
+
114
+ masks = np.array(masks)
115
+
116
+ if masks.ndim == 4:
117
+ masks = masks[0]
118
+ if masks.ndim == 3 and masks.shape[0] == 1:
119
+ masks = masks[0]
120
+
121
+ w, h = image_size
122
+ combined = np.zeros((h, w), dtype=np.uint8)
123
+
124
+ if masks.ndim == 2:
125
+ combined = (masks > 0).astype(np.uint8) * 255
126
+ elif masks.ndim == 3:
127
+ for m in masks:
128
+ m = np.array(m)
129
+ if m.shape != (h, w):
130
+ m_img = Image.fromarray((m > 0).astype(np.uint8) * 255)
131
+ m_img = m_img.resize((w, h), resample=Image.NEAREST)
132
+ m = np.array(m_img) > 0
133
+ combined = np.maximum(combined, (m > 0).astype(np.uint8) * 255)
134
+
135
+ return Image.fromarray(combined, mode="L")
136
+
137
+
138
+ def draw_box_on_image(image, box, color="lime", width=3):
139
+ """在图像上画一个矩形框,用于预览匹配位置。"""
140
+ if isinstance(image, np.ndarray):
141
+ image = Image.fromarray(image)
142
+ draw_img = image.copy()
143
+ draw = ImageDraw.Draw(draw_img)
144
+ x1, y1, x2, y2 = box
145
+ draw.rectangle([x1, y1, x2, y2], outline=color, width=width)
146
+ return draw_img
147
+
148
+
149
  def multi_scale_template_match(main_cv, sample_cv, min_scale=0.5, max_scale=1.5, steps=15):
150
+ """多尺度 + 多方法模板匹配。"""
 
 
 
151
  main_gray = cv2.cvtColor(main_cv, cv2.COLOR_BGR2GRAY)
152
  sample_gray = cv2.cvtColor(sample_cv, cv2.COLOR_BGR2GRAY)
153
+
154
  sh, sw = sample_gray.shape[:2]
155
  mh, mw = main_gray.shape[:2]
156
+
 
157
  methods = [
158
  cv2.TM_CCOEFF_NORMED,
159
  cv2.TM_CCORR_NORMED,
160
  ]
161
+
162
  best_score = -1
163
  best_loc = None
164
  best_w, best_h = sw, sh
165
+
166
  for scale in np.linspace(min_scale, max_scale, steps):
167
  new_w = int(sw * scale)
168
  new_h = int(sh * scale)
169
+
 
170
  if new_w > mw or new_h > mh:
171
  continue
 
172
  if new_w < 10 or new_h < 10:
173
  continue
174
+
175
  resized_sample = cv2.resize(sample_gray, (new_w, new_h))
176
+
177
  for method in methods:
178
  result = cv2.matchTemplate(main_gray, resized_sample, method)
179
  _, max_val, _, max_loc = cv2.minMaxLoc(result)
 
180
  if max_val > best_score:
181
  best_score = max_val
182
  best_loc = max_loc
183
  best_w, best_h = new_w, new_h
184
+
185
  if best_loc is None:
186
  return None
 
 
 
187
 
188
+ return best_score, best_loc, best_w, best_h
 
 
 
 
 
 
 
 
189
 
190
 
 
191
  @spaces.GPU
192
  def process_text_detection(image, text_query, threshold):
193
+ if image is None or not text_query:
194
+ return None, None, "请输入图像和描述词"
195
  try:
196
  model, processor = get_model("sam3_image_text")
197
+ inputs = processor(images=image, text=text_query, return_tensors="pt").to(device)
 
 
198
  with torch.no_grad():
199
  outputs = model(**inputs)
200
+
201
  results = processor.post_process_instance_segmentation(
202
+ outputs,
203
+ threshold=threshold,
204
+ mask_threshold=0.5,
205
+ target_sizes=inputs.get("original_sizes").tolist(),
206
  )[0]
207
+
208
+ masks = results.get("masks")
209
+ mask_img = masks_to_binary_mask(masks, image.size)
210
+ preview_img = overlay_masks(image, masks)
211
+
212
+ count = 0 if masks is None else len(masks)
213
+ if count > 0:
214
+ status = f"✅ 文本检测完成!找到 {count} 个目标。"
215
  else:
216
  status = "❓ 未找到目标,请调低阈值。"
217
+
218
+ return mask_img, preview_img, status
219
  except Exception as e:
220
+ return None, image, f"❌ 错误: {str(e)}"
221
 
222
 
 
223
  @spaces.GPU
224
  def process_sample_detection(main_image, sample_image, match_threshold):
225
+ if main_image is None or sample_image is None:
226
+ return None, None, "请上传主图和样本截图"
227
  try:
228
  model, processor = get_model("sam3_image_tracker")
229
 
230
  main_cv = cv2.cvtColor(np.array(main_image), cv2.COLOR_RGB2BGR)
231
  sample_cv = cv2.cvtColor(np.array(sample_image), cv2.COLOR_RGB2BGR)
232
 
 
233
  match = multi_scale_template_match(main_cv, sample_cv)
 
234
  if match is None:
235
+ return None, main_image, "❌ 样本图太大或无法匹配。"
236
+
237
  best_score, best_loc, best_w, best_h = match
238
+ box = [best_loc[0], best_loc[1], best_loc[0] + best_w, best_loc[1] + best_h]
239
 
240
  if best_score < match_threshold:
 
 
241
  preview = draw_box_on_image(main_image, box, color="red")
242
+ return None, preview, (
243
  f"❓ 匹配度不足 (最高: {best_score:.2f},阈值: {match_threshold:.2f})。\n"
244
  f"红框为最佳匹配位置,可尝试降低阈值或使用更清晰的截图。"
245
  )
246
 
 
 
 
 
 
 
 
 
247
  inputs = processor(
248
  images=main_image,
249
  input_boxes=[[box]],
250
+ return_tensors="pt",
251
  ).to(device)
252
 
253
  with torch.no_grad():
 
256
  masks = processor.post_process_masks(
257
  outputs.pred_masks.cpu(),
258
  inputs["original_sizes"],
259
+ binarize=True,
260
  )[0]
261
 
262
  if masks.ndim == 4:
263
+ if hasattr(outputs, "iou_scores") and outputs.iou_scores is not None:
264
  scores = outputs.iou_scores.cpu().numpy()[0, 0]
265
  best_idx = np.argmax(scores)
266
+ masks = masks[0, best_idx:best_idx + 1]
267
  else:
268
  masks = masks[0, 0:1]
269
 
270
+ mask_img = masks_to_binary_mask(masks, main_image.size)
271
+ preview_img = overlay_masks(main_image, masks)
272
+ preview_img = draw_box_on_image(preview_img, box, color="lime")
273
+
274
+ return mask_img, preview_img, (
275
  f"✅ 样本检测成功!\n"
276
  f"匹配度: {best_score:.2f} | 匹配位置: ({box[0]}, {box[1]}) → ({box[2]}, {box[3]})"
277
  )
 
278
  except Exception as e:
279
+ return None, main_image, f"❌ 错误: {str(e)}"
280
 
281
 
 
282
  with gr.Blocks() as demo:
283
  gr.Markdown("# 🚀 SAM 3 自动检测工具 (双模式)")
284
 
 
291
  t_thresh = gr.Slider(0.1, 0.9, value=0.3, step=0.05, label="灵敏度")
292
  t_btn = gr.Button("开始文本检测", variant="primary")
293
  with gr.Column():
294
+ t_mask_out = gr.Image(type="pil", label="二值 Mask")
295
+ t_preview_out = gr.Image(type="pil", label="检测预览")
296
  t_info = gr.Textbox(label="状态信息")
297
  t_btn.click(
298
  process_text_detection,
299
  [t_img_in, t_query, t_thresh],
300
+ [t_mask_out, t_preview_out, t_info],
301
+ api_name="process_text_detection",
302
  )
303
 
304
  with gr.Tab("🖼️ 样本截图检测"):
 
316
  s_thresh = gr.Slider(0.1, 0.9, value=0.25, step=0.05, label="匹配阈值(越低越容易匹配)")
317
  s_btn = gr.Button("开始样本检测", variant="primary")
318
  with gr.Column():
319
+ s_mask_out = gr.Image(type="pil", label="二值 Mask")
320
+ s_preview_out = gr.Image(type="pil", label="检测预览")
321
  s_info = gr.Textbox(label="状态信息", lines=3)
322
  s_btn.click(
323
  process_sample_detection,
324
  [s_img_main, s_img_sample, s_thresh],
325
+ [s_mask_out, s_preview_out, s_info],
326
+ api_name="process_sample_detection",
327
  )
328
 
329
+
330
  if __name__ == "__main__":
331
+ demo.launch()