Asset2Scene / sd_service.py
MetricMogul's picture
Update sd_service.py
b8726b6 verified
import base64
import gc
import io
import uuid
from pathlib import Path
import gradio as gr
import torch
from PIL import Image
from diffusers import (
ControlNetModel,
StableDiffusionControlNetPipeline,
UniPCMultistepScheduler,
)
from transformers import pipeline as hf_pipeline
# Базовые модели. Потом можно заменить на свои локальные / любимые.
BASE_SD_ID = "runwayml/stable-diffusion-v1-5"
CONTROLNET_ID = "lllyasviel/sd-controlnet-depth"
DEPTH_MODEL_ID = "Intel/dpt-hybrid-midas"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
ROOT_DIR = Path(__file__).resolve().parent
DATA_DIR = ROOT_DIR / "data"
SD_OUTPUTS_DIR = DATA_DIR / "sd_outputs"
SD_OUTPUTS_DIR.mkdir(parents=True, exist_ok=True)
sd_pipe = None
depth_estimator = None
def get_depth_estimator():
global depth_estimator
if depth_estimator is None:
depth_estimator = hf_pipeline(
"depth-estimation",
model=DEPTH_MODEL_ID,
device=0 if DEVICE == "cuda" else -1,
)
return depth_estimator
def get_sd_pipe():
global sd_pipe
if sd_pipe is None:
controlnet = ControlNetModel.from_pretrained(
CONTROLNET_ID,
torch_dtype=DTYPE,
)
kwargs = {
"controlnet": controlnet,
"torch_dtype": DTYPE,
"safety_checker": None,
}
if DEVICE == "cuda":
kwargs["variant"] = "fp16"
sd_pipe = StableDiffusionControlNetPipeline.from_pretrained(
BASE_SD_ID,
**kwargs,
)
sd_pipe.scheduler = UniPCMultistepScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe = sd_pipe.to(DEVICE)
return sd_pipe
def decode_data_url_to_image(data_url: str) -> Image.Image:
if not data_url or "," not in data_url:
raise gr.Error("Canvas is empty. Add assets to the scene first.")
_, encoded = data_url.split(",", 1)
binary = base64.b64decode(encoded)
img = Image.open(io.BytesIO(binary)).convert("RGBA")
return img
def flatten_rgba_on_white(img: Image.Image) -> Image.Image:
bg = Image.new("RGBA", img.size, (255, 255, 255, 255))
merged = Image.alpha_composite(bg, img.convert("RGBA"))
return merged.convert("RGB")
def resize_for_depth_and_sd(img: Image.Image, target_max_side: int = 768) -> Image.Image:
w, h = img.size
scale = min(target_max_side / max(w, h), 1.0) if max(w, h) > 0 else 1.0
new_w = max(64, int(round((w * scale) / 8) * 8))
new_h = max(64, int(round((h * scale) / 8) * 8))
if (new_w, new_h) == (w, h):
return img
return img.resize((new_w, new_h), Image.LANCZOS)
def make_depth_image(scene_image: Image.Image) -> Image.Image:
estimator = get_depth_estimator()
result = estimator(scene_image)
depth = result["depth"]
if not isinstance(depth, Image.Image):
depth = Image.fromarray(depth)
depth = depth.convert("RGB")
if depth.size != scene_image.size:
depth = depth.resize(scene_image.size, Image.LANCZOS)
return depth
def save_image(img: Image.Image, prefix: str) -> str:
path = SD_OUTPUTS_DIR / f"{prefix}_{uuid.uuid4().hex[:8]}.png"
img.save(path)
return str(path)
def generate_with_depth_from_scene(
scene_png_data: str,
prompt: str,
negative_prompt: str,
steps: int,
guidance_scale: float,
controlnet_scale: float,
seed: int,
):
prompt = (prompt or "").strip()
if not prompt:
raise gr.Error("Prompt is empty.")
rgba_scene = decode_data_url_to_image(scene_png_data)
# Только для depth-оценки. В сам SD эта картинка уже не идёт.
scene_rgb = flatten_rgba_on_white(rgba_scene)
scene_rgb = resize_for_depth_and_sd(scene_rgb, target_max_side=768)
depth_image = make_depth_image(scene_rgb)
pipe = get_sd_pipe()
generator = torch.Generator(device=DEVICE).manual_seed(int(seed))
result = pipe(
prompt=prompt,
negative_prompt=negative_prompt or None,
image=depth_image,
num_inference_steps=int(steps),
guidance_scale=float(guidance_scale),
controlnet_conditioning_scale=float(controlnet_scale),
generator=generator,
width=depth_image.width,
height=depth_image.height,
)
output_image = result.images[0]
scene_path = save_image(scene_rgb, "scene_for_depth")
depth_path = save_image(depth_image, "depth")
output_path = save_image(output_image, "sd")
gc.collect()
if DEVICE == "cuda":
torch.cuda.empty_cache()
return scene_path, depth_path, output_path