tingul4 commited on
Commit
08d6fec
·
1 Parent(s): 5f0608c

Add initial implementation for text-guided image segmentation

Browse files

- Introduced app.py for image segmentation using Grounding DINO and SAM models
- Added .gitignore to exclude virtual environment and cache files
- Updated .gitattributes to include support for image files
- Added car.jpg as a sample input image
- Created requirements.txt for project dependencies

Files changed (5) hide show
  1. .gitattributes +3 -0
  2. .gitignore +3 -0
  3. app.py +197 -0
  4. car.jpg +3 -0
  5. requirements.txt +5 -0
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.jpg filter=lfs diff=lfs merge=lfs -text
37
+ *.png filter=lfs diff=lfs merge=lfs -text
38
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .venv/
2
+ __pycache__/
3
+ .gradio/
app.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Dict
4
+
5
+ import numpy as np
6
+ import torch
7
+ from PIL import Image
8
+ import gradio as gr
9
+ from transformers import (
10
+ AutoModelForZeroShotObjectDetection,
11
+ AutoProcessor,
12
+ SamModel,
13
+ SamProcessor,
14
+ )
15
+
16
+ GROUNDING_MODEL_ID = "IDEA-Research/grounding-dino-base"
17
+ SAM_MODEL_ID = "facebook/sam-vit-base"
18
+
19
+ # 載入模型(使用 CPU)
20
+ device = "cpu"
21
+ grounding_processor = AutoProcessor.from_pretrained(GROUNDING_MODEL_ID)
22
+ grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(GROUNDING_MODEL_ID).to(device)
23
+ sam_processor = SamProcessor.from_pretrained(SAM_MODEL_ID)
24
+ sam_model = SamModel.from_pretrained(SAM_MODEL_ID).to(device)
25
+
26
+
27
+ def segment_image_with_text(
28
+ image: Image.Image,
29
+ text_prompt: str) -> tuple[np.ndarray, Dict[str, Any]]:
30
+ """
31
+ 使用 Grounding DINO 檢測物件,然後使用 SAM 進行分割
32
+
33
+ Args:
34
+ image: PIL Image
35
+ text_prompt: 文字提示
36
+
37
+ Returns:
38
+ tuple: (分割遮罩, 除錯資訊)
39
+ """
40
+ try:
41
+ # 格式化文字提示:確保每個詞後面都有句號
42
+ # Grounding DINO 期望的格式是 "object1. object2. object3.",沒有句號會有偵測不到的問題
43
+ formatted_prompt = text_prompt.strip()
44
+ if not formatted_prompt.endswith('.'):
45
+ formatted_prompt += '.'
46
+
47
+ # 步驟 1: 使用 Grounding DINO 檢測物件
48
+ inputs = grounding_processor(images=image, text=formatted_prompt, return_tensors="pt").to(device)
49
+
50
+ with torch.no_grad():
51
+ outputs = grounding_model(**inputs)
52
+
53
+ # 後處理檢測結果
54
+ results = grounding_processor.post_process_grounded_object_detection(
55
+ outputs,
56
+ inputs.input_ids,
57
+ threshold=0.15, # 降低閾值以檢測更多物件
58
+ target_sizes=[image.size[::-1]]
59
+ )
60
+
61
+ boxes = results[0]["boxes"].cpu().numpy()
62
+ scores = results[0]["scores"].cpu().numpy()
63
+ labels = results[0]["labels"]
64
+
65
+ debug_info = {
66
+ "num_detections": len(boxes),
67
+ "scores": scores.tolist(),
68
+ "labels": labels,
69
+ "boxes": boxes.tolist(),
70
+ "original_prompt": text_prompt
71
+ }
72
+
73
+ # 準備圖像陣列
74
+ image_array = np.array(image)
75
+ if image_array.shape[-1] == 4: # 如果有 alpha 通道
76
+ image_array = image_array[..., :3]
77
+
78
+ if len(boxes) == 0:
79
+ # 沒有檢測到物件,返回原圖
80
+ return image_array, debug_info
81
+
82
+ # 步驟 2: 為每個檢測框使用 SAM 進行分割
83
+ overlay = image_array.copy()
84
+
85
+ # 為每個檢測到的物件生成不同顏色的遮罩
86
+ colors = [
87
+ [255, 0, 0], # 紅色
88
+ [0, 255, 0], # 綠色
89
+ [0, 0, 255], # 藍色
90
+ [255, 255, 0], # 黃色
91
+ ]
92
+
93
+ for idx, (box, label, score) in enumerate(zip(boxes, labels, scores)):
94
+ # 為每個框單獨使用 SAM
95
+ # box 格式應該是 [x_min, y_min, x_max, y_max]
96
+ sam_inputs = sam_processor(
97
+ image,
98
+ input_boxes=[[box.tolist()]],
99
+ return_tensors="pt"
100
+ ).to(device)
101
+
102
+ with torch.no_grad():
103
+ sam_outputs = sam_model(**sam_inputs)
104
+
105
+ # 取得分割遮罩
106
+ masks = sam_processor.image_processor.post_process_masks(
107
+ sam_outputs.pred_masks.cpu(),
108
+ sam_inputs["original_sizes"].cpu(),
109
+ sam_inputs["reshaped_input_sizes"].cpu()
110
+ )
111
+
112
+ # 取得第一個遮罩並轉換為 numpy
113
+ mask = masks[0].squeeze().numpy()
114
+ if mask.ndim == 3:
115
+ mask = mask[0] # 取第一個遮罩
116
+ mask = mask > 0
117
+
118
+ # 使用不同顏色標示不同物件
119
+ color = colors[idx % len(colors)]
120
+ mask_overlay = np.zeros_like(image_array)
121
+ mask_overlay[mask] = color
122
+
123
+ # 混合到結果圖像
124
+ overlay[mask] = overlay[mask] * 0.5 + mask_overlay[mask] * 0.5
125
+
126
+ return overlay.astype(np.uint8), debug_info
127
+
128
+ except Exception as e:
129
+ debug_info = {"error": str(e)}
130
+ return np.zeros((480, 640, 3), dtype=np.uint8), debug_info
131
+
132
+
133
+ def segment_and_display(image, text):
134
+ if image is None:
135
+ return None, "請上傳圖片"
136
+ if not text or text.strip() == "":
137
+ return None, "請輸入文字提示"
138
+
139
+ # 轉換為 PIL Image(如果需要)
140
+ if isinstance(image, np.ndarray):
141
+ image = Image.fromarray(image)
142
+
143
+ result_image, debug_info = segment_image_with_text(image, text)
144
+
145
+ output_text = f"檢測到 {debug_info.get('num_detections', 0)} 個物件\n"
146
+ output_text += f"原始文字提示: '{debug_info.get('original_prompt', text)}'\n"
147
+
148
+ if 'labels' in debug_info and len(debug_info['labels']) > 0:
149
+ output_text += "檢測結果:\n"
150
+ for i, (label, score) in enumerate(zip(debug_info['labels'], debug_info.get('scores', []))):
151
+ color_map = ['紅色', '綠色', '藍色', '黃色']
152
+ color_name = color_map[i % len(color_map)]
153
+ output_text += f" {i+1}. {label} (信心度: {score:.2f}, 顏色: {color_name})\n"
154
+
155
+ if 'error' in debug_info:
156
+ output_text += f"\n錯誤: {debug_info['error']}"
157
+
158
+ return result_image, output_text
159
+
160
+
161
+ with gr.Blocks() as demo:
162
+ gr.Markdown("# Text-Guided Image Segmentation")
163
+ gr.Markdown("""
164
+ ### 使用說明
165
+ 1. 上傳一張圖片 (有提供預設圖片方便 demo)
166
+ 2. 輸入文字描述(例如:car、sky、road)
167
+ 3. 多個物件請用句號分隔(例如:car. sky. road.)
168
+ """)
169
+
170
+ with gr.Row():
171
+ with gr.Column(scale=1):
172
+ # 預設圖片 car.jpg
173
+ image_input = gr.Image(label="Input Image", value="car.jpg")
174
+ text_input = gr.Textbox(
175
+ label="Text Prompt",
176
+ placeholder="e.g. 'car. sky. road.'",
177
+ lines=1
178
+ )
179
+ with gr.Row():
180
+ segment_button = gr.Button("Segment", variant="primary")
181
+ clear_button = gr.Button("Clear")
182
+ with gr.Column(scale=2):
183
+ output_mask = gr.Image(label="Segmentation Mask", type="numpy")
184
+ debug_output = gr.Textbox(
185
+ label="Summary",
186
+ interactive=False,
187
+ lines=3,
188
+ max_lines=20,
189
+ show_copy_button=True
190
+ )
191
+
192
+ text_input.submit(segment_and_display, inputs=[image_input, text_input], outputs=[output_mask, debug_output])
193
+ segment_button.click(segment_and_display, inputs=[image_input, text_input], outputs=[output_mask, debug_output])
194
+ clear_button.click(lambda: (None, ""), inputs=None, outputs=[image_input, text_input])
195
+
196
+
197
+ demo.launch()
car.jpg ADDED

Git LFS Details

  • SHA256: 705731827597f812c2872db2c89fa9df45a22612714779bb8c38dd1e6af3ff3e
  • Pointer size: 130 Bytes
  • Size of remote file: 17.3 kB
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio
2
+ transformers
3
+ torch
4
+ Pillow
5
+ numpy