Asset2Scene / shape_e_service.py
MetricMogul's picture
Update shape_e_service.py
8dc66af verified
import gc
import uuid
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import gradio as gr
import torch
from diffusers import ShapEPipeline
from PIL import Image
MODEL_ID = "openai/shap-e"
pipe = None
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
ROOT_DIR = Path(__file__).resolve().parent
DATA_DIR = ROOT_DIR / "data"
ASSETS_DIR = DATA_DIR / "assets"
ASSETS_DIR.mkdir(parents=True, exist_ok=True)
EXAMPLES = [
"A cute stylized robot with a round head",
"A fantasy treasure chest with gold trim",
"A small dragon figurine, toy-like, colorful",
"A low-poly medieval house",
"A ceramic teapot shaped like an owl",
"A cartoon submarine with tiny windows",
]
def get_pipeline():
global pipe
if pipe is None:
kwargs = {"torch_dtype": DTYPE}
if DEVICE == "cuda":
kwargs["variant"] = "fp16"
pipe = ShapEPipeline.from_pretrained(MODEL_ID, **kwargs)
pipe = pipe.to(DEVICE)
return pipe
def make_white_background_transparent(frame: Image.Image, threshold: int = 245) -> Image.Image:
"""
Делает почти-белый фон прозрачным.
Если R, G и B все >= threshold, пиксель считаем фоном.
"""
img = frame.convert("RGBA")
data = img.getdata()
new_data = []
for r, g, b, a in data:
if r >= threshold and g >= threshold and b >= threshold:
new_data.append((255, 255, 255, 0))
else:
new_data.append((r, g, b, a))
img.putdata(new_data)
return img
def crop_to_nontransparent_content(img: Image.Image, padding: int = 8) -> Image.Image:
"""
Обрезает лишние прозрачные поля вокруг объекта.
"""
alpha = img.getchannel("A")
bbox = alpha.getbbox()
if bbox is None:
return img
left, top, right, bottom = bbox
left = max(0, left - padding)
top = max(0, top - padding)
right = min(img.width, right + padding)
bottom = min(img.height, bottom + padding)
return img.crop((left, top, right, bottom))
def save_frames_to_files(frames, prompt: str) -> List[str]:
asset_id = f"asset_{uuid.uuid4().hex[:8]}"
asset_dir = ASSETS_DIR / asset_id
asset_dir.mkdir(parents=True, exist_ok=True)
frame_paths = []
for i, frame in enumerate(frames):
img = frame.convert("RGBA")
img = make_white_background_transparent(img, threshold=245)
img = crop_to_nontransparent_content(img, padding=8)
frame_path = asset_dir / f"view_{i:03d}.png"
img.save(frame_path)
frame_paths.append(str(frame_path))
return frame_paths
def make_asset(prompt: str, frame_paths: List[str]) -> Dict[str, Any]:
return {
"prompt": prompt,
"frame_paths": frame_paths,
"selected_index": 0,
}
def gallery_items_from_assets(saved_assets: List[Dict[str, Any]]) -> List[Tuple[str, str]]:
items = []
for i, asset in enumerate(saved_assets):
frame_paths = asset.get("frame_paths", [])
if not frame_paths:
continue
idx = int(asset.get("selected_index", 0))
idx = max(0, min(idx, len(frame_paths) - 1))
caption = f"{i + 1}. {asset.get('prompt', '')}"
items.append((frame_paths[idx], caption))
return items
def current_view_from_selected(
saved_assets: List[Dict[str, Any]],
selected_asset_index: Optional[int],
):
if selected_asset_index is None:
return None, "No asset selected.", gr.update(interactive=False), gr.update(interactive=False)
if selected_asset_index < 0 or selected_asset_index >= len(saved_assets):
return None, "No asset selected.", gr.update(interactive=False), gr.update(interactive=False)
asset = saved_assets[selected_asset_index]
frame_paths = asset.get("frame_paths", [])
if not frame_paths:
return None, "No asset selected.", gr.update(interactive=False), gr.update(interactive=False)
idx = int(asset.get("selected_index", 0))
idx = max(0, min(idx, len(frame_paths) - 1))
label = f"Asset {selected_asset_index + 1} · View {idx + 1} / {len(frame_paths)}"
return frame_paths[idx], label, gr.update(interactive=True), gr.update(interactive=True)
def selected_view_path(
saved_assets: List[Dict[str, Any]],
selected_asset_index: Optional[int],
):
if selected_asset_index is None:
return None
if selected_asset_index < 0 or selected_asset_index >= len(saved_assets):
return None
asset = saved_assets[selected_asset_index]
frame_paths = asset.get("frame_paths", [])
if not frame_paths:
return None
idx = int(asset.get("selected_index", 0))
idx = max(0, min(idx, len(frame_paths) - 1))
return frame_paths[idx]
def generate_and_add_asset(
prompt: str,
steps: int,
guidance_scale: float,
frame_size: int,
seed: int,
saved_assets: List[Dict[str, Any]],
):
prompt = (prompt or "").strip()
if not prompt:
raise gr.Error("Prompt is empty.")
saved_assets = saved_assets or []
pipeline = get_pipeline()
generator = torch.Generator(device=DEVICE).manual_seed(int(seed))
result = pipeline(
prompt,
guidance_scale=float(guidance_scale),
num_inference_steps=int(steps),
frame_size=int(frame_size),
generator=generator,
)
frames = result.images[0]
frame_paths = save_frames_to_files(frames, prompt)
new_asset = make_asset(prompt, frame_paths)
saved_assets = saved_assets + [new_asset]
selected_asset_index = len(saved_assets) - 1
gallery_items = gallery_items_from_assets(saved_assets)
current_view, view_text, prev_btn, next_btn = current_view_from_selected(
saved_assets, selected_asset_index
)
gc.collect()
if DEVICE == "cuda":
torch.cuda.empty_cache()
return saved_assets, selected_asset_index, gallery_items, current_view, view_text, prev_btn, next_btn
def select_asset(
saved_assets: List[Dict[str, Any]],
evt: gr.SelectData,
):
if not saved_assets:
return None, [], None, "No asset selected.", gr.update(interactive=False), gr.update(interactive=False)
if evt is None or evt.index is None:
return None, gallery_items_from_assets(saved_assets), None, "No asset selected.", gr.update(interactive=False), gr.update(interactive=False)
idx = evt.index
if isinstance(idx, (list, tuple)):
idx = idx[0]
idx = int(idx)
current_view, view_text, prev_btn, next_btn = current_view_from_selected(saved_assets, idx)
gallery_items = gallery_items_from_assets(saved_assets)
return idx, gallery_items, current_view, view_text, prev_btn, next_btn
def prev_view(
saved_assets: List[Dict[str, Any]],
selected_asset_index: Optional[int],
):
if selected_asset_index is None:
raise gr.Error("Select an asset in the gallery first.")
if selected_asset_index < 0 or selected_asset_index >= len(saved_assets):
raise gr.Error("Select an asset in the gallery first.")
asset = saved_assets[selected_asset_index]
frame_paths = asset.get("frame_paths", [])
if not frame_paths:
raise gr.Error("Selected asset has no frames.")
idx = int(asset.get("selected_index", 0))
idx = (idx - 1) % len(frame_paths)
asset["selected_index"] = idx
gallery_items = gallery_items_from_assets(saved_assets)
current_view, view_text, prev_btn, next_btn = current_view_from_selected(
saved_assets, selected_asset_index
)
return saved_assets, gallery_items, current_view, view_text, prev_btn, next_btn
def next_view(
saved_assets: List[Dict[str, Any]],
selected_asset_index: Optional[int],
):
if selected_asset_index is None:
raise gr.Error("Select an asset in the gallery first.")
if selected_asset_index < 0 or selected_asset_index >= len(saved_assets):
raise gr.Error("Select an asset in the gallery first.")
asset = saved_assets[selected_asset_index]
frame_paths = asset.get("frame_paths", [])
if not frame_paths:
raise gr.Error("Selected asset has no frames.")
idx = int(asset.get("selected_index", 0))
idx = (idx + 1) % len(frame_paths)
asset["selected_index"] = idx
gallery_items = gallery_items_from_assets(saved_assets)
current_view, view_text, prev_btn, next_btn = current_view_from_selected(
saved_assets, selected_asset_index
)
return saved_assets, gallery_items, current_view, view_text, prev_btn, next_btn
def clear_saved_assets():
return [], None, [], None, "No asset selected.", gr.update(interactive=False), gr.update(interactive=False)
def set_prompt(value: str):
return value