Spaces:
Sleeping
Sleeping
File size: 7,069 Bytes
08d6fec 72c908c 08d6fec 72c908c 08d6fec 72c908c 08d6fec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
from __future__ import annotations
from typing import Any, Dict
import numpy as np
import torch
from PIL import Image
import gradio as gr
from transformers import (
AutoModelForZeroShotObjectDetection,
AutoProcessor,
SamModel,
SamProcessor,
)
GROUNDING_MODEL_ID = "IDEA-Research/grounding-dino-base"
SAM_MODEL_ID = "facebook/sam-vit-base"
# 載入模型(使用 CPU)
device = "cpu"
grounding_processor = AutoProcessor.from_pretrained(GROUNDING_MODEL_ID)
grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(GROUNDING_MODEL_ID).to(device)
sam_processor = SamProcessor.from_pretrained(SAM_MODEL_ID)
sam_model = SamModel.from_pretrained(SAM_MODEL_ID).to(device)
def segment_image_with_text(
image: Image.Image,
text_prompt: str) -> tuple[np.ndarray, Dict[str, Any]]:
"""
使用 Grounding DINO 檢測物件,然後使用 SAM 進行分割
Args:
image: PIL Image
text_prompt: 文字提示
Returns:
tuple: (分割遮罩, 除錯資訊)
"""
try:
# 格式化文字提示:確保每個詞後面都有句號
# Grounding DINO 期望的格式是 "object1. object2. object3.",沒有句號會有偵測不到的問題
formatted_prompt = text_prompt.strip()
if not formatted_prompt.endswith('.'):
formatted_prompt += '.'
# 步驟 1: 使用 Grounding DINO 檢測物件
inputs = grounding_processor(images=image, text=formatted_prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = grounding_model(**inputs)
# 後處理檢測結果
results = grounding_processor.post_process_grounded_object_detection(
outputs,
inputs.input_ids,
threshold=0.15, # 降低閾值以檢測更多物件
target_sizes=[image.size[::-1]]
)
boxes = results[0]["boxes"].cpu().numpy()
scores = results[0]["scores"].cpu().numpy()
labels = results[0]["labels"]
debug_info = {
"num_detections": len(boxes),
"scores": scores.tolist(),
"labels": labels,
"boxes": boxes.tolist(),
"original_prompt": text_prompt
}
# 準備圖像陣列
image_array = np.array(image)
if image_array.shape[-1] == 4: # 如果有 alpha 通道
image_array = image_array[..., :3]
if len(boxes) == 0:
# 沒有檢測到物件,返回原圖
return image_array, debug_info
# 步驟 2: 為每個檢測框使用 SAM 進行分割
overlay = image_array.copy()
# 為每個檢測到的物件生成不同顏色的遮罩
colors = [
[255, 0, 0], # 紅色
[0, 255, 0], # 綠色
[0, 0, 255], # 藍色
[255, 255, 0], # 黃色
]
for idx, (box, label, score) in enumerate(zip(boxes, labels, scores)):
# 為每個框單獨使用 SAM
# box 格式應該是 [x_min, y_min, x_max, y_max]
sam_inputs = sam_processor(
image,
input_boxes=[[box.tolist()]],
return_tensors="pt"
).to(device)
with torch.no_grad():
sam_outputs = sam_model(**sam_inputs)
# 取得分割遮罩
masks = sam_processor.image_processor.post_process_masks(
sam_outputs.pred_masks.cpu(),
sam_inputs["original_sizes"].cpu(),
sam_inputs["reshaped_input_sizes"].cpu()
)
# 取得第一個遮罩並轉換為 numpy
mask = masks[0].squeeze().numpy()
if mask.ndim == 3:
mask = mask[0] # 取第一個遮罩
mask = mask > 0
# 使用不同顏色標示不同物件
color = colors[idx % len(colors)]
mask_overlay = np.zeros_like(image_array)
mask_overlay[mask] = color
# 混合到結果圖像
overlay[mask] = overlay[mask] * 0.5 + mask_overlay[mask] * 0.5
return overlay.astype(np.uint8), debug_info
except Exception as e:
debug_info = {"error": str(e)}
return np.zeros((480, 640, 3), dtype=np.uint8), debug_info
def segment_and_display(image, text):
if image is None:
return None, "請上傳圖片"
if not text or text.strip() == "":
return None, "請輸入文字提示"
# 轉換為 PIL Image(如果需要)
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
result_image, debug_info = segment_image_with_text(image, text)
output_text = f"檢測到 {debug_info.get('num_detections', 0)} 個物件\n---\n"
output_text += f"輸入文字: '{debug_info.get('original_prompt', text)}'\n---\n"
if 'labels' in debug_info and len(debug_info['labels']) > 0:
output_text += "檢測結果:\n"
for i, (label, score) in enumerate(zip(debug_info['labels'], debug_info.get('scores', []))):
color_map = ['紅色', '綠色', '藍色', '黃色']
color_name = color_map[i % len(color_map)]
output_text += f" {i+1}. {label} (信心度: {score:.2f}, 顏色: {color_name})\n"
if 'error' in debug_info:
output_text += f"\n錯誤: {debug_info['error']}"
return result_image, output_text
with gr.Blocks() as demo:
gr.Markdown("# Text-Guided Image Segmentation Demo")
gr.Markdown("""
### 使用說明
1. 上傳一張圖片 (有提供預設圖片方便 demo)
2. 輸入文字描述(例如:car、sky、road)
3. 多個物件請用句號分隔(例如:car. sky. road.)
""")
with gr.Row():
with gr.Column(scale=1):
# 預設圖片 car.jpg
image_input = gr.Image(label="Input Image", value="sample_images/car.jpg")
text_input = gr.Textbox(
label="Text Prompt",
placeholder="e.g. 'car. sky. road.'",
lines=1
)
with gr.Row():
segment_button = gr.Button("Segment", variant="primary")
clear_button = gr.Button("Clear")
with gr.Column(scale=2):
output_mask = gr.Image(label="Segmentation Mask", type="numpy")
debug_output = gr.Textbox(
label="Summary",
interactive=False,
lines=3,
max_lines=20,
show_copy_button=True
)
text_input.submit(segment_and_display, inputs=[image_input, text_input], outputs=[output_mask, debug_output])
segment_button.click(segment_and_display, inputs=[image_input, text_input], outputs=[output_mask, debug_output])
clear_button.click(lambda: (None, ""), inputs=None, outputs=[image_input, text_input])
demo.launch() |