Spaces:
Running on Zero
Running on Zero
| import os | |
| import gc | |
| import cv2 | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import matplotlib | |
| from PIL import Image, ImageDraw | |
| import transformers | |
| import pydantic | |
| from transformers import ( | |
| Sam3Model, | |
| Sam3Processor, | |
| Sam3TrackerModel, | |
| Sam3TrackerProcessor, | |
| ) | |
| print("torch:", torch.__version__) | |
| print("transformers:", transformers.__version__) | |
| print("pydantic:", pydantic.__version__) | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| MODELS = {} | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| if device != "cuda": | |
| raise RuntimeError("CUDA 不可用,SAM3 无法运行") | |
| def cleanup_memory(): | |
| if MODELS: | |
| MODELS.clear() | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| def get_model(model_type): | |
| if model_type in MODELS: | |
| return MODELS[model_type] | |
| cleanup_memory() | |
| print(f"⏳ 正在加载 {model_type} 模型...") | |
| try: | |
| if model_type == "sam3_image_text": | |
| model = Sam3Model.from_pretrained("facebook/sam3", token=HF_TOKEN).to(device) | |
| processor = Sam3Processor.from_pretrained("facebook/sam3", token=HF_TOKEN) | |
| elif model_type == "sam3_image_tracker": | |
| model = Sam3TrackerModel.from_pretrained("facebook/sam3", token=HF_TOKEN).to(device) | |
| processor = Sam3TrackerProcessor.from_pretrained("facebook/sam3", token=HF_TOKEN) | |
| else: | |
| raise ValueError(f"未知模型类型: {model_type}") | |
| MODELS[model_type] = (model, processor) | |
| print(f"✅ {model_type} 加载完成。") | |
| return MODELS[model_type] | |
| except Exception as e: | |
| cleanup_memory() | |
| raise RuntimeError(f"{model_type} 加载失败: {e}") | |
| def overlay_masks(image, masks, alpha=0.6): | |
| if image is None: | |
| return None | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| image = image.convert("RGBA") | |
| if masks is None or len(masks) == 0: | |
| return image.convert("RGB") | |
| if isinstance(masks, torch.Tensor): | |
| masks = masks.detach().cpu().numpy() | |
| masks = masks.astype(np.uint8) | |
| if masks.ndim == 4: | |
| masks = masks[0] | |
| if masks.ndim == 3 and masks.shape[0] == 1: | |
| masks = masks[0] | |
| if masks.ndim == 2: | |
| masks = [masks] | |
| n_masks = len(masks) | |
| try: | |
| cmap = matplotlib.colormaps["rainbow"].resampled(max(n_masks, 1)) | |
| except AttributeError: | |
| cmap = plt.get_cmap("rainbow", max(n_masks, 1)) | |
| overlay_layer = Image.new("RGBA", image.size, (0, 0, 0, 0)) | |
| for i, mask in enumerate(masks): | |
| mask_img = Image.fromarray((mask > 0).astype(np.uint8) * 255) | |
| if mask_img.size != image.size: | |
| mask_img = mask_img.resize(image.size, resample=Image.NEAREST) | |
| rgb = [int(x * 255) for x in cmap(i)[:3]] | |
| color_layer = Image.new("RGBA", image.size, tuple(rgb) + (0,)) | |
| mask_alpha = mask_img.point(lambda v: int(v * alpha) if v > 0 else 0) | |
| color_layer.putalpha(mask_alpha) | |
| overlay_layer = Image.alpha_composite(overlay_layer, color_layer) | |
| return Image.alpha_composite(image, overlay_layer).convert("RGB") | |
| def masks_to_binary_mask(masks, image_size): | |
| """把多个 mask 合并成一张二值 mask。白色=目标区域""" | |
| if masks is None: | |
| return None | |
| if isinstance(masks, torch.Tensor): | |
| masks = masks.detach().float().cpu().numpy() | |
| masks = np.array(masks) | |
| if masks.ndim == 4: | |
| masks = masks[0] | |
| if masks.ndim == 3 and masks.shape[0] == 1: | |
| masks = masks[0] | |
| w, h = image_size | |
| combined = np.zeros((h, w), dtype=np.uint8) | |
| if masks.ndim == 2: | |
| combined = (masks > 0).astype(np.uint8) * 255 | |
| elif masks.ndim == 3: | |
| for m in masks: | |
| m = np.array(m) | |
| if m.shape != (h, w): | |
| m_img = Image.fromarray((m > 0).astype(np.uint8) * 255) | |
| m_img = m_img.resize((w, h), resample=Image.NEAREST) | |
| m = np.array(m_img) > 0 | |
| combined = np.maximum(combined, (m > 0).astype(np.uint8) * 255) | |
| return Image.fromarray(combined, mode="L") | |
| def draw_box_on_image(image, box, color="lime", width=3): | |
| """在图像上画一个矩形框,用于预览匹配位置。""" | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| draw_img = image.copy() | |
| draw = ImageDraw.Draw(draw_img) | |
| x1, y1, x2, y2 = box | |
| draw.rectangle([x1, y1, x2, y2], outline=color, width=width) | |
| return draw_img | |
| def multi_scale_template_match(main_cv, sample_cv, min_scale=0.5, max_scale=1.5, steps=15): | |
| """多尺度 + 多方法模板匹配。""" | |
| main_gray = cv2.cvtColor(main_cv, cv2.COLOR_BGR2GRAY) | |
| sample_gray = cv2.cvtColor(sample_cv, cv2.COLOR_BGR2GRAY) | |
| sh, sw = sample_gray.shape[:2] | |
| mh, mw = main_gray.shape[:2] | |
| methods = [ | |
| cv2.TM_CCOEFF_NORMED, | |
| cv2.TM_CCORR_NORMED, | |
| ] | |
| best_score = -1 | |
| best_loc = None | |
| best_w, best_h = sw, sh | |
| for scale in np.linspace(min_scale, max_scale, steps): | |
| new_w = int(sw * scale) | |
| new_h = int(sh * scale) | |
| if new_w > mw or new_h > mh: | |
| continue | |
| if new_w < 10 or new_h < 10: | |
| continue | |
| resized_sample = cv2.resize(sample_gray, (new_w, new_h)) | |
| for method in methods: | |
| result = cv2.matchTemplate(main_gray, resized_sample, method) | |
| _, max_val, _, max_loc = cv2.minMaxLoc(result) | |
| if max_val > best_score: | |
| best_score = max_val | |
| best_loc = max_loc | |
| best_w, best_h = new_w, new_h | |
| if best_loc is None: | |
| return None | |
| return best_score, best_loc, best_w, best_h | |
| def process_text_detection(image, text_query, threshold): | |
| if image is None or not text_query: | |
| return None, None, "请输入图像和描述词" | |
| try: | |
| model, processor = get_model("sam3_image_text") | |
| inputs = processor(images=image, text=text_query, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| results = processor.post_process_instance_segmentation( | |
| outputs, | |
| threshold=threshold, | |
| mask_threshold=0.5, | |
| target_sizes=inputs.get("original_sizes").tolist(), | |
| )[0] | |
| masks = results.get("masks") | |
| mask_img = masks_to_binary_mask(masks, image.size) | |
| preview_img = overlay_masks(image, masks) | |
| count = 0 if masks is None else len(masks) | |
| if count > 0: | |
| status = f"✅ 文本检测完成!找到 {count} 个目标。" | |
| else: | |
| status = "❓ 未找到目标,请调低阈值。" | |
| return mask_img, preview_img, status | |
| except Exception as e: | |
| return None, image, f"❌ 错误: {str(e)}" | |
| def process_sample_detection(main_image, sample_image, match_threshold): | |
| if main_image is None or sample_image is None: | |
| return None, None, "请上传主图和样本截图" | |
| try: | |
| model, processor = get_model("sam3_image_tracker") | |
| main_cv = cv2.cvtColor(np.array(main_image), cv2.COLOR_RGB2BGR) | |
| sample_cv = cv2.cvtColor(np.array(sample_image), cv2.COLOR_RGB2BGR) | |
| match = multi_scale_template_match(main_cv, sample_cv) | |
| if match is None: | |
| return None, main_image, "❌ 样本图太大或无法匹配。" | |
| best_score, best_loc, best_w, best_h = match | |
| box = [best_loc[0], best_loc[1], best_loc[0] + best_w, best_loc[1] + best_h] | |
| if best_score < match_threshold: | |
| preview = draw_box_on_image(main_image, box, color="red") | |
| return None, preview, ( | |
| f"❓ 匹配度不足 (最高: {best_score:.2f},阈值: {match_threshold:.2f})。\n" | |
| f"红框为最佳匹配位置,可尝试降低阈值或使用更清晰的截图。" | |
| ) | |
| inputs = processor( | |
| images=main_image, | |
| input_boxes=[[box]], | |
| return_tensors="pt", | |
| ).to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| masks = processor.post_process_masks( | |
| outputs.pred_masks.cpu(), | |
| inputs["original_sizes"], | |
| binarize=True, | |
| )[0] | |
| if masks.ndim == 4: | |
| if hasattr(outputs, "iou_scores") and outputs.iou_scores is not None: | |
| scores = outputs.iou_scores.cpu().numpy()[0, 0] | |
| best_idx = np.argmax(scores) | |
| masks = masks[0, best_idx:best_idx + 1] | |
| else: | |
| masks = masks[0, 0:1] | |
| mask_img = masks_to_binary_mask(masks, main_image.size) | |
| preview_img = overlay_masks(main_image, masks) | |
| preview_img = draw_box_on_image(preview_img, box, color="lime") | |
| return mask_img, preview_img, ( | |
| f"✅ 样本检测成功!\n" | |
| f"匹配度: {best_score:.2f} | 匹配位置: ({box[0]}, {box[1]}) → ({box[2]}, {box[3]})" | |
| ) | |
| except Exception as e: | |
| return None, main_image, f"❌ 错误: {str(e)}" | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 🚀 SAM 3 自动检测工具 (双模式)") | |
| with gr.Tabs(): | |
| with gr.Tab("📝 文本描述检测"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| t_img_in = gr.Image(type="pil", label="上传原图") | |
| t_query = gr.Textbox(label="输入检测内容(英文)", value="watermark") | |
| t_thresh = gr.Slider(0.1, 0.9, value=0.3, step=0.05, label="灵敏度") | |
| t_btn = gr.Button("开始文本检测", variant="primary") | |
| with gr.Column(): | |
| t_mask_out = gr.Image(type="pil", label="二值 Mask") | |
| t_preview_out = gr.Image(type="pil", label="检测预览") | |
| t_info = gr.Textbox(label="状态信息") | |
| t_btn.click( | |
| process_text_detection, | |
| [t_img_in, t_query, t_thresh], | |
| [t_mask_out, t_preview_out, t_info], | |
| api_name="process_text_detection", | |
| ) | |
| with gr.Tab("🖼️ 样本截图检测"): | |
| gr.Markdown( | |
| "⚠️ **使用说明:**\n" | |
| "1. 上传主图(完整大图)\n" | |
| "2. 上传样本截图(你要找的目标的截图)\n" | |
| "3. 样本最好是从主图中**原比例截取**的,支持一定程度的缩放\n" | |
| "4. 如果匹配失败,可以降低「匹配阈值」" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| s_img_main = gr.Image(type="pil", label="上传主图") | |
| s_img_sample = gr.Image(type="pil", label="上传样本截图") | |
| s_thresh = gr.Slider(0.1, 0.9, value=0.25, step=0.05, label="匹配阈值(越低越容易匹配)") | |
| s_btn = gr.Button("开始样本检测", variant="primary") | |
| with gr.Column(): | |
| s_mask_out = gr.Image(type="pil", label="二值 Mask") | |
| s_preview_out = gr.Image(type="pil", label="检测预览") | |
| s_info = gr.Textbox(label="状态信息", lines=3) | |
| s_btn.click( | |
| process_sample_detection, | |
| [s_img_main, s_img_sample, s_thresh], | |
| [s_mask_out, s_preview_out, s_info], | |
| api_name="process_sample_detection", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |