text-image-seg / app.py
tingul4's picture
1. refactor variables
b754bf5
from __future__ import annotations
from typing import Any, Dict
import numpy as np
import torch
from PIL import Image
import gradio as gr
from transformers import (
AutoModelForZeroShotObjectDetection,
AutoProcessor,
SamModel,
SamProcessor,
)
GROUNDING_MODEL_ID = "IDEA-Research/grounding-dino-base"
SAM_MODEL_ID = "facebook/sam-vit-base"
# 載入模型(使用 CPU)
device = "cpu"
grounding_processor = AutoProcessor.from_pretrained(GROUNDING_MODEL_ID)
grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(GROUNDING_MODEL_ID).to(device)
sam_processor = SamProcessor.from_pretrained(SAM_MODEL_ID)
sam_model = SamModel.from_pretrained(SAM_MODEL_ID).to(device)
def segment_image_with_text(
image: Image.Image,
text_prompt: str) -> tuple[np.ndarray, Dict[str, Any]]:
"""
使用 Grounding DINO 檢測物件,然後使用 SAM 進行分割
Args:
image: PIL Image
text_prompt: 文字提示
Returns:
tuple: (分割遮罩, 除錯資訊)
"""
try:
# 格式化文字提示:確保每個詞後面都有句號
# Grounding DINO 期望的格式是 "object1. object2. object3.",沒有句號會有偵測不到的問題
formatted_prompt = text_prompt.strip()
if not formatted_prompt.endswith('.'):
formatted_prompt += '.'
# 步驟 1: 使用 Grounding DINO 檢測物件
inputs = grounding_processor(images=image, text=formatted_prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = grounding_model(**inputs)
# 後處理檢測結果
results = grounding_processor.post_process_grounded_object_detection(
outputs,
inputs.input_ids,
threshold=0.15, # 降低閾值以檢測更多物件
target_sizes=[image.size[::-1]]
)
boxes = results[0]["boxes"].cpu().numpy()
scores = results[0]["scores"].cpu().numpy()
labels = results[0]["labels"]
debug_info = {
"num_detections": len(boxes),
"scores": scores.tolist(),
"labels": labels,
"boxes": boxes.tolist(),
"original_prompt": text_prompt
}
# 準備圖像陣列
image_array = np.array(image)
if image_array.shape[-1] == 4: # 如果有 alpha 通道
image_array = image_array[..., :3]
if len(boxes) == 0:
# 沒有檢測到物件,返回原圖
return image_array, debug_info
# 步驟 2: 為每個檢測框使用 SAM 進行分割
overlay = image_array.copy()
# 為每個檢測到的物件生成不同顏色的遮罩
colors = [
[255, 0, 0], # 紅色
[0, 255, 0], # 綠色
[0, 0, 255], # 藍色
[255, 255, 0], # 黃色
]
for idx, (box, label, score) in enumerate(zip(boxes, labels, scores)):
# 為每個框單獨使用 SAM
# box 格式應該是 [x_min, y_min, x_max, y_max]
sam_inputs = sam_processor(
image,
input_boxes=[[box.tolist()]],
return_tensors="pt"
).to(device)
with torch.no_grad():
sam_outputs = sam_model(**sam_inputs)
# 取得分割遮罩
masks = sam_processor.image_processor.post_process_masks(
sam_outputs.pred_masks.cpu(),
sam_inputs["original_sizes"].cpu(),
sam_inputs["reshaped_input_sizes"].cpu()
)
# 取得第一個遮罩並轉換為 numpy
mask = masks[0].squeeze().numpy()
if mask.ndim == 3:
mask = mask[0] # 取第一個遮罩
mask = mask > 0
# 使用不同顏色標示不同物件
color = colors[idx % len(colors)]
mask_overlay = np.zeros_like(image_array)
mask_overlay[mask] = color
# 混合到結果圖像
overlay[mask] = overlay[mask] * 0.5 + mask_overlay[mask] * 0.5
return overlay.astype(np.uint8), debug_info
except Exception as e:
debug_info = {"error": str(e)}
return np.zeros((480, 640, 3), dtype=np.uint8), debug_info
def segment_and_display(image, text):
if image is None:
return None, "請上傳圖片"
if not text or text.strip() == "":
return None, "請輸入文字提示"
# 轉換為 PIL Image(如果需要)
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
result_image, debug_info = segment_image_with_text(image, text)
output_text = f"檢測到 {debug_info.get('num_detections', 0)} 個物件\n---\n"
output_text += f"輸入文字: '{debug_info.get('original_prompt', text)}'\n---\n"
if 'labels' in debug_info and len(debug_info['labels']) > 0:
output_text += "檢測結果:\n"
for i, (label, score) in enumerate(zip(debug_info['labels'], debug_info.get('scores', []))):
color_map = ['紅色', '綠色', '藍色', '黃色']
color_name = color_map[i % len(color_map)]
output_text += f" {i+1}. {label} (信心度: {score:.2f}, 顏色: {color_name})\n"
if 'error' in debug_info:
output_text += f"\n錯誤: {debug_info['error']}"
return result_image, output_text
with gr.Blocks() as demo:
gr.Markdown("# Text-Guided Image Segmentation Demo")
gr.Markdown("""
### 使用說明
1. 上傳一張圖片 (有提供預設圖片方便 demo)
2. 輸入文字描述(例如:car、sky、road)
3. 多個物件請用句號分隔(例如:car. sky. road.)
""")
with gr.Row():
with gr.Column(scale=1):
# 預設圖片 car.jpg
image_input = gr.Image(label="Input Image", value="sample_images/car.jpg")
text_input = gr.Textbox(
label="Text Prompt",
placeholder="e.g. 'car. sky. road.'",
lines=1
)
with gr.Row():
segment_button = gr.Button("Segment", variant="primary")
clear_button = gr.Button("Clear")
with gr.Column(scale=2):
output_mask = gr.Image(label="Segmentation Mask", type="numpy")
summary = gr.Textbox(
label="Summary",
interactive=False,
lines=3,
max_lines=20,
show_copy_button=True
)
text_input.submit(segment_and_display, inputs=[image_input, text_input], outputs=[output_mask, summary])
segment_button.click(segment_and_display, inputs=[image_input, text_input], outputs=[output_mask, summary])
clear_button.click(lambda: (None, "", None, ""), inputs=None, outputs=[image_input, text_input, output_mask, summary])
demo.launch()