Spaces:
Running
Running
| from __future__ import annotations | |
| from typing import Any, Dict | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| import gradio as gr | |
| from transformers import ( | |
| AutoModelForZeroShotObjectDetection, | |
| AutoProcessor, | |
| SamModel, | |
| SamProcessor, | |
| ) | |
| GROUNDING_MODEL_ID = "IDEA-Research/grounding-dino-base" | |
| SAM_MODEL_ID = "facebook/sam-vit-base" | |
| # 載入模型(使用 CPU) | |
| device = "cpu" | |
| grounding_processor = AutoProcessor.from_pretrained(GROUNDING_MODEL_ID) | |
| grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(GROUNDING_MODEL_ID).to(device) | |
| sam_processor = SamProcessor.from_pretrained(SAM_MODEL_ID) | |
| sam_model = SamModel.from_pretrained(SAM_MODEL_ID).to(device) | |
| def segment_image_with_text( | |
| image: Image.Image, | |
| text_prompt: str) -> tuple[np.ndarray, Dict[str, Any]]: | |
| """ | |
| 使用 Grounding DINO 檢測物件,然後使用 SAM 進行分割 | |
| Args: | |
| image: PIL Image | |
| text_prompt: 文字提示 | |
| Returns: | |
| tuple: (分割遮罩, 除錯資訊) | |
| """ | |
| try: | |
| # 格式化文字提示:確保每個詞後面都有句號 | |
| # Grounding DINO 期望的格式是 "object1. object2. object3.",沒有句號會有偵測不到的問題 | |
| formatted_prompt = text_prompt.strip() | |
| if not formatted_prompt.endswith('.'): | |
| formatted_prompt += '.' | |
| # 步驟 1: 使用 Grounding DINO 檢測物件 | |
| inputs = grounding_processor(images=image, text=formatted_prompt, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = grounding_model(**inputs) | |
| # 後處理檢測結果 | |
| results = grounding_processor.post_process_grounded_object_detection( | |
| outputs, | |
| inputs.input_ids, | |
| threshold=0.15, # 降低閾值以檢測更多物件 | |
| target_sizes=[image.size[::-1]] | |
| ) | |
| boxes = results[0]["boxes"].cpu().numpy() | |
| scores = results[0]["scores"].cpu().numpy() | |
| labels = results[0]["labels"] | |
| debug_info = { | |
| "num_detections": len(boxes), | |
| "scores": scores.tolist(), | |
| "labels": labels, | |
| "boxes": boxes.tolist(), | |
| "original_prompt": text_prompt | |
| } | |
| # 準備圖像陣列 | |
| image_array = np.array(image) | |
| if image_array.shape[-1] == 4: # 如果有 alpha 通道 | |
| image_array = image_array[..., :3] | |
| if len(boxes) == 0: | |
| # 沒有檢測到物件,返回原圖 | |
| return image_array, debug_info | |
| # 步驟 2: 為每個檢測框使用 SAM 進行分割 | |
| overlay = image_array.copy() | |
| # 為每個檢測到的物件生成不同顏色的遮罩 | |
| colors = [ | |
| [255, 0, 0], # 紅色 | |
| [0, 255, 0], # 綠色 | |
| [0, 0, 255], # 藍色 | |
| [255, 255, 0], # 黃色 | |
| ] | |
| for idx, (box, label, score) in enumerate(zip(boxes, labels, scores)): | |
| # 為每個框單獨使用 SAM | |
| # box 格式應該是 [x_min, y_min, x_max, y_max] | |
| sam_inputs = sam_processor( | |
| image, | |
| input_boxes=[[box.tolist()]], | |
| return_tensors="pt" | |
| ).to(device) | |
| with torch.no_grad(): | |
| sam_outputs = sam_model(**sam_inputs) | |
| # 取得分割遮罩 | |
| masks = sam_processor.image_processor.post_process_masks( | |
| sam_outputs.pred_masks.cpu(), | |
| sam_inputs["original_sizes"].cpu(), | |
| sam_inputs["reshaped_input_sizes"].cpu() | |
| ) | |
| # 取得第一個遮罩並轉換為 numpy | |
| mask = masks[0].squeeze().numpy() | |
| if mask.ndim == 3: | |
| mask = mask[0] # 取第一個遮罩 | |
| mask = mask > 0 | |
| # 使用不同顏色標示不同物件 | |
| color = colors[idx % len(colors)] | |
| mask_overlay = np.zeros_like(image_array) | |
| mask_overlay[mask] = color | |
| # 混合到結果圖像 | |
| overlay[mask] = overlay[mask] * 0.5 + mask_overlay[mask] * 0.5 | |
| return overlay.astype(np.uint8), debug_info | |
| except Exception as e: | |
| debug_info = {"error": str(e)} | |
| return np.zeros((480, 640, 3), dtype=np.uint8), debug_info | |
| def segment_and_display(image, text): | |
| if image is None: | |
| return None, "請上傳圖片" | |
| if not text or text.strip() == "": | |
| return None, "請輸入文字提示" | |
| # 轉換為 PIL Image(如果需要) | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| result_image, debug_info = segment_image_with_text(image, text) | |
| output_text = f"檢測到 {debug_info.get('num_detections', 0)} 個物件\n---\n" | |
| output_text += f"輸入文字: '{debug_info.get('original_prompt', text)}'\n---\n" | |
| if 'labels' in debug_info and len(debug_info['labels']) > 0: | |
| output_text += "檢測結果:\n" | |
| for i, (label, score) in enumerate(zip(debug_info['labels'], debug_info.get('scores', []))): | |
| color_map = ['紅色', '綠色', '藍色', '黃色'] | |
| color_name = color_map[i % len(color_map)] | |
| output_text += f" {i+1}. {label} (信心度: {score:.2f}, 顏色: {color_name})\n" | |
| if 'error' in debug_info: | |
| output_text += f"\n錯誤: {debug_info['error']}" | |
| return result_image, output_text | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Text-Guided Image Segmentation Demo") | |
| gr.Markdown(""" | |
| ### 使用說明 | |
| 1. 上傳一張圖片 (有提供預設圖片方便 demo) | |
| 2. 輸入文字描述(例如:car、sky、road) | |
| 3. 多個物件請用句號分隔(例如:car. sky. road.) | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # 預設圖片 car.jpg | |
| image_input = gr.Image(label="Input Image", value="sample_images/car.jpg") | |
| text_input = gr.Textbox( | |
| label="Text Prompt", | |
| placeholder="e.g. 'car. sky. road.'", | |
| lines=1 | |
| ) | |
| with gr.Row(): | |
| segment_button = gr.Button("Segment", variant="primary") | |
| clear_button = gr.Button("Clear") | |
| with gr.Column(scale=2): | |
| output_mask = gr.Image(label="Segmentation Mask", type="numpy") | |
| summary = gr.Textbox( | |
| label="Summary", | |
| interactive=False, | |
| lines=3, | |
| max_lines=20, | |
| show_copy_button=True | |
| ) | |
| text_input.submit(segment_and_display, inputs=[image_input, text_input], outputs=[output_mask, summary]) | |
| segment_button.click(segment_and_display, inputs=[image_input, text_input], outputs=[output_mask, summary]) | |
| clear_button.click(lambda: (None, "", None, ""), inputs=None, outputs=[image_input, text_input, output_mask, summary]) | |
| demo.launch() |