fdsgsfjsfg commited on
Commit
e4a9ffa
·
verified ·
1 Parent(s): 6253770

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -207
app.py CHANGED
@@ -1,141 +1,21 @@
1
- import subprocess
2
- import sys
3
- import os
4
- import re
5
- import glob
6
-
7
- # ============================================================
8
- # ✅ 终极修复:在 import transformers 之前,
9
- # 直接修改已安装的 transformers 包中的源码文件,
10
- # 让 initializer_range 字段同时接受 int 和 float。
11
- # 这样即使 ZeroGPU 重新反序列化模型也不会报错。
12
- # ============================================================
13
-
14
- def patch_transformers_source():
15
- """
16
- 找到 transformers 包中所有 Sam3 相关的 configuration 文件,
17
- 把 initializer_range 的类型注解从严格的 int 或 float 改为兼容的 Union 类型,
18
- 或者直接移除类型校验。
19
- """
20
- import transformers
21
- pkg_dir = os.path.dirname(transformers.__file__)
22
-
23
- # 找到所有可能的配置文件
24
- patterns = [
25
- os.path.join(pkg_dir, "models", "sam3", "*.py"),
26
- os.path.join(pkg_dir, "models", "sam3", "**", "*.py"),
27
- ]
28
-
29
- files_to_check = []
30
- for pattern in patterns:
31
- files_to_check.extend(glob.glob(pattern, recursive=True))
32
-
33
- if not files_to_check:
34
- print(f"⚠️ 未找到 sam3 模型文件,尝试搜索整个 transformers 目录...")
35
- # 搜索所有包含 Sam3 和 initializer_range 的文件
36
- result = subprocess.run(
37
- ["grep", "-rl", "initializer_range", os.path.join(pkg_dir, "models")],
38
- capture_output=True, text=True
39
- )
40
- if result.stdout:
41
- all_files = result.stdout.strip().split("\n")
42
- # 只处理 sam3 相关的
43
- files_to_check = [f for f in all_files if "sam3" in f.lower() or "sam_3" in f.lower()]
44
- if not files_to_check:
45
- # 如果没找到 sam3 特定的,搜索 configuration 文件
46
- files_to_check = [f for f in all_files if "configuration" in f.lower()]
47
-
48
- patched_count = 0
49
- for filepath in files_to_check:
50
- try:
51
- with open(filepath, "r") as f:
52
- content = f.read()
53
-
54
- if "initializer_range" not in content:
55
- continue
56
-
57
- original = content
58
-
59
- # 策略1: 把 initializer_range: int 改为 initializer_range: float
60
- content = re.sub(
61
- r'(initializer_range\s*:\s*)int(\s*=)',
62
- r'\1float\2',
63
- content
64
- )
65
-
66
- # 策略2: 把 initializer_range: int = 0 改为 initializer_range: float = 0.0
67
- content = re.sub(
68
- r'(initializer_range\s*:\s*\w+\s*=\s*)(\d+)(\s*[,\n\)])',
69
- lambda m: f'{m.group(1)}{float(int(m.group(2)))}{m.group(3)}',
70
- content
71
- )
72
-
73
- # 策略3: 如果有 validator 或 field_validator 针对 initializer_range 的严格类型检查
74
- # 注释掉相关校验行
75
-
76
- if content != original:
77
- with open(filepath, "w") as f:
78
- f.write(content)
79
- patched_count += 1
80
- print(f"✅ 已修补文件: {filepath}")
81
-
82
- except Exception as e:
83
- print(f"⚠️ 处理文件 {filepath} 时出错: {e}")
84
-
85
- if patched_count == 0:
86
- print("⚠️ 未找到需要修补的文件,尝试通用方案...")
87
- # 通用方案:patch PretrainedConfig 的 __init_subclass__
88
- patch_config_base_class()
89
- else:
90
- print(f"✅ 共修补了 {patched_count} 个文件")
91
-
92
-
93
- def patch_config_base_class():
94
- """
95
- 如果找不到具体文件可改,就 patch PretrainedConfig 基类,
96
- 让所有配置类在实例化时自动容忍 int/float 互转。
97
- """
98
- from transformers import PretrainedConfig
99
-
100
- original_init_subclass = PretrainedConfig.__init_subclass__
101
- original_init = PretrainedConfig.__init__
102
-
103
- # patch __setattr__ 让赋值时自动兼容
104
- original_setattr = PretrainedConfig.__setattr__ if hasattr(PretrainedConfig, '__setattr__') else object.__setattr__
105
-
106
- def tolerant_setattr(self, name, value):
107
- # 对 initializer_range 不做严格类型检查
108
- if name == "initializer_range":
109
- # 直接写入,跳过任何校验
110
- self.__dict__[name] = value
111
- return
112
- try:
113
- original_setattr(self, name, value)
114
- except TypeError:
115
- self.__dict__[name] = value
116
-
117
- PretrainedConfig.__setattr__ = tolerant_setattr
118
-
119
-
120
- # === 执行修复 ===
121
- print("🔧 正在修补 transformers 源码中的 initializer_range 类型问题...")
122
- patch_transformers_source()
123
- print("🔧 修补完成!")
124
-
125
- # === 现在才导入其他模块 ===
126
  import gradio as gr
127
  import torch
128
  import numpy as np
129
  import matplotlib.pyplot as plt
 
130
  from PIL import Image
131
  import gc
132
- import spaces
 
133
  import cv2
134
- from transformers import Sam3Model, Sam3Processor
 
 
 
135
 
136
  HF_TOKEN = os.getenv("HF_TOKEN")
137
  MODELS = {}
138
- device = "cuda" if torch.cuda.is_available() else "cpu"
139
 
140
  def cleanup_memory():
141
  if MODELS:
@@ -143,156 +23,174 @@ def cleanup_memory():
143
  gc.collect()
144
  torch.cuda.empty_cache()
145
 
146
- def get_model():
147
- model_id = "facebook/sam3"
148
- if model_id in MODELS:
149
- return MODELS[model_id]
 
 
 
 
150
 
151
  cleanup_memory()
152
- print("⏳ 正在加载 SAM 3 模型...")
153
-
154
- model = Sam3Model.from_pretrained(
155
- model_id,
156
- token=HF_TOKEN,
157
- torch_dtype=torch.float16 if device == "cuda" else torch.float32,
158
- ).to(device)
159
-
160
- processor = Sam3Processor.from_pretrained(
161
- model_id,
162
- token=HF_TOKEN
163
- )
164
 
165
- MODELS[model_id] = (model, processor)
166
- return MODELS[model_id]
 
167
 
168
  def overlay_masks(image, masks, alpha=0.6):
169
- if image is None:
170
  return None
171
- if isinstance(image, np.ndarray):
172
  image = Image.fromarray(image)
173
  image = image.convert("RGBA")
174
-
175
- if masks is None or len(masks) == 0:
176
  return image.convert("RGB")
177
-
178
- if isinstance(masks, torch.Tensor):
179
  masks = masks.cpu().numpy()
180
-
181
  masks = masks.astype(np.uint8)
182
  if masks.ndim == 4: masks = masks[0]
183
  if masks.ndim == 3 and masks.shape[0] == 1: masks = masks[0]
184
  if masks.ndim == 2: masks = [masks]
185
-
186
  n_masks = len(masks)
187
- cmap = plt.get_cmap("rainbow", max(n_masks, 1))
 
 
 
 
188
  overlay_layer = Image.new("RGBA", image.size, (0, 0, 0, 0))
189
-
190
  for i, mask in enumerate(masks):
191
  mask_img = Image.fromarray((mask * 255).astype(np.uint8))
192
- if mask_img.size != image.size:
193
  mask_img = mask_img.resize(image.size, resample=Image.NEAREST)
194
-
195
  rgb = [int(x * 255) for x in cmap(i)[:3]]
196
  color_layer = Image.new("RGBA", image.size, tuple(rgb) + (0,))
197
  mask_alpha = mask_img.point(lambda v: int(v * alpha) if v > 0 else 0)
198
  color_layer.putalpha(mask_alpha)
199
  overlay_layer = Image.alpha_composite(overlay_layer, color_layer)
200
-
201
  return Image.alpha_composite(image, overlay_layer).convert("RGB")
202
 
 
 
203
  @spaces.GPU
204
  def process_text_detection(image, text_query, threshold):
205
- if not image or not text_query:
206
  return None, "请输入图像和描述词"
207
-
208
  try:
209
- model, processor = get_model()
210
-
211
  inputs = processor(
212
- images=image,
213
- text=text_query,
214
  return_tensors="pt"
215
  ).to(device)
216
-
217
- with torch.no_grad():
218
  outputs = model(**inputs)
219
-
220
  results = processor.post_process_instance_segmentation(
221
- outputs,
222
- threshold=threshold,
223
- mask_threshold=0.5,
224
  target_sizes=inputs.get("original_sizes").tolist()
225
  )[0]
226
-
227
  masks = results["masks"]
228
  result_img = overlay_masks(image, masks)
229
-
230
  if len(masks) > 0:
231
  status = f"✅ 文本检测完成!找到 {len(masks)} 个目标。"
232
  else:
233
  status = "❓ 未找到目标,请调低阈值。"
234
-
235
  return result_img, status
236
-
237
  except Exception as e:
238
  return image, f"❌ 错误: {str(e)}"
239
 
 
 
240
  @spaces.GPU
241
  def process_sample_detection(main_image, sample_image):
242
- if not main_image or not sample_image:
243
  return None, "请上传主图和样本截图"
244
-
245
  try:
246
- model, processor = get_model()
247
-
 
 
248
  main_cv = cv2.cvtColor(np.array(main_image), cv2.COLOR_RGB2BGR)
249
  sample_cv = cv2.cvtColor(np.array(sample_image), cv2.COLOR_RGB2BGR)
250
-
251
  if sample_cv.shape[0] > main_cv.shape[0] or sample_cv.shape[1] > main_cv.shape[1]:
252
  return main_image, "❌ 错误:样本截图不能比主图还大!"
253
-
254
  result = cv2.matchTemplate(main_cv, sample_cv, cv2.TM_CCOEFF_NORMED)
255
  min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(result)
256
-
257
  if max_val < 0.4:
258
  return main_image, f"❓ 未在主图中找到该样本 (最高匹配度: {max_val:.2f})。"
259
-
260
  h, w = sample_cv.shape[:2]
261
-
262
  box = [
263
- max_loc[0],
264
- max_loc[1],
265
- max_loc[0] + w,
266
  max_loc[1] + h
267
  ]
268
-
 
269
  inputs = processor(
270
- images=main_image,
271
- input_boxes=[[[box]]],
272
  return_tensors="pt"
273
  ).to(device)
274
-
275
- with torch.no_grad():
276
  outputs = model(**inputs)
277
-
278
- results = processor.post_process_instance_segmentation(
279
- outputs,
280
- threshold=0.1,
281
- mask_threshold=0.5,
282
- target_sizes=inputs.get("original_sizes").tolist()
283
  )[0]
284
-
285
- masks = results["masks"]
 
 
 
 
 
 
 
 
 
286
  result_img = overlay_masks(main_image, masks)
287
-
288
  return result_img, f"✅ 样本检测成功!(匹配度: {max_val:.2f})"
289
-
290
  except Exception as e:
291
  return main_image, f"❌ 错误: {str(e)}"
292
 
 
 
293
  with gr.Blocks() as demo:
294
  gr.Markdown("# 🚀 SAM 3 自动检测工具 (双模式)")
295
-
296
  with gr.Tabs():
297
  with gr.Tab("📝 文本描述检测"):
298
  with gr.Row():
@@ -305,8 +203,8 @@ with gr.Blocks() as demo:
305
  t_img_out = gr.Image(type="pil", label="检测结果")
306
  t_info = gr.Textbox(label="状态信息")
307
  t_btn.click(
308
- process_text_detection,
309
- [t_img_in, t_query, t_thresh],
310
  [t_img_out, t_info]
311
  )
312
 
@@ -321,8 +219,8 @@ with gr.Blocks() as demo:
321
  s_img_out = gr.Image(type="pil", label="检测结果")
322
  s_info = gr.Textbox(label="状态信息")
323
  s_btn.click(
324
- process_sample_detection,
325
- [s_img_main, s_img_sample],
326
  [s_img_out, s_info]
327
  )
328
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import torch
3
  import numpy as np
4
  import matplotlib.pyplot as plt
5
+ import matplotlib
6
  from PIL import Image
7
  import gc
8
+ import os
9
+ import spaces
10
  import cv2
11
+ from transformers import (
12
+ Sam3Model, Sam3Processor,
13
+ Sam3TrackerModel, Sam3TrackerProcessor,
14
+ )
15
 
16
  HF_TOKEN = os.getenv("HF_TOKEN")
17
  MODELS = {}
18
+ device = "cuda"
19
 
20
  def cleanup_memory():
21
  if MODELS:
 
23
  gc.collect()
24
  torch.cuda.empty_cache()
25
 
26
+ def get_model(model_type):
27
+ """
28
+ 按需加载不同模型:
29
+ - sam3_image_text: 文本检测用 Sam3Model + Sam3Processor
30
+ - sam3_image_tracker: 样本/Box检测用 Sam3TrackerModel + Sam3TrackerProcessor
31
+ """
32
+ if model_type in MODELS:
33
+ return MODELS[model_type]
34
 
35
  cleanup_memory()
36
+ print(f"⏳ 正在加载 {model_type} 模型...")
37
+
38
+ if model_type == "sam3_image_text":
39
+ model = Sam3Model.from_pretrained("facebook/sam3", token=HF_TOKEN).to(device)
40
+ processor = Sam3Processor.from_pretrained("facebook/sam3", token=HF_TOKEN)
41
+ elif model_type == "sam3_image_tracker":
42
+ model = Sam3TrackerModel.from_pretrained("facebook/sam3", token=HF_TOKEN).to(device)
43
+ processor = Sam3TrackerProcessor.from_pretrained("facebook/sam3", token=HF_TOKEN)
44
+ else:
45
+ raise ValueError(f"未知模型类型: {model_type}")
 
 
46
 
47
+ MODELS[model_type] = (model, processor)
48
+ print(f"✅ {model_type} 加载完成。")
49
+ return MODELS[model_type]
50
 
51
  def overlay_masks(image, masks, alpha=0.6):
52
+ if image is None:
53
  return None
54
+ if isinstance(image, np.ndarray):
55
  image = Image.fromarray(image)
56
  image = image.convert("RGBA")
57
+
58
+ if masks is None or len(masks) == 0:
59
  return image.convert("RGB")
60
+
61
+ if isinstance(masks, torch.Tensor):
62
  masks = masks.cpu().numpy()
63
+
64
  masks = masks.astype(np.uint8)
65
  if masks.ndim == 4: masks = masks[0]
66
  if masks.ndim == 3 and masks.shape[0] == 1: masks = masks[0]
67
  if masks.ndim == 2: masks = [masks]
68
+
69
  n_masks = len(masks)
70
+ try:
71
+ cmap = matplotlib.colormaps["rainbow"].resampled(max(n_masks, 1))
72
+ except AttributeError:
73
+ cmap = plt.get_cmap("rainbow", max(n_masks, 1))
74
+
75
  overlay_layer = Image.new("RGBA", image.size, (0, 0, 0, 0))
 
76
  for i, mask in enumerate(masks):
77
  mask_img = Image.fromarray((mask * 255).astype(np.uint8))
78
+ if mask_img.size != image.size:
79
  mask_img = mask_img.resize(image.size, resample=Image.NEAREST)
 
80
  rgb = [int(x * 255) for x in cmap(i)[:3]]
81
  color_layer = Image.new("RGBA", image.size, tuple(rgb) + (0,))
82
  mask_alpha = mask_img.point(lambda v: int(v * alpha) if v > 0 else 0)
83
  color_layer.putalpha(mask_alpha)
84
  overlay_layer = Image.alpha_composite(overlay_layer, color_layer)
85
+
86
  return Image.alpha_composite(image, overlay_layer).convert("RGB")
87
 
88
+
89
+ # ========== 文本描述检测 ==========
90
  @spaces.GPU
91
  def process_text_detection(image, text_query, threshold):
92
+ if not image or not text_query:
93
  return None, "请输入图像和描述词"
 
94
  try:
95
+ model, processor = get_model("sam3_image_text")
96
+
97
  inputs = processor(
98
+ images=image,
99
+ text=text_query,
100
  return_tensors="pt"
101
  ).to(device)
102
+
103
+ with torch.no_grad():
104
  outputs = model(**inputs)
105
+
106
  results = processor.post_process_instance_segmentation(
107
+ outputs,
108
+ threshold=threshold,
109
+ mask_threshold=0.5,
110
  target_sizes=inputs.get("original_sizes").tolist()
111
  )[0]
112
+
113
  masks = results["masks"]
114
  result_img = overlay_masks(image, masks)
115
+
116
  if len(masks) > 0:
117
  status = f"✅ 文本检测完成!找到 {len(masks)} 个目标。"
118
  else:
119
  status = "❓ 未找到目标,请调低阈值。"
 
120
  return result_img, status
121
+
122
  except Exception as e:
123
  return image, f"❌ 错误: {str(e)}"
124
 
125
+
126
+ # ========== 样���截图检测 ==========
127
  @spaces.GPU
128
  def process_sample_detection(main_image, sample_image):
129
+ if not main_image or not sample_image:
130
  return None, "请上传主图和样本截图"
 
131
  try:
132
+ # 关键:box prompt 必须用 Sam3TrackerModel,不能用 Sam3Model
133
+ model, processor = get_model("sam3_image_tracker")
134
+
135
+ # Step 1: OpenCV 模板匹配,定位样本在主图中的位置
136
  main_cv = cv2.cvtColor(np.array(main_image), cv2.COLOR_RGB2BGR)
137
  sample_cv = cv2.cvtColor(np.array(sample_image), cv2.COLOR_RGB2BGR)
138
+
139
  if sample_cv.shape[0] > main_cv.shape[0] or sample_cv.shape[1] > main_cv.shape[1]:
140
  return main_image, "❌ 错误:样本截图不能比主图还大!"
141
+
142
  result = cv2.matchTemplate(main_cv, sample_cv, cv2.TM_CCOEFF_NORMED)
143
  min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(result)
144
+
145
  if max_val < 0.4:
146
  return main_image, f"❓ 未在主图中找到该样本 (最高匹配度: {max_val:.2f})。"
147
+
148
  h, w = sample_cv.shape[:2]
 
149
  box = [
150
+ max_loc[0],
151
+ max_loc[1],
152
+ max_loc[0] + w,
153
  max_loc[1] + h
154
  ]
155
+
156
+ # Step 2: 用 Sam3TrackerProcessor 的 box prompt 做精细分割
157
  inputs = processor(
158
+ images=main_image,
159
+ input_boxes=[[[box]]],
160
  return_tensors="pt"
161
  ).to(device)
162
+
163
+ with torch.no_grad():
164
  outputs = model(**inputs)
165
+
166
+ # Sam3Tracker 用 post_process_masks 而不是 post_process_instance_segmentation
167
+ masks = processor.post_process_masks(
168
+ outputs.pred_masks.cpu(),
169
+ inputs["original_sizes"],
170
+ binarize=True
171
  )[0]
172
+
173
+ # masks 的形状是 [num_objects, num_masks, H, W],取第一个物体的最佳 mask
174
+ if masks.ndim == 4:
175
+ # 取 IoU 最高的 mask
176
+ if hasattr(outputs, 'iou_scores') and outputs.iou_scores is not None:
177
+ scores = outputs.iou_scores.cpu().numpy()[0, 0]
178
+ best_idx = np.argmax(scores)
179
+ masks = masks[0, best_idx:best_idx+1]
180
+ else:
181
+ masks = masks[0, 0:1]
182
+
183
  result_img = overlay_masks(main_image, masks)
 
184
  return result_img, f"✅ 样本检测成功!(匹配度: {max_val:.2f})"
185
+
186
  except Exception as e:
187
  return main_image, f"❌ 错误: {str(e)}"
188
 
189
+
190
+ # ========== Gradio 界面 ==========
191
  with gr.Blocks() as demo:
192
  gr.Markdown("# 🚀 SAM 3 自动检测工具 (双模式)")
193
+
194
  with gr.Tabs():
195
  with gr.Tab("📝 文本描述检测"):
196
  with gr.Row():
 
203
  t_img_out = gr.Image(type="pil", label="检测结果")
204
  t_info = gr.Textbox(label="状态信息")
205
  t_btn.click(
206
+ process_text_detection,
207
+ [t_img_in, t_query, t_thresh],
208
  [t_img_out, t_info]
209
  )
210
 
 
219
  s_img_out = gr.Image(type="pil", label="检测结果")
220
  s_info = gr.Textbox(label="状态信息")
221
  s_btn.click(
222
+ process_sample_detection,
223
+ [s_img_main, s_img_sample],
224
  [s_img_out, s_info]
225
  )
226