| """ |
| Part Segmentation API — ポイントプロンプト版 |
| SAM (Apache-2.0) + アニメランドマーク28点座標でパーツ別マスクを生成 |
| |
| 使い方: |
| - /run: 従来の自動マスク生成(ランドマークなし) |
| - /run_with_landmarks: ランドマーク座標付きポイントプロンプト(高精度) |
| """ |
|
|
| import gradio as gr |
| import os |
| import uuid |
| import json |
| import zipfile |
| from typing import List, Tuple, Optional |
|
|
| from PIL import Image |
| import numpy as np |
| import torch |
| from transformers import SamModel, SamProcessor, pipeline |
|
|
| |
| |
| |
| SAM_MODEL = "facebook/sam-vit-large" |
|
|
| |
| |
| |
| |
| sam_model = SamModel.from_pretrained(SAM_MODEL) |
| sam_processor = SamProcessor.from_pretrained(SAM_MODEL) |
|
|
| |
| mask_gen = pipeline(task="mask-generation", model=SAM_MODEL) |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| sam_model = sam_model.to(device) |
|
|
| OUT_ROOT = "/tmp/vtuber_masks" |
| os.makedirs(OUT_ROOT, exist_ok=True) |
|
|
| |
| |
| |
| |
| PART_GROUPS = { |
| "eyeL": {"indices": [17, 18, 19, 20, 21, 22], "use_box": True}, |
| "eyeR": {"indices": [11, 12, 13, 14, 15, 16], "use_box": True}, |
| "mouth": {"indices": [24, 25, 26, 27], "use_box": True}, |
| "browL": {"indices": [8, 9, 10], "use_box": True}, |
| "browR": {"indices": [5, 6, 7], "use_box": True}, |
| |
| } |
|
|
|
|
| |
| |
| |
| def create_zip(files: List[str], zip_path: str) -> str: |
| with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as z: |
| for f in files: |
| z.write(f, os.path.basename(f)) |
| return zip_path |
|
|
|
|
| def get_all_points(points: List[dict], indices: List[int]) -> List[List[float]]: |
| """ランドマークグループの全点を取得""" |
| return [[points[i]["x"], points[i]["y"]] for i in indices if i < len(points)] |
|
|
|
|
| def compute_bbox(points: List[List[float]], padding: float = 0.15) -> List[float]: |
| """点群からバウンディングボックスを計算(パディング付き)""" |
| xs = [p[0] for p in points] |
| ys = [p[1] for p in points] |
| x1, y1, x2, y2 = min(xs), min(ys), max(xs), max(ys) |
| w, h = x2 - x1, y2 - y1 |
| pad = max(w, h) * padding |
| return [x1 - pad, y1 - pad, x2 + pad, y2 + pad] |
|
|
|
|
| def segment_part(image: Image.Image, image_embeddings, part_points: List[List[float]], original_size, reshaped_size, use_box: bool = True, negative_points: Optional[List[List[float]]] = None) -> Optional[np.ndarray]: |
| """1パーツをポイント+ボックスプロンプトでセグメント""" |
| if not part_points: |
| return None |
|
|
| device = image_embeddings.device |
|
|
| |
| cx = sum(p[0] for p in part_points) / len(part_points) |
| cy = sum(p[1] for p in part_points) / len(part_points) |
| pts = [[cx, cy]] |
| labels = [1] |
|
|
| |
| if negative_points: |
| for np_ in negative_points: |
| pts.append(np_) |
| labels.append(0) |
|
|
| input_points = [pts] |
| input_labels = [labels] |
|
|
| |
| input_boxes = None |
| if use_box and len(part_points) >= 2: |
| bbox = compute_bbox(part_points, padding=0.25) |
| input_boxes = [[bbox]] |
|
|
| inputs = sam_processor( |
| image, |
| input_points=input_points, |
| input_labels=input_labels, |
| input_boxes=input_boxes, |
| return_tensors="pt", |
| ).to(device) |
|
|
| |
| inputs.pop("pixel_values", None) |
|
|
| kwargs = { |
| "image_embeddings": image_embeddings, |
| "input_points": inputs["input_points"], |
| "input_labels": inputs["input_labels"], |
| } |
| if "input_boxes" in inputs and inputs["input_boxes"] is not None: |
| kwargs["input_boxes"] = inputs["input_boxes"] |
|
|
| with torch.no_grad(): |
| outputs = sam_model(**kwargs) |
|
|
| |
| masks = sam_processor.image_processor.post_process_masks( |
| outputs.pred_masks.cpu(), |
| original_size, |
| reshaped_size, |
| ) |
|
|
| |
| scores = outputs.iou_scores[0][0] |
| best_idx = scores.argmax().item() |
| best_mask = masks[0][0][best_idx].numpy() |
|
|
| return (best_mask > 0).astype(np.uint8) * 255 |
|
|
|
|
| def compute_face_mask(points: List[dict], image_size: Tuple[int, int]) -> np.ndarray: |
| """顎(0-4)の輪郭を上側で円弧で閉じた顔マスクを生成""" |
| from PIL import ImageDraw |
| import math |
|
|
| w, h = image_size |
|
|
| |
| jaw = [(points[i]["x"], points[i]["y"]) for i in range(5) if i < len(points)] |
| if len(jaw) < 3: |
| return np.zeros((h, w), dtype=np.uint8) |
|
|
| |
| brow_ys = [points[i]["y"] for i in range(5, 11) if i < len(points)] |
| forehead_y = min(brow_ys) if brow_ys else jaw[0][1] |
| |
| face_width = abs(jaw[0][0] - jaw[-1][0]) |
| top_y = forehead_y - face_width * 0.2 |
|
|
| |
| |
| left = jaw[0] |
| right = jaw[-1] |
| cx = (left[0] + right[0]) / 2 |
| rx = abs(right[0] - left[0]) / 2 * 1.1 |
| ry = max(left[1], right[1]) - top_y |
|
|
| |
| arc_points = [] |
| n_arc = 16 |
| for i in range(n_arc + 1): |
| angle = math.pi * i / n_arc |
| ax = cx + rx * math.cos(angle) |
| ay = min(left[1], right[1]) - ry * math.sin(angle) |
| arc_points.append((ax, ay)) |
|
|
| |
| polygon = list(jaw) + arc_points |
|
|
| |
| mask_img = Image.new("L", (w, h), 0) |
| draw = ImageDraw.Draw(mask_img) |
| draw.polygon(polygon, fill=255) |
|
|
| return np.array(mask_img) |
|
|
|
|
| def estimate_hair_points(points: List[dict], image_size: Tuple[int, int]) -> List[List[float]]: |
| """ランドマークから髪の推定ポイントを生成(頭上 + 左右 + サイド)""" |
| w, h = image_size |
| jaw_ys = [points[i]["y"] for i in range(5) if i < len(points)] |
| jaw_xs = [points[i]["x"] for i in range(5) if i < len(points)] |
| if not jaw_ys: |
| return [] |
|
|
| top_y = min(jaw_ys) |
| bottom_y = max(jaw_ys) |
| center_x = sum(jaw_xs) / len(jaw_xs) |
| face_width = max(jaw_xs) - min(jaw_xs) if len(jaw_xs) > 1 else 100 |
| face_height = bottom_y - top_y if bottom_y > top_y else face_width |
|
|
| hair_points = [] |
| |
| hair_top = max(0, top_y - face_height * 0.8) |
| hair_points.append([center_x, hair_top]) |
| |
| hair_points.append([center_x - face_width * 0.3, hair_top]) |
| hair_points.append([center_x + face_width * 0.3, hair_top]) |
| |
| mid_y = top_y + face_height * 0.3 |
| hair_points.append([max(0, center_x - face_width * 0.7), mid_y]) |
| hair_points.append([min(w, center_x + face_width * 0.7), mid_y]) |
| |
| hair_points.append([max(0, center_x - face_width * 0.6), bottom_y]) |
| hair_points.append([min(w, center_x + face_width * 0.6), bottom_y]) |
|
|
| return hair_points |
|
|
|
|
| |
| |
| |
| def run_with_landmarks(image: Image.Image, landmarks_json: str) -> Tuple[List[str], str]: |
| if image is None: |
| return [], "画像がありません" |
|
|
| image = image.convert("RGB") |
|
|
| |
| try: |
| landmarks_data = json.loads(landmarks_json) |
| points = landmarks_data.get("points", []) |
| except (json.JSONDecodeError, TypeError): |
| return [], "ランドマークJSONの解析に失敗しました" |
|
|
| if len(points) < 28: |
| return [], f"ランドマークが不足しています({len(points)}/28)" |
|
|
| out_dir = os.path.join(OUT_ROOT, uuid.uuid4().hex) |
| os.makedirs(out_dir, exist_ok=True) |
|
|
| |
| device = next(sam_model.parameters()).device |
| inputs = sam_processor(image, return_tensors="pt").to(device) |
| with torch.no_grad(): |
| image_embeddings = sam_model.get_image_embeddings(inputs["pixel_values"]) |
|
|
| original_size = inputs["original_sizes"] |
| reshaped_size = inputs["reshaped_input_sizes"] |
|
|
| mask_files = [] |
| status_parts = [] |
| all_masks = [] |
|
|
| |
| for part_name, config in PART_GROUPS.items(): |
| part_points = get_all_points(points, config["indices"]) |
| mask_arr = segment_part(image, image_embeddings, part_points, original_size, reshaped_size, use_box=config.get("use_box", True)) |
|
|
| if mask_arr is not None: |
| mask_img = Image.fromarray(mask_arr, "L") |
| if mask_img.size != image.size: |
| mask_img = mask_img.resize(image.size) |
| mask_arr = np.array(mask_img) |
| path = os.path.join(out_dir, f"{part_name}.png") |
| mask_img.save(path) |
| mask_files.append(path) |
| status_parts.append(part_name) |
| all_masks.append(mask_arr) |
|
|
| |
| part_masks = dict(zip(status_parts, all_masks)) |
| small_parts = ["eyeL", "eyeR", "mouth", "browL", "browR"] |
| h_img, w_img = all_masks[0].shape if all_masks else (image.size[1], image.size[0]) |
|
|
| |
| face_ellipse = compute_face_mask(points, image.size) |
| if face_ellipse.shape != (h_img, w_img): |
| face_ellipse = np.array(Image.fromarray(face_ellipse, "L").resize((w_img, h_img))) |
|
|
| |
| small_union = np.zeros((h_img, w_img), dtype=np.uint8) |
| for sp in small_parts: |
| if sp in part_masks: |
| small_union = np.maximum(small_union, part_masks[sp]) |
|
|
| |
| face_mask = np.where(small_union > 0, 0, face_ellipse).astype(np.uint8) |
| part_masks["face"] = face_mask |
|
|
| |
| mask_files = [] |
| all_masks = [] |
| status_parts_clean = [p for p in status_parts] |
| if "face" not in status_parts_clean: |
| status_parts_clean.append("face") |
| for part_name in status_parts_clean: |
| arr = part_masks[part_name] |
| mask_img = Image.fromarray(arr, "L") |
| path = os.path.join(out_dir, f"{part_name}.png") |
| mask_img.save(path) |
| mask_files.append(path) |
| all_masks.append(arr) |
|
|
| |
| union = np.zeros((h_img, w_img), dtype=np.uint8) |
| for m in all_masks: |
| union = np.maximum(union, m) |
| remainder = np.where(union > 0, 0, 255).astype(np.uint8) |
|
|
| if np.sum(remainder > 0) > remainder.size * 0.005: |
| |
| jaw_ys = [points[i]["y"] for i in range(5) if i < len(points)] |
| jaw_xs = [points[i]["x"] for i in range(5) if i < len(points)] |
| if jaw_ys: |
| face_center_y = sum(jaw_ys) / len(jaw_ys) |
| face_center_x = sum(jaw_xs) / len(jaw_xs) |
| face_w = (max(jaw_xs) - min(jaw_xs)) if len(jaw_xs) > 1 else w_img * 0.4 |
|
|
| |
| hair_mask = np.zeros_like(remainder) |
| for y in range(h_img): |
| for x in range(w_img): |
| if remainder[y, x] == 0: |
| continue |
| |
| if y < face_center_y: |
| hair_mask[y, x] = 255 |
| |
| elif y < face_center_y + face_w * 0.5: |
| dist_from_center = abs(x - face_center_x) |
| if dist_from_center > face_w * 0.35: |
| hair_mask[y, x] = 255 |
|
|
| |
| body_mask = np.where(hair_mask > 0, 0, remainder).astype(np.uint8) |
|
|
| |
| if np.sum(hair_mask > 0) > remainder.size * 0.005: |
| hair_img = Image.fromarray(hair_mask, "L") |
| path = os.path.join(out_dir, "hair.png") |
| hair_img.save(path) |
| mask_files.append(path) |
| status_parts_clean.append("hair") |
| all_masks.append(hair_mask) |
|
|
| |
| if np.sum(body_mask > 0) > remainder.size * 0.005: |
| body_img = Image.fromarray(body_mask, "L") |
| path = os.path.join(out_dir, "body.png") |
| body_img.save(path) |
| mask_files.append(path) |
| status_parts_clean.append("body") |
| all_masks.append(body_mask) |
| else: |
| |
| remainder_img = Image.fromarray(remainder, "L") |
| path = os.path.join(out_dir, "body.png") |
| remainder_img.save(path) |
| mask_files.append(path) |
| status_parts_clean.append("body") |
|
|
| status_parts = status_parts_clean |
|
|
| |
| if mask_files: |
| zip_path = os.path.join(out_dir, "parts.zip") |
| create_zip(mask_files, zip_path) |
| mask_files.append(zip_path) |
|
|
| status = f"生成完了: {', '.join(status_parts)} ({len(status_parts)} parts)" |
| return mask_files, status |
|
|
|
|
| |
| |
| |
| def run(image: Image.Image, max_masks: int = 12) -> Tuple[List[str], str]: |
| if image is None: |
| return [], "画像がありません" |
|
|
| image = image.convert("RGB") |
| result = mask_gen(image) |
|
|
| if isinstance(result, dict): |
| masks = result.get("masks") or result.get("mask") or result.get("pred_masks") or [] |
| else: |
| masks = result |
|
|
| out_dir = os.path.join(OUT_ROOT, uuid.uuid4().hex) |
| os.makedirs(out_dir, exist_ok=True) |
|
|
| mask_files = [] |
| for i, item in enumerate(masks[:max_masks]): |
| m = item.get("mask") if isinstance(item, dict) else item |
| if m is None: |
| continue |
| arr = np.array(m) |
| if arr.ndim == 3 and arr.shape[0] == 1: |
| arr = arr[0] |
| if arr.dtype != np.uint8: |
| arr = (arr > 0).astype(np.uint8) * 255 |
| mask_img = Image.fromarray(arr, "L") |
| if mask_img.size != image.size: |
| mask_img = mask_img.resize(image.size) |
| path = os.path.join(out_dir, f"mask_{i:02d}.png") |
| mask_img.save(path) |
| mask_files.append(path) |
|
|
| if mask_files: |
| zip_path = os.path.join(out_dir, "all_masks.zip") |
| create_zip(mask_files, zip_path) |
| mask_files.append(zip_path) |
|
|
| return mask_files, f"{len(mask_files) - 1} masks generated" |
|
|
|
|
| |
| |
| |
| with gr.Blocks(title="Part Segmentation API (SAM + Landmarks)") as demo: |
| gr.Markdown( |
| """ |
| # Part Segmentation API |
| |
| ## モード |
| - **自動マスク**: 画像だけで自動分割 |
| - **ランドマーク指定**: 28点アニメランドマーク座標でパーツ別に高精度分割(目・口・眉・髪) |
| |
| 商用利用可能(Apache-2.0) |
| """ |
| ) |
|
|
| with gr.Tab("自動マスク"): |
| with gr.Row(): |
| with gr.Column(): |
| auto_image = gr.Image(type="pil", label="Upload image") |
| auto_max = gr.Slider(minimum=1, maximum=50, value=12, step=1, label="Max masks") |
| auto_btn = gr.Button("Generate", variant="primary") |
| with gr.Column(): |
| auto_files = gr.Files(label="Masks") |
| auto_status = gr.Textbox(label="Status") |
| auto_btn.click(fn=run, inputs=[auto_image, auto_max], outputs=[auto_files, auto_status], api_name="run") |
|
|
| with gr.Tab("ランドマーク指定"): |
| with gr.Row(): |
| with gr.Column(): |
| lm_image = gr.Image(type="pil", label="Upload image") |
| lm_json = gr.Textbox(label="Landmarks JSON", placeholder='{"points": [{"x":100,"y":200}, ...]}', lines=3) |
| lm_btn = gr.Button("Generate Parts", variant="primary") |
| with gr.Column(): |
| lm_files = gr.Files(label="Part Masks") |
| lm_status = gr.Textbox(label="Status") |
| lm_btn.click(fn=run_with_landmarks, inputs=[lm_image, lm_json], outputs=[lm_files, lm_status], api_name="run_with_landmarks") |
|
|
| demo.launch() |
|
|