File size: 7,069 Bytes
08d6fec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72c908c
 
08d6fec
 
 
 
 
 
 
 
 
 
 
 
 
 
72c908c
08d6fec
 
 
 
 
 
 
 
 
 
72c908c
08d6fec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
from __future__ import annotations

from typing import Any, Dict

import numpy as np
import torch
from PIL import Image
import gradio as gr
from transformers import (
    AutoModelForZeroShotObjectDetection,
    AutoProcessor,
    SamModel,
    SamProcessor,
)

GROUNDING_MODEL_ID = "IDEA-Research/grounding-dino-base"
SAM_MODEL_ID = "facebook/sam-vit-base"

# 載入模型(使用 CPU)
device = "cpu"
grounding_processor = AutoProcessor.from_pretrained(GROUNDING_MODEL_ID)
grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(GROUNDING_MODEL_ID).to(device)
sam_processor = SamProcessor.from_pretrained(SAM_MODEL_ID)
sam_model = SamModel.from_pretrained(SAM_MODEL_ID).to(device)


def segment_image_with_text(
    image: Image.Image,
    text_prompt: str) -> tuple[np.ndarray, Dict[str, Any]]:
    """
    使用 Grounding DINO 檢測物件,然後使用 SAM 進行分割
    
    Args:
        image: PIL Image
        text_prompt: 文字提示
        
    Returns:
        tuple: (分割遮罩, 除錯資訊)
    """
    try:
        # 格式化文字提示:確保每個詞後面都有句號
        # Grounding DINO 期望的格式是 "object1. object2. object3.",沒有句號會有偵測不到的問題
        formatted_prompt = text_prompt.strip()
        if not formatted_prompt.endswith('.'):
            formatted_prompt += '.'
        
        # 步驟 1: 使用 Grounding DINO 檢測物件
        inputs = grounding_processor(images=image, text=formatted_prompt, return_tensors="pt").to(device)
        
        with torch.no_grad():
            outputs = grounding_model(**inputs)
        
        # 後處理檢測結果
        results = grounding_processor.post_process_grounded_object_detection(
            outputs,
            inputs.input_ids,
            threshold=0.15,  # 降低閾值以檢測更多物件
            target_sizes=[image.size[::-1]]
        )
        
        boxes = results[0]["boxes"].cpu().numpy()
        scores = results[0]["scores"].cpu().numpy()
        labels = results[0]["labels"]
        
        debug_info = {
            "num_detections": len(boxes),
            "scores": scores.tolist(),
            "labels": labels,
            "boxes": boxes.tolist(),
            "original_prompt": text_prompt
        }
        
        # 準備圖像陣列
        image_array = np.array(image)
        if image_array.shape[-1] == 4:  # 如果有 alpha 通道
            image_array = image_array[..., :3]
        
        if len(boxes) == 0:
            # 沒有檢測到物件,返回原圖
            return image_array, debug_info
        
        # 步驟 2: 為每個檢測框使用 SAM 進行分割
        overlay = image_array.copy()
        
        # 為每個檢測到的物件生成不同顏色的遮罩
        colors = [
            [255, 0, 0],    # 紅色
            [0, 255, 0],    # 綠色
            [0, 0, 255],    # 藍色
            [255, 255, 0],  # 黃色
        ]
        
        for idx, (box, label, score) in enumerate(zip(boxes, labels, scores)):
            # 為每個框單獨使用 SAM
            # box 格式應該是 [x_min, y_min, x_max, y_max]
            sam_inputs = sam_processor(
                image,
                input_boxes=[[box.tolist()]],
                return_tensors="pt"
            ).to(device)
            
            with torch.no_grad():
                sam_outputs = sam_model(**sam_inputs)
            
            # 取得分割遮罩
            masks = sam_processor.image_processor.post_process_masks(
                sam_outputs.pred_masks.cpu(),
                sam_inputs["original_sizes"].cpu(),
                sam_inputs["reshaped_input_sizes"].cpu()
            )
            
            # 取得第一個遮罩並轉換為 numpy
            mask = masks[0].squeeze().numpy()
            if mask.ndim == 3:
                mask = mask[0]  # 取第一個遮罩
            mask = mask > 0
            
            # 使用不同顏色標示不同物件
            color = colors[idx % len(colors)]
            mask_overlay = np.zeros_like(image_array)
            mask_overlay[mask] = color
            
            # 混合到結果圖像
            overlay[mask] = overlay[mask] * 0.5 + mask_overlay[mask] * 0.5
        
        return overlay.astype(np.uint8), debug_info
        
    except Exception as e:
        debug_info = {"error": str(e)}
        return np.zeros((480, 640, 3), dtype=np.uint8), debug_info


def segment_and_display(image, text):
    if image is None:
        return None, "請上傳圖片"
    if not text or text.strip() == "":
        return None, "請輸入文字提示"
    
    # 轉換為 PIL Image(如果需要)
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)
    
    result_image, debug_info = segment_image_with_text(image, text)
    
    output_text = f"檢測到 {debug_info.get('num_detections', 0)} 個物件\n---\n"
    output_text += f"輸入文字: '{debug_info.get('original_prompt', text)}'\n---\n"
    
    if 'labels' in debug_info and len(debug_info['labels']) > 0:
        output_text += "檢測結果:\n"
        for i, (label, score) in enumerate(zip(debug_info['labels'], debug_info.get('scores', []))):
            color_map = ['紅色', '綠色', '藍色', '黃色']
            color_name = color_map[i % len(color_map)]
            output_text += f"  {i+1}. {label} (信心度: {score:.2f}, 顏色: {color_name})\n"
    if 'error' in debug_info:
        output_text += f"\n錯誤: {debug_info['error']}"
    
    return result_image, output_text


with gr.Blocks() as demo:
    gr.Markdown("# Text-Guided Image Segmentation Demo")
    gr.Markdown("""
    ### 使用說明
    1. 上傳一張圖片 (有提供預設圖片方便 demo)
    2. 輸入文字描述(例如:car、sky、road)
    3. 多個物件請用句號分隔(例如:car. sky. road.)
    """)

    with gr.Row():
        with gr.Column(scale=1):
            # 預設圖片 car.jpg
            image_input = gr.Image(label="Input Image", value="sample_images/car.jpg")
            text_input = gr.Textbox(
                label="Text Prompt", 
                placeholder="e.g. 'car. sky. road.'",
                lines=1
            )
            with gr.Row():
                segment_button = gr.Button("Segment", variant="primary")
                clear_button = gr.Button("Clear")
        with gr.Column(scale=2):
            output_mask = gr.Image(label="Segmentation Mask", type="numpy")
            debug_output = gr.Textbox(
                label="Summary", 
                interactive=False,
                lines=3,
                max_lines=20,
                show_copy_button=True
            )

    text_input.submit(segment_and_display, inputs=[image_input, text_input], outputs=[output_mask, debug_output])
    segment_button.click(segment_and_display, inputs=[image_input, text_input], outputs=[output_mask, debug_output])
    clear_button.click(lambda: (None, ""), inputs=None, outputs=[image_input, text_input])


demo.launch()