separator / app.py
Kanae-K's picture
Update app.py
75bfdf3 verified
"""
Part Segmentation API — ポイントプロンプト版
SAM (Apache-2.0) + アニメランドマーク28点座標でパーツ別マスクを生成
使い方:
- /run: 従来の自動マスク生成(ランドマークなし)
- /run_with_landmarks: ランドマーク座標付きポイントプロンプト(高精度)
"""
import gradio as gr
import os
import uuid
import json
import zipfile
from typing import List, Tuple, Optional
from PIL import Image
import numpy as np
import torch
from transformers import SamModel, SamProcessor, pipeline
# =========================
# 商用利用OKモデル(Apache-2.0)
# =========================
SAM_MODEL = "facebook/sam-vit-large"
# =========================
# モデル初期化
# =========================
# ポイントプロンプト用
sam_model = SamModel.from_pretrained(SAM_MODEL)
sam_processor = SamProcessor.from_pretrained(SAM_MODEL)
# 自動マスク用(従来互換)
mask_gen = pipeline(task="mask-generation", model=SAM_MODEL)
device = "cuda" if torch.cuda.is_available() else "cpu"
sam_model = sam_model.to(device)
OUT_ROOT = "/tmp/vtuber_masks"
os.makedirs(OUT_ROOT, exist_ok=True)
# =========================
# VTuberパーツ定義(28点ランドマーク対応)
# 0-4: jaw, 5-7: browR, 8-10: browL, 11-16: eyeR, 17-22: eyeL, 23: nose, 24-27: mouth
# =========================
PART_GROUPS = {
"eyeL": {"indices": [17, 18, 19, 20, 21, 22], "use_box": True},
"eyeR": {"indices": [11, 12, 13, 14, 15, 16], "use_box": True},
"mouth": {"indices": [24, 25, 26, 27], "use_box": True},
"browL": {"indices": [8, 9, 10], "use_box": True},
"browR": {"indices": [5, 6, 7], "use_box": True},
# face はSAMを使わず幾何学的に生成(後処理で追加)
}
# =========================
# ヘルパー
# =========================
def create_zip(files: List[str], zip_path: str) -> str:
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as z:
for f in files:
z.write(f, os.path.basename(f))
return zip_path
def get_all_points(points: List[dict], indices: List[int]) -> List[List[float]]:
"""ランドマークグループの全点を取得"""
return [[points[i]["x"], points[i]["y"]] for i in indices if i < len(points)]
def compute_bbox(points: List[List[float]], padding: float = 0.15) -> List[float]:
"""点群からバウンディングボックスを計算(パディング付き)"""
xs = [p[0] for p in points]
ys = [p[1] for p in points]
x1, y1, x2, y2 = min(xs), min(ys), max(xs), max(ys)
w, h = x2 - x1, y2 - y1
pad = max(w, h) * padding
return [x1 - pad, y1 - pad, x2 + pad, y2 + pad]
def segment_part(image: Image.Image, image_embeddings, part_points: List[List[float]], original_size, reshaped_size, use_box: bool = True, negative_points: Optional[List[List[float]]] = None) -> Optional[np.ndarray]:
"""1パーツをポイント+ボックスプロンプトでセグメント"""
if not part_points:
return None
device = image_embeddings.device
# ポイントプロンプト(前景ポイント + オプションでネガティブポイント)
cx = sum(p[0] for p in part_points) / len(part_points)
cy = sum(p[1] for p in part_points) / len(part_points)
pts = [[cx, cy]]
labels = [1]
# ネガティブポイント(背景として除外したい点)
if negative_points:
for np_ in negative_points:
pts.append(np_)
labels.append(0)
input_points = [pts]
input_labels = [labels]
# ボックスプロンプト(ランドマーク点群のバウンディングボックス)
input_boxes = None
if use_box and len(part_points) >= 2:
bbox = compute_bbox(part_points, padding=0.25)
input_boxes = [[bbox]]
inputs = sam_processor(
image,
input_points=input_points,
input_labels=input_labels,
input_boxes=input_boxes,
return_tensors="pt",
).to(device)
# 事前計算済みのimage embeddingsを使う
inputs.pop("pixel_values", None)
kwargs = {
"image_embeddings": image_embeddings,
"input_points": inputs["input_points"],
"input_labels": inputs["input_labels"],
}
if "input_boxes" in inputs and inputs["input_boxes"] is not None:
kwargs["input_boxes"] = inputs["input_boxes"]
with torch.no_grad():
outputs = sam_model(**kwargs)
# マスクを後処理(元画像サイズに戻す)
masks = sam_processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
original_size,
reshaped_size,
)
# IoUスコアが最も高いマスクを選択
scores = outputs.iou_scores[0][0]
best_idx = scores.argmax().item()
best_mask = masks[0][0][best_idx].numpy()
return (best_mask > 0).astype(np.uint8) * 255
def compute_face_mask(points: List[dict], image_size: Tuple[int, int]) -> np.ndarray:
"""顎(0-4)の輪郭を上側で円弧で閉じた顔マスクを生成"""
from PIL import ImageDraw
import math
w, h = image_size
# 顎の5点
jaw = [(points[i]["x"], points[i]["y"]) for i in range(5) if i < len(points)]
if len(jaw) < 3:
return np.zeros((h, w), dtype=np.uint8)
# 眉の上端 → 額の位置
brow_ys = [points[i]["y"] for i in range(5, 11) if i < len(points)]
forehead_y = min(brow_ys) if brow_ys else jaw[0][1]
# 額にマージンを追加
face_width = abs(jaw[0][0] - jaw[-1][0])
top_y = forehead_y - face_width * 0.2
# 左端(jaw[0])と右端(jaw[-1])を円弧で繋ぐ
# 上部の制御点を生成(半円弧を点列で近似)
left = jaw[0] # 左の顎端
right = jaw[-1] # 右の顎端
cx = (left[0] + right[0]) / 2
rx = abs(right[0] - left[0]) / 2 * 1.1 # 少し広め
ry = max(left[1], right[1]) - top_y
# 右端→上→左端の円弧(反時計回り)
arc_points = []
n_arc = 16
for i in range(n_arc + 1):
angle = math.pi * i / n_arc # 0 → π(右→上→左)
ax = cx + rx * math.cos(angle)
ay = min(left[1], right[1]) - ry * math.sin(angle)
arc_points.append((ax, ay))
# 完全なポリゴン: 顎(0→4) + 円弧(右端→上→左端)
polygon = list(jaw) + arc_points
# ポリゴンを塗りつぶしてマスク生成
mask_img = Image.new("L", (w, h), 0)
draw = ImageDraw.Draw(mask_img)
draw.polygon(polygon, fill=255)
return np.array(mask_img)
def estimate_hair_points(points: List[dict], image_size: Tuple[int, int]) -> List[List[float]]:
"""ランドマークから髪の推定ポイントを生成(頭上 + 左右 + サイド)"""
w, h = image_size
jaw_ys = [points[i]["y"] for i in range(5) if i < len(points)]
jaw_xs = [points[i]["x"] for i in range(5) if i < len(points)]
if not jaw_ys:
return []
top_y = min(jaw_ys)
bottom_y = max(jaw_ys)
center_x = sum(jaw_xs) / len(jaw_xs)
face_width = max(jaw_xs) - min(jaw_xs) if len(jaw_xs) > 1 else 100
face_height = bottom_y - top_y if bottom_y > top_y else face_width
hair_points = []
# 頭頂(顔の上80%)
hair_top = max(0, top_y - face_height * 0.8)
hair_points.append([center_x, hair_top])
# 頭頂の左右
hair_points.append([center_x - face_width * 0.3, hair_top])
hair_points.append([center_x + face_width * 0.3, hair_top])
# サイド上部(顔の横、目の高さあたり)
mid_y = top_y + face_height * 0.3
hair_points.append([max(0, center_x - face_width * 0.7), mid_y])
hair_points.append([min(w, center_x + face_width * 0.7), mid_y])
# サイド下部(顎の横、長い髪用)
hair_points.append([max(0, center_x - face_width * 0.6), bottom_y])
hair_points.append([min(w, center_x + face_width * 0.6), bottom_y])
return hair_points
# =========================
# メイン処理: ポイントプロンプト版
# =========================
def run_with_landmarks(image: Image.Image, landmarks_json: str) -> Tuple[List[str], str]:
if image is None:
return [], "画像がありません"
image = image.convert("RGB")
# ランドマーク解析
try:
landmarks_data = json.loads(landmarks_json)
points = landmarks_data.get("points", [])
except (json.JSONDecodeError, TypeError):
return [], "ランドマークJSONの解析に失敗しました"
if len(points) < 28:
return [], f"ランドマークが不足しています({len(points)}/28)"
out_dir = os.path.join(OUT_ROOT, uuid.uuid4().hex)
os.makedirs(out_dir, exist_ok=True)
# SAM画像エンコーディング(1回だけ)
device = next(sam_model.parameters()).device
inputs = sam_processor(image, return_tensors="pt").to(device)
with torch.no_grad():
image_embeddings = sam_model.get_image_embeddings(inputs["pixel_values"])
original_size = inputs["original_sizes"]
reshaped_size = inputs["reshaped_input_sizes"]
mask_files = []
status_parts = []
all_masks = [] # 全パーツマスクを蓄積(残り領域計算用)
# 各パーツをセグメント(目・口・眉のみ — SAMが得意な小パーツ)
for part_name, config in PART_GROUPS.items():
part_points = get_all_points(points, config["indices"])
mask_arr = segment_part(image, image_embeddings, part_points, original_size, reshaped_size, use_box=config.get("use_box", True))
if mask_arr is not None:
mask_img = Image.fromarray(mask_arr, "L")
if mask_img.size != image.size:
mask_img = mask_img.resize(image.size)
mask_arr = np.array(mask_img)
path = os.path.join(out_dir, f"{part_name}.png")
mask_img.save(path)
mask_files.append(path)
status_parts.append(part_name)
all_masks.append(mask_arr)
# ===== 後処理: 顔を幾何学的に追加 + 重なり除去 =====
part_masks = dict(zip(status_parts, all_masks))
small_parts = ["eyeL", "eyeR", "mouth", "browL", "browR"]
h_img, w_img = all_masks[0].shape if all_masks else (image.size[1], image.size[0])
# 幾何学的な顔マスク(楕円)を生成、小パーツを除外
face_ellipse = compute_face_mask(points, image.size)
if face_ellipse.shape != (h_img, w_img):
face_ellipse = np.array(Image.fromarray(face_ellipse, "L").resize((w_img, h_img)))
# 小パーツの和集合
small_union = np.zeros((h_img, w_img), dtype=np.uint8)
for sp in small_parts:
if sp in part_masks:
small_union = np.maximum(small_union, part_masks[sp])
# 顔 = 楕円マスク − 小パーツ
face_mask = np.where(small_union > 0, 0, face_ellipse).astype(np.uint8)
part_masks["face"] = face_mask
# 修正後のマスクを再保存
mask_files = []
all_masks = []
status_parts_clean = [p for p in status_parts]
if "face" not in status_parts_clean:
status_parts_clean.append("face")
for part_name in status_parts_clean:
arr = part_masks[part_name]
mask_img = Image.fromarray(arr, "L")
path = os.path.join(out_dir, f"{part_name}.png")
mask_img.save(path)
mask_files.append(path)
all_masks.append(arr)
# ===== 残り領域を「髪」と「体」に分割 =====
union = np.zeros((h_img, w_img), dtype=np.uint8)
for m in all_masks:
union = np.maximum(union, m)
remainder = np.where(union > 0, 0, 255).astype(np.uint8)
if np.sum(remainder > 0) > remainder.size * 0.005:
# 顔の中心Y座標を基準に上下分割
jaw_ys = [points[i]["y"] for i in range(5) if i < len(points)]
jaw_xs = [points[i]["x"] for i in range(5) if i < len(points)]
if jaw_ys:
face_center_y = sum(jaw_ys) / len(jaw_ys)
face_center_x = sum(jaw_xs) / len(jaw_xs)
face_w = (max(jaw_xs) - min(jaw_xs)) if len(jaw_xs) > 1 else w_img * 0.4
# 髪 = 残りのうち、顔の上部 or 顔の横(顔中心Yより上 + 顔幅の外側で上半分)
hair_mask = np.zeros_like(remainder)
for y in range(h_img):
for x in range(w_img):
if remainder[y, x] == 0:
continue
# 顔中心より上 → 髪
if y < face_center_y:
hair_mask[y, x] = 255
# 顔の横で上半分 → サイドの髪
elif y < face_center_y + face_w * 0.5:
dist_from_center = abs(x - face_center_x)
if dist_from_center > face_w * 0.35:
hair_mask[y, x] = 255
# 体 = 残り − 髪
body_mask = np.where(hair_mask > 0, 0, remainder).astype(np.uint8)
# 髪マスクを保存(意味のあるサイズなら)
if np.sum(hair_mask > 0) > remainder.size * 0.005:
hair_img = Image.fromarray(hair_mask, "L")
path = os.path.join(out_dir, "hair.png")
hair_img.save(path)
mask_files.append(path)
status_parts_clean.append("hair")
all_masks.append(hair_mask)
# 体マスクを保存
if np.sum(body_mask > 0) > remainder.size * 0.005:
body_img = Image.fromarray(body_mask, "L")
path = os.path.join(out_dir, "body.png")
body_img.save(path)
mask_files.append(path)
status_parts_clean.append("body")
all_masks.append(body_mask)
else:
# ランドマークなしの場合は全部bodyに
remainder_img = Image.fromarray(remainder, "L")
path = os.path.join(out_dir, "body.png")
remainder_img.save(path)
mask_files.append(path)
status_parts_clean.append("body")
status_parts = status_parts_clean
# ZIP
if mask_files:
zip_path = os.path.join(out_dir, "parts.zip")
create_zip(mask_files, zip_path)
mask_files.append(zip_path)
status = f"生成完了: {', '.join(status_parts)} ({len(status_parts)} parts)"
return mask_files, status
# =========================
# 従来の自動マスク生成(互換用)
# =========================
def run(image: Image.Image, max_masks: int = 12) -> Tuple[List[str], str]:
if image is None:
return [], "画像がありません"
image = image.convert("RGB")
result = mask_gen(image)
if isinstance(result, dict):
masks = result.get("masks") or result.get("mask") or result.get("pred_masks") or []
else:
masks = result
out_dir = os.path.join(OUT_ROOT, uuid.uuid4().hex)
os.makedirs(out_dir, exist_ok=True)
mask_files = []
for i, item in enumerate(masks[:max_masks]):
m = item.get("mask") if isinstance(item, dict) else item
if m is None:
continue
arr = np.array(m)
if arr.ndim == 3 and arr.shape[0] == 1:
arr = arr[0]
if arr.dtype != np.uint8:
arr = (arr > 0).astype(np.uint8) * 255
mask_img = Image.fromarray(arr, "L")
if mask_img.size != image.size:
mask_img = mask_img.resize(image.size)
path = os.path.join(out_dir, f"mask_{i:02d}.png")
mask_img.save(path)
mask_files.append(path)
if mask_files:
zip_path = os.path.join(out_dir, "all_masks.zip")
create_zip(mask_files, zip_path)
mask_files.append(zip_path)
return mask_files, f"{len(mask_files) - 1} masks generated"
# =========================
# UI
# =========================
with gr.Blocks(title="Part Segmentation API (SAM + Landmarks)") as demo:
gr.Markdown(
"""
# Part Segmentation API
## モード
- **自動マスク**: 画像だけで自動分割
- **ランドマーク指定**: 28点アニメランドマーク座標でパーツ別に高精度分割(目・口・眉・髪)
商用利用可能(Apache-2.0)
"""
)
with gr.Tab("自動マスク"):
with gr.Row():
with gr.Column():
auto_image = gr.Image(type="pil", label="Upload image")
auto_max = gr.Slider(minimum=1, maximum=50, value=12, step=1, label="Max masks")
auto_btn = gr.Button("Generate", variant="primary")
with gr.Column():
auto_files = gr.Files(label="Masks")
auto_status = gr.Textbox(label="Status")
auto_btn.click(fn=run, inputs=[auto_image, auto_max], outputs=[auto_files, auto_status], api_name="run")
with gr.Tab("ランドマーク指定"):
with gr.Row():
with gr.Column():
lm_image = gr.Image(type="pil", label="Upload image")
lm_json = gr.Textbox(label="Landmarks JSON", placeholder='{"points": [{"x":100,"y":200}, ...]}', lines=3)
lm_btn = gr.Button("Generate Parts", variant="primary")
with gr.Column():
lm_files = gr.Files(label="Part Masks")
lm_status = gr.Textbox(label="Status")
lm_btn.click(fn=run_with_landmarks, inputs=[lm_image, lm_json], outputs=[lm_files, lm_status], api_name="run_with_landmarks")
demo.launch()