| |
| """ |
| SAM 3 目标检测与分割演示的 Hugging Face Spaces 版本。 |
| |
| 适配 Hugging Face Spaces 部署环境: |
| 1. 直接从 Hugging Face Hub 下载模型和资源 |
| 2. 支持 ZeroGPU 或 CPU 推理 |
| 3. 无需本地上传额外文件 |
| |
| 支持功能: |
| 1. 文本提示分割 |
| 2. 单框/多框提示分割 |
| 3. 正框/负框交互式标注(Multi Box 模式下可切换绘制正框或负框) |
| """ |
|
|
| import os |
| import torch |
| import numpy as np |
| import gradio as gr |
| from PIL import Image, ImageDraw, ImageFont |
| import matplotlib.pyplot as plt |
| import io |
| import random |
| from typing import List, Dict, Any, Tuple |
|
|
| |
| from huggingface_hub import hf_hub_download, snapshot_download |
|
|
| |
| IS_HF_SPACES = os.environ.get("SPACE_ID") is not None |
|
|
| |
| try: |
| from gradio_image_prompter import ImagePrompter |
| IMAGE_PROMPTER_AVAILABLE = True |
| except ImportError as e: |
| print(f"ImagePrompter 不可用: {e}") |
| IMAGE_PROMPTER_AVAILABLE = False |
|
|
| |
| try: |
| import spaces |
| SPACES_GPU_AVAILABLE = True |
| except ImportError: |
| SPACES_GPU_AVAILABLE = False |
| print("Hugging Face Spaces GPU 模块不可用,将使用标准推理") |
|
|
| |
| |
| |
| |
| SAM3_HF_REPO_ID = os.environ.get("SAM3_HF_REPO_ID", "facebook/sam3") |
|
|
| |
| |
| SAM3_INSTALLED = False |
| sam3 = None |
| build_sam3_image_model = None |
| box_xywh_to_cxcywh = None |
| Sam3Processor = None |
| normalize_bbox = None |
| draw_box_on_image = None |
| plot_mask = None |
| plot_bbox = None |
| COLORS = [(1, 0, 0), (0, 1, 0), (0, 0, 1)] |
| plot_results = None |
|
|
| try: |
| import sam3 |
| from sam3 import build_sam3_image_model |
| SAM3_INSTALLED = True |
| print("✅ sam3 库已安装") |
| |
| |
| try: |
| from sam3.model.box_ops import box_xywh_to_cxcywh |
| except ImportError as e: |
| print(f"⚠️ box_ops 导入失败: {e}") |
| |
| def box_xywh_to_cxcywh(boxes): |
| """将 XYWH 格式转换为 CXCYWH 格式""" |
| x, y, w, h = boxes.unbind(-1) |
| cx = x + w / 2 |
| cy = y + h / 2 |
| return torch.stack([cx, cy, w, h], dim=-1) |
| |
| try: |
| from sam3.model.sam3_image_processor import Sam3Processor |
| except ImportError as e: |
| print(f"⚠️ Sam3Processor 导入失败: {e}") |
| Sam3Processor = None |
| |
| try: |
| from sam3.visualization_utils import normalize_bbox, draw_box_on_image, plot_mask, plot_bbox, COLORS, plot_results |
| except ImportError as e: |
| print(f"⚠️ visualization_utils 导入失败: {e}") |
| |
| def normalize_bbox(boxes, width, height): |
| """归一化边界框坐标""" |
| if isinstance(boxes, torch.Tensor): |
| normalized = boxes.clone() |
| normalized[..., 0] /= width |
| normalized[..., 1] /= height |
| normalized[..., 2] /= width |
| normalized[..., 3] /= height |
| return normalized |
| return boxes |
| |
| def plot_mask(mask, color=(1, 0, 0), alpha=0.5): |
| """绘制掩码""" |
| import matplotlib.pyplot as plt |
| h, w = mask.shape[-2:] |
| mask_image = mask.reshape(h, w, 1) * np.array(color).reshape(1, 1, -1) |
| plt.imshow(mask_image, alpha=alpha) |
| |
| def plot_bbox(h, w, box, text="", box_format="XYXY", color=(1, 0, 0), relative_coords=False): |
| """绘制边界框""" |
| import matplotlib.pyplot as plt |
| import matplotlib.patches as patches |
| if isinstance(box, torch.Tensor): |
| box = box.tolist() |
| x0, y0, x1, y1 = box |
| rect = patches.Rectangle((x0, y0), x1-x0, y1-y0, linewidth=2, edgecolor=color, facecolor='none') |
| plt.gca().add_patch(rect) |
| if text: |
| plt.text(x0, y0, text, color=color, fontsize=8) |
| |
| COLORS = [(1, 0, 0), (0, 1, 0), (0, 0, 1), (1, 1, 0), (1, 0, 1), (0, 1, 1)] |
| plot_results = None |
| draw_box_on_image = None |
|
|
| except ImportError as e: |
| print(f"❌ sam3 库导入失败: {e}") |
| print("请确保 requirements.txt 中包含: git+https://github.com/facebookresearch/sam3.git") |
|
|
|
|
| |
|
|
| def draw_boxes_with_labels( |
| image: Image.Image, |
| xyxy_boxes: List[List[float]], |
| box_labels: List[bool] |
| ) -> Image.Image: |
| """ |
| 在图像上绘制带颜色和标签的框。 |
| |
| Args: |
| image: 原始 PIL 图像 |
| xyxy_boxes: 框坐标列表 [[x_min, y_min, x_max, y_max], ...] |
| box_labels: 框标签列表 [True/False, ...],True=正框(绿色),False=负框(红色) |
| |
| Returns: |
| 带有彩色框和标签的图像 |
| """ |
| if image is None: |
| return None |
| |
| |
| img_draw = image.copy() |
| draw = ImageDraw.Draw(img_draw) |
| |
| |
| try: |
| font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 16) |
| except: |
| try: |
| font = ImageFont.truetype("Arial.ttf", 16) |
| except: |
| font = ImageFont.load_default() |
| |
| for i, (box, label) in enumerate(zip(xyxy_boxes, box_labels)): |
| x_min, y_min, x_max, y_max = [int(coord) for coord in box] |
| |
| |
| if label: |
| color = (255, 0, 0) |
| label_text = f"Box {i}: True (正框)" |
| else: |
| color = (0, 255, 0) |
| label_text = f"Box {i}: False (负框)" |
| |
| |
| draw.rectangle([x_min, y_min, x_max, y_max], outline=color, width=3) |
| |
| |
| text_bbox = draw.textbbox((x_min, y_min - 20), label_text, font=font) |
| |
| text_y = max(0, y_min - 22) |
| if text_y == 0: |
| text_y = y_max + 2 |
| |
| text_bbox = draw.textbbox((x_min, text_y), label_text, font=font) |
| draw.rectangle(text_bbox, fill=color) |
| draw.text((x_min, text_y), label_text, fill="white", font=font) |
| |
| return img_draw |
|
|
|
|
| def process_imageprompter_data( |
| data: Any, |
| box_mode_history: List[Tuple[int, str]] = None, |
| verbose: bool = False |
| ) -> Tuple[List[List[float]], List[bool]]: |
| """ |
| 处理 ImagePrompter 数据,提取框坐标 (XYXY 格式) 和对应的标签(正/负框)。 |
| |
| ImagePrompter 返回格式: |
| {'image': <PIL Image>, 'points': [[x1, y1, label1, x2, y2, label2], ...]} |
| |
| Args: |
| data: ImagePrompter 返回的数据字典 |
| box_mode_history: 框模式切换历史列表,格式为 [(框索引, 模式), ...] |
| 例如 [(0, "positive"), (2, "negative")] 表示第0个框开始是正框,第2个框开始是负框 |
| 如果为 None 或空,则所有框默认为正框 |
| verbose: 是否输出详细调试日志 |
| |
| Returns: |
| tuple: (xyxy_boxes, box_labels) |
| - xyxy_boxes: 框坐标列表 [[x_min, y_min, x_max, y_max], ...] |
| - box_labels: 框标签列表 [True/False, ...],True=正框,False=负框 |
| """ |
| if data is None or not isinstance(data, dict): |
| return [], [] |
| |
| xyxy_boxes = [] |
| |
| if verbose: |
| print(f"\n--- Shape Parsing Debug START ---") |
| print(f"Debug: Data keys = {list(data.keys())}") |
| print(f"Debug: Box mode history = {box_mode_history}") |
| |
| |
| |
| if 'points' in data and data['points'] is not None: |
| points_list = data['points'] |
| |
| for i, points in enumerate(points_list): |
| if isinstance(points, (list, np.ndarray)) and len(points) >= 6: |
| try: |
| |
| x1 = float(points[0]) |
| y1 = float(points[1]) |
| x2 = float(points[3]) |
| y2 = float(points[4]) |
| |
| |
| x_min = min(x1, x2) |
| x_max = max(x1, x2) |
| y_min = min(y1, y2) |
| y_max = max(y1, y2) |
| |
| box = [x_min, y_min, x_max, y_max] |
| xyxy_boxes.append(box) |
| |
| except (ValueError, TypeError, IndexError): |
| pass |
| |
| |
| |
| box_labels = [] |
| current_mode = "positive" |
| |
| |
| mode_switch_points = {} |
| if box_mode_history: |
| for box_idx, mode in box_mode_history: |
| mode_switch_points[box_idx] = mode |
| |
| for i in range(len(xyxy_boxes)): |
| |
| if i in mode_switch_points: |
| current_mode = mode_switch_points[i] |
| |
| is_positive = (current_mode == "positive") |
| box_labels.append(is_positive) |
| |
| if verbose: |
| print(f"Total boxes: {len(xyxy_boxes)} (正框: {sum(box_labels) if box_labels else 0}, 负框: {len(box_labels) - sum(box_labels) if box_labels else 0})") |
| print(f"--- Shape Parsing Debug END ---\n") |
| |
| return xyxy_boxes, box_labels |
|
|
|
|
| |
|
|
| def plot_boxes_to_image( |
| image_pil: Image, |
| tgt: Dict, |
| return_point: bool = False, |
| point_width: float = 1.0, |
| return_score=True, |
| ) -> Image: |
| """Plot bounding boxes and labels on an image.""" |
| boxes = tgt["boxes"] |
| scores = tgt["scores"] |
|
|
| draw = ImageDraw.Draw(image_pil) |
| mask = Image.new("L", image_pil.size, 0) |
| mask_draw = ImageDraw.Draw(mask) |
|
|
| for box, score in zip(boxes, scores): |
| color = tuple(np.random.randint(0, 255, size=3).tolist()) |
| x0, y0, x1, y1 = box |
| x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1) |
| |
| if return_point: |
| center_x = int((x0 + x1) / 2) |
| center_y = int((y0 + y1) / 2) |
| draw.ellipse( |
| ( |
| center_x - point_width, |
| center_y - point_width, |
| center_x + point_width, |
| center_y + point_width, |
| ), |
| fill=color, |
| width=point_width, |
| ) |
| else: |
| draw.rectangle([x0, y0, x1, y1], outline=color, width=int(point_width)) |
|
|
| if return_score: |
| text = f"{score:.2f}" |
| else: |
| text = f"" |
| font = ImageFont.load_default() |
| if hasattr(font, "getbbox"): |
| bbox = draw.textbbox((x0, y0), text, font) |
| else: |
| w, h = draw.textsize(text, font) |
| bbox = (x0, y0, w + x0, y0 + h) |
| if not return_point: |
| draw.rectangle(bbox, fill=color) |
| draw.text((x0, y0), text, fill="white") |
|
|
| mask_draw.rectangle([x0, y0, x1, y1], fill=255, width=6) |
| return image_pil, mask |
|
|
|
|
| def parse_visual_prompt(points: List): |
| """Parse visual prompt points to bounding boxes (XYXY format)""" |
| boxes = [] |
| pos_points = [] |
| neg_points = [] |
| for point in points: |
| if point[2] == 2 and point[-1] == 3: |
| x1, y1, _, x2, y2, _ = point |
| boxes.append([x1, y1, x2, y2]) |
| elif point[2] == 1 and point[-1] == 4: |
| x, y, _, _, _, _ = point |
| pos_points.append([x, y]) |
| elif point[2] == 0 and point[-1] == 4: |
| x, y, _, _, _, _ = point |
| neg_points.append([x, y]) |
| return boxes, pos_points, neg_points |
|
|
|
|
| |
|
|
| |
| bpe_path = None |
| sam3_checkpoint = None |
| example_image_hf_path = None |
|
|
| def download_resources_from_hf(): |
| """从 Hugging Face Hub 下载模型和资源文件""" |
| global bpe_path, sam3_checkpoint, example_image_hf_path |
| |
| if not SAM3_INSTALLED: |
| print("❌ sam3 库未安装,无法下载资源") |
| return False |
| |
| try: |
| |
| bpe_path = hf_hub_download( |
| repo_id=SAM3_HF_REPO_ID, |
| filename="assets/bpe_simple_vocab_16e6.txt.gz", |
| cache_dir=os.environ.get("HF_HOME", None) |
| ) |
| print(f"✅ BPE 词汇表: {bpe_path}") |
| except Exception as e: |
| print(f"⚠️ 无法下载 BPE 词汇表: {e}") |
| |
| if sam3 is not None: |
| sam3_root = os.path.join(os.path.dirname(sam3.__file__), "..") |
| bpe_path = os.path.join(sam3_root, "assets", "bpe_simple_vocab_16e6.txt.gz") |
| if not os.path.exists(bpe_path): |
| bpe_path = None |
| |
| try: |
| |
| |
| env_checkpoint = os.environ.get("SAM3_CHECKPOINT_PATH") |
| if env_checkpoint and os.path.exists(env_checkpoint): |
| sam3_checkpoint = env_checkpoint |
| print(f"✅ 使用环境变量指定的模型: {sam3_checkpoint}") |
| else: |
| |
| sam3_checkpoint = hf_hub_download( |
| repo_id=SAM3_HF_REPO_ID, |
| filename="checkpoints/sam3.pt", |
| cache_dir=os.environ.get("HF_HOME", None) |
| ) |
| print(f"✅ 模型检查点: {sam3_checkpoint}") |
| except Exception as e: |
| print(f"⚠️ 无法下载模型检查点: {e}") |
| |
| try: |
| sam3_checkpoint = hf_hub_download( |
| repo_id=SAM3_HF_REPO_ID, |
| filename="sam3.pt", |
| cache_dir=os.environ.get("HF_HOME", None) |
| ) |
| print(f"✅ 模型检查点(备选): {sam3_checkpoint}") |
| except: |
| sam3_checkpoint = None |
| |
| try: |
| |
| example_image_hf_path = hf_hub_download( |
| repo_id=SAM3_HF_REPO_ID, |
| filename="assets/images/test_image.jpg", |
| cache_dir=os.environ.get("HF_HOME", None) |
| ) |
| print(f"✅ 示例图片: {example_image_hf_path}") |
| except Exception as e: |
| print(f"⚠️ 无法下载示例图片: {e}") |
| example_image_hf_path = None |
| |
| return bpe_path is not None and sam3_checkpoint is not None |
|
|
| |
| print(f"\n{'='*50}") |
| print(f"正在从 Hugging Face Hub 下载资源...") |
| print(f"仓库 ID: {SAM3_HF_REPO_ID}") |
| print(f"{'='*50}\n") |
| download_resources_from_hf() |
|
|
| |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"Using device: {DEVICE}") |
|
|
| |
| model = None |
| processor = None |
| autocast_ctx = None |
|
|
|
|
| def load_model(): |
| """延迟加载模型(支持 ZeroGPU)""" |
| global model, processor, autocast_ctx |
| |
| if model is not None: |
| return True |
| |
| if not SAM3_INSTALLED: |
| print("❌ sam3 库未安装") |
| return False |
| |
| if sam3_checkpoint is None: |
| print("❌ 模型检查点路径未配置") |
| return False |
| |
| if bpe_path is None: |
| print("❌ BPE 词汇表路径未配置") |
| return False |
| |
| try: |
| if DEVICE == "cuda": |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
| autocast_ctx = torch.autocast("cuda", dtype=torch.bfloat16) |
| autocast_ctx.__enter__() |
| model = build_sam3_image_model(bpe_path=bpe_path, checkpoint_path=sam3_checkpoint).to(DEVICE) |
| else: |
| autocast_ctx = None |
| model = build_sam3_image_model(bpe_path=bpe_path, checkpoint_path=sam3_checkpoint).to(DEVICE) |
|
|
| processor = Sam3Processor(model, confidence_threshold=0.5) |
| print("✅ 模型加载成功") |
| return True |
| |
| except Exception as e: |
| print(f"❌ 模型加载失败: {e}") |
| import traceback |
| traceback.print_exc() |
| model = None |
| processor = None |
| return False |
|
|
|
|
| |
| if not SPACES_GPU_AVAILABLE: |
| load_model() |
|
|
|
|
| |
|
|
| def plot_to_pil(fig): |
| """将 Matplotlib 图形转换为 PIL Image。""" |
| buf = io.BytesIO() |
| fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) |
| buf.seek(0) |
| plt.close(fig) |
| return Image.open(buf).convert("RGB") |
|
|
|
|
| def get_result_figure( |
| img: Image.Image, |
| results: dict, |
| return_point: bool = False, |
| point_width: float = 3.0, |
| return_score: bool = True |
| ) -> Tuple[plt.Figure, int]: |
| """封装原始 plot_results 逻辑,支持显示中心点和置信度控制。 |
| |
| Args: |
| img: 输入图像 |
| results: 推理结果字典 |
| return_point: 是否显示中心点而不是边框 |
| point_width: 中心点或边框线宽 |
| return_score: 是否显示置信度分数 |
| """ |
| fig = plt.figure(figsize=(12, 8)) |
| plt.imshow(img) |
| plt.axis("off") |
| |
| nb_objects = len(results.get("scores", [])) |
| print(f"found {nb_objects} object(s)") |
| |
| for i in range(nb_objects): |
| color = COLORS[i % len(COLORS)] |
| |
| if "masks" in results and i < len(results["masks"]): |
| mask_data = results["masks"][i] |
| if mask_data.ndim == 3: |
| mask_data = mask_data.squeeze(0) |
| plot_mask(mask_data.cpu(), color=color) |
| |
| if "boxes" in results and i < len(results["boxes"]): |
| w, h = img.size |
| box = results["boxes"][i].cpu().tolist() |
| prob = results["scores"][i].item() |
| |
| |
| if return_point: |
| |
| x0, y0, x1, y1 = box |
| center_x = (x0 + x1) / 2 |
| center_y = (y0 + y1) / 2 |
| circle = plt.Circle( |
| (center_x, center_y), |
| point_width * 2, |
| color=color, |
| fill=True |
| ) |
| plt.gca().add_patch(circle) |
| |
| |
| if return_score: |
| plt.text( |
| center_x + point_width * 3, |
| center_y, |
| f"{prob:.2f}", |
| color=color, |
| fontsize=10, |
| fontweight='bold' |
| ) |
| else: |
| |
| text = f"(id={i}, {prob:.2f})" if return_score else f"(id={i})" |
| plot_bbox( |
| h, |
| w, |
| results["boxes"][i].cpu(), |
| text=text, |
| box_format="XYXY", |
| color=color, |
| relative_coords=False, |
| ) |
| |
| return fig, nb_objects |
|
|
|
|
| |
|
|
| def sam3_segmentation_core( |
| unified_image_input: Any, |
| prompt_text: str, |
| box_type: str, |
| box_mode_history: List[Tuple[int, str]], |
| return_point: bool = False, |
| point_width: float = 3.0, |
| return_score: bool = True |
| ): |
| """核心分割函数""" |
| global model, processor |
| |
| |
| if not SAM3_INSTALLED: |
| return None, "❌ sam3 库未安装,请检查 requirements.txt 或 HF 仓库配置。", box_mode_history |
| |
| if model is None or processor is None: |
| if not load_model(): |
| return None, "❌ 模型未加载,请检查模型配置。可能需要设置 SAM3_HF_REPO_ID 环境变量。", box_mode_history |
|
|
| |
| image = None |
| visual_prompter_data = None |
| |
| if IMAGE_PROMPTER_AVAILABLE and isinstance(unified_image_input, dict): |
| image = unified_image_input.get('image') |
| visual_prompter_data = unified_image_input |
| else: |
| image = unified_image_input |
| visual_prompter_data = None |
|
|
| if image is None: |
| return None, "请上传图像。", box_mode_history |
| |
| img0 = image.copy() |
| width, height = img0.size |
| |
| |
| try: |
| inference_state = processor.set_image(img0) |
| except Exception as e: |
| return None, f"图像处理失败: {e}", box_mode_history |
|
|
| |
| processor.reset_all_prompts(inference_state) |
| found_objects = 0 |
| xyxy_boxes = [] |
| |
| |
| if box_type == "Text": |
| if not prompt_text: |
| return None, "文本模式下,请提供文本提示。", box_mode_history |
| inference_state = processor.set_text_prompt( |
| state=inference_state, |
| prompt=prompt_text |
| ) |
| caption_base = "文本提示分割" |
| |
| |
| elif box_type in ["Single Box", "Multi Box"]: |
| |
| if not IMAGE_PROMPTER_AVAILABLE: |
| return None, "当前环境不支持 ImagePrompter,Box 模式无法运行。", box_mode_history |
|
|
| if visual_prompter_data: |
| |
| xyxy_boxes, box_labels = process_imageprompter_data(visual_prompter_data, box_mode_history, verbose=True) |
| print(f"Boxes: {xyxy_boxes}") |
| print(f"Labels: {box_labels}") |
| |
| if not xyxy_boxes: |
| return None, f"请在图像上绘制至少一个矩形框作为提示(当前模式: {box_type})。", box_mode_history |
| |
| |
| if box_type == "Single Box" and len(xyxy_boxes) > 1: |
| xyxy_boxes = [xyxy_boxes[0]] |
| box_labels = [box_labels[0]] if box_labels else [True] |
|
|
| box_inputs = [] |
| |
| for i, (x_min, y_min, x_max, y_max) in enumerate(xyxy_boxes): |
| x = x_min |
| y = y_min |
| w = x_max - x_min |
| h = y_max - y_min |
| box_inputs.append([x, y, w, h]) |
|
|
| |
| try: |
| box_input_xywh = torch.tensor(box_inputs, dtype=torch.float32).view(-1, 4).to(DEVICE) |
| box_input_cxcywh = box_xywh_to_cxcywh(box_input_xywh) |
| norm_boxes_cxcywh = normalize_bbox(box_input_cxcywh, width, height).tolist() |
| |
| for i in range(len(box_inputs)): |
| norm_box = norm_boxes_cxcywh[i] |
| label = box_labels[i] if i < len(box_labels) else True |
| label_str = "正框" if label else "负框" |
| print(f"Adding box {i}: {norm_box}, label={label} ({label_str})") |
| |
| |
| inference_state = processor.add_geometric_prompt( |
| state=inference_state, box=norm_box, label=label |
| ) |
| |
| except Exception as e: |
| print(f"Error during box conversion/prompt setting: {e}") |
| return None, f"框提示处理失败: {e}", box_mode_history |
| |
| num_positive = sum(box_labels) if box_labels else len(xyxy_boxes) |
| num_negative = len(xyxy_boxes) - num_positive |
| caption_base = f"使用 {len(xyxy_boxes)} 个提示框分割(正框: {num_positive}, 负框: {num_negative})" |
| |
| else: |
| return None, "请选择有效的提示类型 (Text, Single Box, 或 Multi Box)。", box_mode_history |
|
|
| |
| fig, found_objects = get_result_figure( |
| img0.copy(), |
| inference_state, |
| return_point=return_point, |
| point_width=point_width, |
| return_score=return_score |
| ) |
| result_image = plot_to_pil(fig) |
| |
| return result_image, f"{caption_base}。找到 {found_objects} 个对象。", box_mode_history |
|
|
|
|
| |
| if SPACES_GPU_AVAILABLE: |
| @spaces.GPU |
| def sam3_segmentation( |
| unified_image_input: Any, |
| prompt_text: str, |
| box_type: str, |
| box_mode_history: List[Tuple[int, str]], |
| return_point: bool = False, |
| point_width: float = 3.0, |
| return_score: bool = True |
| ): |
| """ZeroGPU 版本的推理函数""" |
| return sam3_segmentation_core( |
| unified_image_input, prompt_text, box_type, |
| box_mode_history, return_point, point_width, return_score |
| ) |
| else: |
| def sam3_segmentation( |
| unified_image_input: Any, |
| prompt_text: str, |
| box_type: str, |
| box_mode_history: List[Tuple[int, str]], |
| return_point: bool = False, |
| point_width: float = 3.0, |
| return_score: bool = True |
| ): |
| """标准版本的推理函数""" |
| return sam3_segmentation_core( |
| unified_image_input, prompt_text, box_type, |
| box_mode_history, return_point, point_width, return_score |
| ) |
|
|
|
|
| |
|
|
| def on_box_mode_change( |
| new_mode: str, |
| unified_image_input: Any, |
| current_history: List[Tuple[int, str]] |
| ) -> Tuple[List[Tuple[int, str]], str, Image.Image]: |
| """ |
| 当用户切换框模式时,记录当前框数量和新模式,并更新预览。 |
| |
| Args: |
| new_mode: 新选择的模式 ("正框 (Positive)" 或 "负框 (Negative)") |
| unified_image_input: 当前 ImagePrompter 的数据 |
| current_history: 当前的模式切换历史 |
| |
| Returns: |
| tuple: (更新后的历史, 状态信息文本, 预览图像) |
| """ |
| if current_history is None: |
| current_history = [] |
| |
| |
| current_box_count = 0 |
| if unified_image_input and isinstance(unified_image_input, dict): |
| points = unified_image_input.get('points', []) |
| if points: |
| current_box_count = len(points) |
| |
| |
| mode_internal = "positive" if "Positive" in new_mode or "正框" in new_mode else "negative" |
| |
| |
| |
| new_history = current_history.copy() |
| new_history.append((current_box_count, mode_internal)) |
| |
| |
| mode_display = "正框" if mode_internal == "positive" else "负框" |
| status = f"✅ 已切换到 {mode_display} 模式。从第 {current_box_count + 1} 个框开始将被标记为{mode_display}。" |
| |
| print(f"Box mode changed: {new_mode} -> {mode_internal}, at box index {current_box_count}") |
| print(f"Updated history: {new_history}") |
| |
| |
| preview_image = None |
| if unified_image_input and isinstance(unified_image_input, dict): |
| image = unified_image_input.get('image') |
| if image is not None: |
| xyxy_boxes, box_labels = process_imageprompter_data(unified_image_input, new_history, verbose=False) |
| if xyxy_boxes: |
| preview_image = draw_boxes_with_labels(image, xyxy_boxes, box_labels) |
| else: |
| preview_image = image |
| |
| return new_history, status, preview_image |
|
|
|
|
| def reset_box_mode_history( |
| unified_image_input: Any |
| ) -> Tuple[List[Tuple[int, str]], str, Image.Image]: |
| """重置框模式历史并更新预览""" |
| new_history = [(0, "positive")] |
| status = "已重置,所有框将默认为正框。" |
| |
| |
| preview_image = None |
| if unified_image_input and isinstance(unified_image_input, dict): |
| image = unified_image_input.get('image') |
| if image is not None: |
| xyxy_boxes, box_labels = process_imageprompter_data(unified_image_input, new_history, verbose=False) |
| if xyxy_boxes: |
| preview_image = draw_boxes_with_labels(image, xyxy_boxes, box_labels) |
| else: |
| preview_image = image |
| |
| return new_history, status, preview_image |
|
|
|
|
| def get_current_box_status( |
| unified_image_input: Any, |
| box_mode_history: List[Tuple[int, str]] |
| ) -> str: |
| """获取当前框的状态信息""" |
| if not unified_image_input or not isinstance(unified_image_input, dict): |
| return "尚未绘制框" |
| |
| points = unified_image_input.get('points', []) |
| if not points: |
| return "尚未绘制框" |
| |
| num_boxes = len(points) |
| |
| |
| if not box_mode_history: |
| return f"已绘制 {num_boxes} 个框(全部为正框)" |
| |
| |
| mode_switch_points = {} |
| for box_idx, mode in box_mode_history: |
| mode_switch_points[box_idx] = mode |
| |
| current_mode = "positive" |
| positive_count = 0 |
| negative_count = 0 |
| |
| for i in range(num_boxes): |
| if i in mode_switch_points: |
| current_mode = mode_switch_points[i] |
| if current_mode == "positive": |
| positive_count += 1 |
| else: |
| negative_count += 1 |
| |
| return f"已绘制 {num_boxes} 个框(正框: {positive_count}, 负框: {negative_count})" |
|
|
|
|
| def update_box_preview( |
| unified_image_input: Any, |
| box_mode_history: List[Tuple[int, str]] |
| ) -> Tuple[Image.Image, str, str]: |
| """ |
| 更新框预览图像,显示带颜色和标签的框。 |
| |
| Args: |
| unified_image_input: ImagePrompter 的数据 |
| box_mode_history: 框模式历史 |
| |
| Returns: |
| tuple: (预览图像, 状态文本, 框提示参数文本) |
| """ |
| |
| status_text = get_current_box_status(unified_image_input, box_mode_history) |
| |
| |
| if not unified_image_input or not isinstance(unified_image_input, dict): |
| return None, status_text, "None" |
| |
| image = unified_image_input.get('image') |
| if image is None: |
| return None, status_text, "None" |
| |
| |
| xyxy_boxes, box_labels = process_imageprompter_data(unified_image_input, box_mode_history, verbose=False) |
| |
| if not xyxy_boxes: |
| return image, status_text, "None" |
| |
| |
| preview_image = draw_boxes_with_labels(image, xyxy_boxes, box_labels) |
| |
| |
| boxes_int = [[int(coord) for coord in box] for box in xyxy_boxes] |
| if len(xyxy_boxes) == 1: |
| prompt_info_text = f"Box: {boxes_int[0]}\nLabel: {box_labels[0]}" |
| else: |
| prompt_info_text = f"Boxes: {boxes_int}\nLabels: {box_labels}" |
| |
| return preview_image, status_text, prompt_info_text |
|
|
|
|
| |
|
|
| |
| example_image_path = None |
| example_image = None |
|
|
| |
| if example_image_hf_path and os.path.exists(example_image_hf_path): |
| example_image_path = example_image_hf_path |
| example_image = Image.open(example_image_hf_path) |
| print(f"✅ 使用 HF Hub 下载的示例图片: {example_image_path}") |
| else: |
| |
| SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) |
| sam3_asset_path = os.path.join(SCRIPT_DIR, "assets", "images", "test_image.jpg") |
| |
| if os.path.exists(sam3_asset_path): |
| example_image_path = os.path.abspath(sam3_asset_path) |
| example_image = Image.open(sam3_asset_path) |
| print(f"✅ 使用本地示例图片: {example_image_path}") |
| elif sam3 is not None: |
| |
| sam3_root = os.path.join(os.path.dirname(sam3.__file__), "..") |
| sam3_asset_path = os.path.join(sam3_root, "assets", "images", "test_image.jpg") |
| if os.path.exists(sam3_asset_path): |
| example_image_path = os.path.abspath(sam3_asset_path) |
| example_image = Image.open(sam3_asset_path) |
| print(f"✅ 使用 sam3 模块示例图片: {example_image_path}") |
|
|
| if example_image is None: |
| print(f"⚠️ 示例图片未找到,使用占位图") |
| example_image_path = None |
| example_image = Image.new('RGB', (512, 512), color='lightgray') |
|
|
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| example_prompts_info = { |
| "Text": "None", |
| "Single Box": "Box: [487, 302, 591, 641]\nLabel: True", |
| "Multi Box": "Boxes: [[487, 302, 591, 641], [341, 275, 495, 662]]\nLabels: [False, True]" |
| } |
|
|
| if IMAGE_PROMPTER_AVAILABLE: |
| |
| if example_image_path: |
| example_data_corrected = [ |
| [{"image": example_image_path, "points": []}, "Text", example_prompts_info["Text"], "shoe"], |
| [{"image": example_image_path, "points": [[487.0, 302.0, 2, 591.0, 641.0, 3]]}, "Single Box", example_prompts_info["Single Box"], ""], |
| [{"image": example_image_path, "points": [[487.0, 302.0, 2, 591.0, 641.0, 3], [341.0, 275.0, 2, 495.0, 662.0, 3]]}, "Multi Box", example_prompts_info["Multi Box"], ""], |
| ] |
| else: |
| example_data_corrected = [ |
| [{"image": example_image, "points": []}, "Text", example_prompts_info["Text"], "shoe"], |
| [{"image": example_image, "points": [[487.0, 302.0, 2, 591.0, 641.0, 3]]}, "Single Box", example_prompts_info["Single Box"], ""], |
| [{"image": example_image, "points": [[487.0, 302.0, 2, 591.0, 641.0, 3], [341.0, 275.0, 2, 495.0, 662.0, 3]]}, "Multi Box", example_prompts_info["Multi Box"], ""], |
| ] |
| |
| example_multi_box_history = [(0, "negative"), (1, "positive")] |
| else: |
| |
| if example_image_path: |
| example_data_corrected = [ |
| [example_image_path, "Text", example_prompts_info["Text"], "shoe"], |
| [example_image_path, "Single Box", example_prompts_info["Single Box"], ""], |
| [example_image_path, "Multi Box", example_prompts_info["Multi Box"], ""], |
| ] |
| else: |
| example_data_corrected = [ |
| [example_image, "Text", example_prompts_info["Text"], "shoe"], |
| [example_image, "Single Box", example_prompts_info["Single Box"], ""], |
| [example_image, "Multi Box", example_prompts_info["Multi Box"], ""], |
| ] |
| example_multi_box_history = [(0, "positive")] |
|
|
|
|
| def on_example_select( |
| unified_image_input: Any, |
| prompt_type: str |
| ) -> Tuple[List[Tuple[int, str]], Image.Image, str]: |
| """ |
| 当用户选择示例时,自动更新框模式历史和预览。 |
| |
| Args: |
| unified_image_input: ImagePrompter 的数据 |
| prompt_type: 提示类型 (Text, Single Box, Multi Box) |
| |
| Returns: |
| tuple: (框模式历史, 预览图像, 状态文本) |
| """ |
| |
| if prompt_type == "Multi Box": |
| |
| box_history = [(0, "positive"), (1, "negative")] |
| elif prompt_type == "Single Box": |
| |
| box_history = [(0, "positive")] |
| else: |
| |
| box_history = [(0, "positive")] |
| |
| |
| preview_image = None |
| status_text = "尚未绘制框" |
| |
| if unified_image_input and isinstance(unified_image_input, dict): |
| image = unified_image_input.get('image') |
| if image is not None: |
| xyxy_boxes, box_labels = process_imageprompter_data(unified_image_input, box_history, verbose=False) |
| if xyxy_boxes: |
| preview_image = draw_boxes_with_labels(image, xyxy_boxes, box_labels) |
| num_positive = sum(box_labels) |
| num_negative = len(box_labels) - num_positive |
| status_text = f"已绘制 {len(xyxy_boxes)} 个框(正框: {num_positive}, 负框: {num_negative})" |
| else: |
| preview_image = image |
| |
| return box_history, preview_image, status_text |
|
|
|
|
| |
| |
| with gr.Blocks() as demo: |
| |
| box_mode_history_state = gr.State([(0, "positive")]) |
| |
| gr.Markdown( |
| """ |
| # 🎯 SAM 3 Demo |
| **Segment Anything Model 3 - 目标检测与分割** |
| |
| > 🚀 Powered by Hugging Face Spaces |
| """ |
| ) |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| |
| with gr.Accordion("📋 使用说明", open=False): |
| gr.Markdown(""" |
| **使用方法:** |
| |
| 📝 **Text 模式** |
| 1. 选择 "Text" 模式 |
| 2. 上传图像 |
| 3. 输入文本提示词(如 "shoe", "person") |
| 4. 点击"运行 SAM 3 分割" |
| |
| ⬜ **Single Box 模式** |
| 1. 选择 "Single Box" 模式 |
| 2. 上传图像 |
| 3. 在图像上绘制一个矩形框 |
| 4. 点击"运行 SAM 3 分割" |
| |
| 🔲 **Multi Box 模式(支持正/负框)** |
| 1. 选择 "Multi Box" 模式 |
| 2. 上传图像 |
| 3. **默认为正框模式**,绘制的框将包含目标 |
| 4. 如需绘制负框(排除区域): |
| - 先绘制正框 |
| - 点击切换到 "负框 (Negative)" 模式 |
| - 继续绘制负框 |
| 5. 点击「🔄 刷新预览」按钮查看框标签预览 |
| 6. 点击"运行 SAM 3 分割" |
| |
| 💡 **正框 vs 负框** |
| - **正框(红色)**: 告诉模型"包含这个区域的目标" |
| - **负框(绿色)**: 告诉模型"排除这个区域",用于去除误检 |
| - **注意**: 负框需要配合正框使用才能生效 |
| |
| ⚙️ **显示选项** |
| - **显示中心点**: 用圆点代替边框显示检测结果中心位置 |
| - **显示置信度**: 在结果中显示模型的置信度分数 |
| - **线条/点宽度**: 调整边框线宽或中心点大小 |
| """) |
| |
| if IMAGE_PROMPTER_AVAILABLE: |
| unified_image_input = ImagePrompter( |
| label="🖼️ 示例图像", |
| type="pil" |
| ) |
| else: |
| unified_image_input = gr.Image( |
| label="🖼️ 示例图像", |
| type="pil" |
| ) |
| |
| prompt_type = gr.Radio( |
| ["Text", "Single Box", "Multi Box"], |
| label="提示类型", |
| value="Text" |
| ) |
| |
| text_prompt_input = gr.Textbox( |
| label="文本提示参数", |
| value="shoe", |
| visible=True |
| ) |
| |
| |
| example_prompt_info_display = gr.Textbox( |
| label="框提示参数", |
| value="", |
| interactive=False, |
| lines=2, |
| visible=True |
| ) |
| |
| |
| with gr.Group(visible=False) as box_mode_group: |
| gr.Markdown("### 🎯 框模式设置") |
| box_mode_selector = gr.Radio( |
| ["正框 (Positive)", "负框 (Negative)"], |
| label="当前绘制模式", |
| value="正框 (Positive)", |
| info="正框=包含目标,负框=排除区域" |
| ) |
| with gr.Row(): |
| reset_history_btn = gr.Button("🔄 重置框标签", size="sm") |
| |
| |
| with gr.Group(visible=False) as box_preview_group: |
| gr.Markdown("### 📦 框预览(红色=正框 True,绿色=负框 False)") |
| box_status_text = gr.Textbox( |
| label="框状态", |
| value="尚未绘制框", |
| interactive=False |
| ) |
| refresh_preview_btn = gr.Button("🔄 刷新预览", size="sm", variant="secondary") |
| gr.Markdown("*绘制框后点击「刷新预览」按钮查看标注效果*") |
| box_preview_image = gr.Image( |
| label="框标签预览", |
| type="pil", |
| interactive=False |
| ) |
| |
| |
| gr.Markdown("### ⚙️ 显示选项") |
| with gr.Row(): |
| return_point = gr.Checkbox(label="显示中心点", value=False) |
| return_score = gr.Checkbox(label="显示置信度", value=True) |
| point_width = gr.Slider( |
| label="线条/点宽度", |
| value=3.0, |
| minimum=0.0, |
| maximum=20.0, |
| step=0.1, |
| ) |
| |
| run_button = gr.Button("Run SAM3", variant="primary") |
|
|
| with gr.Column(scale=2): |
| output_image = gr.Image(label="分割结果", type="pil") |
| result_info = gr.Textbox(label="结果信息", lines=2) |
| |
| def run_example(img, ptype, prompt_info, text): |
| """运行示例时使用正确的框模式历史""" |
| if ptype == "Multi Box": |
| |
| history = [(0, "negative"), (1, "positive")] |
| else: |
| history = [(0, "positive")] |
| result_img, result_text, _ = sam3_segmentation(img, text, ptype, history, False, 3.0, True) |
| return result_img, result_text |
| |
| gr.Examples( |
| examples=example_data_corrected, |
| inputs=[unified_image_input, prompt_type, example_prompt_info_display, text_prompt_input], |
| outputs=[output_image, result_info], |
| fn=run_example, |
| cache_examples=False, |
| label="示例" |
| ) |
|
|
| |
| run_button.click( |
| fn=sam3_segmentation, |
| inputs=[unified_image_input, text_prompt_input, prompt_type, box_mode_history_state, return_point, point_width, return_score], |
| outputs=[output_image, result_info, box_mode_history_state] |
| ) |
|
|
| |
| box_mode_selector.change( |
| fn=on_box_mode_change, |
| inputs=[box_mode_selector, unified_image_input, box_mode_history_state], |
| outputs=[box_mode_history_state, box_status_text, box_preview_image] |
| ) |
| |
| |
| reset_history_btn.click( |
| fn=reset_box_mode_history, |
| inputs=[unified_image_input], |
| outputs=[box_mode_history_state, box_status_text, box_preview_image] |
| ) |
| |
| |
| |
| refresh_preview_btn.click( |
| fn=update_box_preview, |
| inputs=[unified_image_input, box_mode_history_state], |
| outputs=[box_preview_image, box_status_text, example_prompt_info_display] |
| ) |
|
|
| def update_inputs(p_type): |
| is_text = p_type == "Text" |
| is_multi_box = p_type == "Multi Box" |
| is_box_mode = p_type in ["Single Box", "Multi Box"] |
| return ( |
| gr.update(visible=is_text), |
| gr.update(visible=is_multi_box), |
| gr.update(visible=is_box_mode) |
| ) |
|
|
| def update_inputs_and_preview(p_type, img_input): |
| """ |
| 更新输入组件可见性,并在示例加载时自动更新预览。 |
| """ |
| is_text = p_type == "Text" |
| is_multi_box = p_type == "Multi Box" |
| is_box_mode = p_type in ["Single Box", "Multi Box"] |
| |
| |
| if p_type == "Multi Box": |
| box_history = [(0, "negative"), (1, "positive")] |
| elif p_type == "Single Box": |
| box_history = [(0, "positive")] |
| else: |
| box_history = [(0, "positive")] |
| |
| |
| preview_image = None |
| status_text = "尚未绘制框" |
| |
| if img_input and isinstance(img_input, dict): |
| image = img_input.get('image') |
| if image is not None: |
| xyxy_boxes, box_labels = process_imageprompter_data(img_input, box_history, verbose=False) |
| if xyxy_boxes: |
| preview_image = draw_boxes_with_labels(image, xyxy_boxes, box_labels) |
| num_positive = sum(box_labels) |
| num_negative = len(box_labels) - num_positive |
| status_text = f"已绘制 {len(xyxy_boxes)} 个框(正框: {num_positive}, 负框: {num_negative})" |
| else: |
| preview_image = image |
| |
| return ( |
| gr.update(visible=is_text), |
| gr.update(visible=is_multi_box), |
| gr.update(visible=is_box_mode), |
| box_history, |
| preview_image, |
| status_text |
| ) |
|
|
| prompt_type.change( |
| fn=update_inputs_and_preview, |
| inputs=[prompt_type, unified_image_input], |
| outputs=[text_prompt_input, box_mode_group, box_preview_group, box_mode_history_state, box_preview_image, box_status_text] |
| ) |
| |
| |
| |
| |
| demo.load( |
| fn=update_inputs, |
| inputs=[prompt_type], |
| outputs=[text_prompt_input, box_mode_group, box_preview_group] |
| ) |
|
|
|
|
| |
| if __name__ == "__main__": |
| |
| |
| demo.launch( |
| show_error=True |
| ) |
|
|