from __future__ import annotations from typing import Any, Dict import numpy as np import torch from PIL import Image import gradio as gr from transformers import ( AutoModelForZeroShotObjectDetection, AutoProcessor, SamModel, SamProcessor, ) GROUNDING_MODEL_ID = "IDEA-Research/grounding-dino-base" SAM_MODEL_ID = "facebook/sam-vit-base" # 載入模型(使用 CPU) device = "cpu" grounding_processor = AutoProcessor.from_pretrained(GROUNDING_MODEL_ID) grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(GROUNDING_MODEL_ID).to(device) sam_processor = SamProcessor.from_pretrained(SAM_MODEL_ID) sam_model = SamModel.from_pretrained(SAM_MODEL_ID).to(device) def segment_image_with_text( image: Image.Image, text_prompt: str) -> tuple[np.ndarray, Dict[str, Any]]: """ 使用 Grounding DINO 檢測物件,然後使用 SAM 進行分割 Args: image: PIL Image text_prompt: 文字提示 Returns: tuple: (分割遮罩, 除錯資訊) """ try: # 格式化文字提示:確保每個詞後面都有句號 # Grounding DINO 期望的格式是 "object1. object2. object3.",沒有句號會有偵測不到的問題 formatted_prompt = text_prompt.strip() if not formatted_prompt.endswith('.'): formatted_prompt += '.' # 步驟 1: 使用 Grounding DINO 檢測物件 inputs = grounding_processor(images=image, text=formatted_prompt, return_tensors="pt").to(device) with torch.no_grad(): outputs = grounding_model(**inputs) # 後處理檢測結果 results = grounding_processor.post_process_grounded_object_detection( outputs, inputs.input_ids, threshold=0.15, # 降低閾值以檢測更多物件 target_sizes=[image.size[::-1]] ) boxes = results[0]["boxes"].cpu().numpy() scores = results[0]["scores"].cpu().numpy() labels = results[0]["labels"] debug_info = { "num_detections": len(boxes), "scores": scores.tolist(), "labels": labels, "boxes": boxes.tolist(), "original_prompt": text_prompt } # 準備圖像陣列 image_array = np.array(image) if image_array.shape[-1] == 4: # 如果有 alpha 通道 image_array = image_array[..., :3] if len(boxes) == 0: # 沒有檢測到物件,返回原圖 return image_array, debug_info # 步驟 2: 為每個檢測框使用 SAM 進行分割 overlay = image_array.copy() # 為每個檢測到的物件生成不同顏色的遮罩 colors = [ [255, 0, 0], # 紅色 [0, 255, 0], # 綠色 [0, 0, 255], # 藍色 [255, 255, 0], # 黃色 ] for idx, (box, label, score) in enumerate(zip(boxes, labels, scores)): # 為每個框單獨使用 SAM # box 格式應該是 [x_min, y_min, x_max, y_max] sam_inputs = sam_processor( image, input_boxes=[[box.tolist()]], return_tensors="pt" ).to(device) with torch.no_grad(): sam_outputs = sam_model(**sam_inputs) # 取得分割遮罩 masks = sam_processor.image_processor.post_process_masks( sam_outputs.pred_masks.cpu(), sam_inputs["original_sizes"].cpu(), sam_inputs["reshaped_input_sizes"].cpu() ) # 取得第一個遮罩並轉換為 numpy mask = masks[0].squeeze().numpy() if mask.ndim == 3: mask = mask[0] # 取第一個遮罩 mask = mask > 0 # 使用不同顏色標示不同物件 color = colors[idx % len(colors)] mask_overlay = np.zeros_like(image_array) mask_overlay[mask] = color # 混合到結果圖像 overlay[mask] = overlay[mask] * 0.5 + mask_overlay[mask] * 0.5 return overlay.astype(np.uint8), debug_info except Exception as e: debug_info = {"error": str(e)} return np.zeros((480, 640, 3), dtype=np.uint8), debug_info def segment_and_display(image, text): if image is None: return None, "請上傳圖片" if not text or text.strip() == "": return None, "請輸入文字提示" # 轉換為 PIL Image(如果需要) if isinstance(image, np.ndarray): image = Image.fromarray(image) result_image, debug_info = segment_image_with_text(image, text) output_text = f"檢測到 {debug_info.get('num_detections', 0)} 個物件\n---\n" output_text += f"輸入文字: '{debug_info.get('original_prompt', text)}'\n---\n" if 'labels' in debug_info and len(debug_info['labels']) > 0: output_text += "檢測結果:\n" for i, (label, score) in enumerate(zip(debug_info['labels'], debug_info.get('scores', []))): color_map = ['紅色', '綠色', '藍色', '黃色'] color_name = color_map[i % len(color_map)] output_text += f" {i+1}. {label} (信心度: {score:.2f}, 顏色: {color_name})\n" if 'error' in debug_info: output_text += f"\n錯誤: {debug_info['error']}" return result_image, output_text with gr.Blocks() as demo: gr.Markdown("# Text-Guided Image Segmentation Demo") gr.Markdown(""" ### 使用說明 1. 上傳一張圖片 (有提供預設圖片方便 demo) 2. 輸入文字描述(例如:car、sky、road) 3. 多個物件請用句號分隔(例如:car. sky. road.) """) with gr.Row(): with gr.Column(scale=1): # 預設圖片 car.jpg image_input = gr.Image(label="Input Image", value="sample_images/car.jpg") text_input = gr.Textbox( label="Text Prompt", placeholder="e.g. 'car. sky. road.'", lines=1 ) with gr.Row(): segment_button = gr.Button("Segment", variant="primary") clear_button = gr.Button("Clear") with gr.Column(scale=2): output_mask = gr.Image(label="Segmentation Mask", type="numpy") summary = gr.Textbox( label="Summary", interactive=False, lines=3, max_lines=20, show_copy_button=True ) text_input.submit(segment_and_display, inputs=[image_input, text_input], outputs=[output_mask, summary]) segment_button.click(segment_and_display, inputs=[image_input, text_input], outputs=[output_mask, summary]) clear_button.click(lambda: (None, "", None, ""), inputs=None, outputs=[image_input, text_input, output_mask, summary]) demo.launch()