moebius / app.py
Mike0021's picture
Keep only clear object-mask examples
f3a64c1 verified
Raw
History Blame Contribute Delete
7.67 kB
import os
_hf_cache = "/data/.cache/huggingface" if os.path.isdir("/data") and os.access("/data", os.W_OK) else "/tmp/hf_home"
os.environ.setdefault("HF_HOME", _hf_cache)
os.environ.setdefault("HF_MODULES_CACHE", "/tmp/hf_modules")
os.environ.setdefault("MPLCONFIGDIR", "/tmp/matplotlib")
os.environ.setdefault("GRADIO_SSR_MODE", "false")
import time
from pathlib import Path
from typing import Dict, Tuple
import spaces
import gradio as gr
import torch
from diffusers import DDIMScheduler
from diffusers.models import AutoencoderKL
from huggingface_hub import hf_hub_download, snapshot_download
from PIL import Image
from removal.v1_2 import build_removal_model, load_cfg, load_removal_model
from removal.v1_2.pipeline import RemovalSDXLPipeline_BatchMode
ROOT = Path(__file__).resolve().parent
CONFIG_PATH = ROOT / "config" / "model_cfg" / "moebius.yaml"
MOEBIUS_REPO = "hustvl/Moebius"
PIXELHACKER_REPO = "hustvl/PixelHacker"
DEFAULT_MODEL_KEY = "ft_places2"
MODEL_CHOICES = {
"General scenes (Places2)": "ft_places2",
"Portraits (CelebA-HQ)": "ft_celebahq",
"Faces (FFHQ)": "ft_ffhq",
"Pretrained": "pretrained",
}
_PIPELINE_CACHE: Dict[str, RemovalSDXLPipeline_BatchMode] = {}
def _download_vae_dir() -> str:
repo_dir = snapshot_download(
repo_id=PIXELHACKER_REPO,
allow_patterns=["vae/*"],
)
return str(Path(repo_dir) / "vae")
def _download_model_weight(model_key: str) -> str:
return hf_hub_download(
repo_id=MOEBIUS_REPO,
filename=f"{model_key}/diffusion_pytorch_model.bin",
)
def _build_cpu_pipeline(model_key: str) -> RemovalSDXLPipeline_BatchMode:
model_cfg = load_cfg(str(CONFIG_PATH))
model_cfg["vae"]["model_dir"] = _download_vae_dir()
removal_model = build_removal_model(model_cfg, 20)
weight_path = _download_model_weight(model_key)
print(load_removal_model(removal_model, weight_path, device="cpu"))
vae = AutoencoderKL.from_pretrained(model_cfg["vae"]["model_dir"])
scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
clip_sample=False,
)
return RemovalSDXLPipeline_BatchMode(
removal_model=removal_model,
vae=vae,
scheduler=scheduler,
device="cpu",
dtype=torch.float32,
)
def _get_pipeline(model_key: str) -> RemovalSDXLPipeline_BatchMode:
if model_key not in _PIPELINE_CACHE:
_PIPELINE_CACHE[model_key] = _build_cpu_pipeline(model_key)
return _PIPELINE_CACHE[model_key]
def _set_pipeline_device(pipe: RemovalSDXLPipeline_BatchMode, device: str) -> None:
pipe.device = device
pipe.vae.to(device=device, dtype=pipe.dtype).eval()
pipe.removal_model.to(device=device, dtype=pipe.dtype).eval()
half_id_num = pipe.removal_model.num_embeddings // 2
id_num = pipe.removal_model.num_embeddings
input_ids = torch.tensor([list(range(half_id_num))], dtype=torch.int64, device=device, requires_grad=False)
un_input_ids = torch.tensor([list(range(half_id_num, id_num))], dtype=torch.int64, device=device, requires_grad=False)
pipe.input_ids = torch.cat([un_input_ids, input_ids]).to(device=device)
def _normalize_inputs(image: Image.Image, mask: Image.Image) -> Tuple[Image.Image, Image.Image]:
if image is None:
raise gr.Error("Upload an image.")
if mask is None:
raise gr.Error("Upload a mask.")
image = image.convert("RGB")
mask = mask.convert("L").resize(image.size, Image.Resampling.NEAREST)
mask_min, mask_max = mask.getextrema()
if mask_max < 8:
raise gr.Error("The mask is empty. Use white pixels for the area to inpaint.")
if mask_min > 247:
raise gr.Error("The mask covers the whole image. Leave black pixels outside the edit area.")
return image, mask
def _model_key(label: str) -> str:
return MODEL_CHOICES.get(label, DEFAULT_MODEL_KEY)
def _estimate_duration(image, mask, model_name, steps, guidance_scale, paste, compensate, seed, *args, **kwargs):
del image, mask, model_name, guidance_scale, paste, compensate, seed, args, kwargs
return min(240, 90 + int(steps) * 5)
_get_pipeline(DEFAULT_MODEL_KEY)
@spaces.GPU(duration=1)
def _zerogpu_probe():
return "ready"
@spaces.GPU(duration=_estimate_duration)
def run_inpaint(image, mask, model_name, steps, guidance_scale, paste, compensate, seed):
image, mask = _normalize_inputs(image, mask)
model_key = _model_key(model_name)
seed_value = 0 if seed is None else int(seed)
pipe = _get_pipeline(model_key)
started = time.perf_counter()
try:
_set_pipeline_device(pipe, "cuda")
with torch.inference_mode():
outputs = pipe(
[image],
[mask],
image_size=512,
num_steps=int(steps),
guidance_scale=float(guidance_scale),
paste=bool(paste),
compensate=bool(compensate),
noise_offset=0.0357,
retry=seed_value,
mute=True,
)
elapsed = time.perf_counter() - started
return outputs[0], f"Completed in {elapsed:.1f}s"
finally:
_set_pipeline_device(pipe, "cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
with gr.Blocks(title="Moebius Inpainting", fill_width=True) as demo:
gr.Markdown("# Moebius Inpainting")
with gr.Row():
with gr.Column(scale=1, min_width=320):
input_image = gr.Image(
label="Image",
type="pil",
image_mode="RGB",
sources=["upload", "clipboard"],
height=360,
)
input_mask = gr.Image(
label="Mask",
type="pil",
image_mode="L",
sources=["upload", "clipboard"],
height=360,
)
with gr.Column(scale=1, min_width=320):
output_image = gr.Image(label="Result", type="pil", height=520)
status = gr.Markdown()
with gr.Row():
model_name = gr.Dropdown(
label="Checkpoint",
choices=list(MODEL_CHOICES.keys()),
value="General scenes (Places2)",
min_width=240,
)
steps = gr.Slider(4, 30, value=20, step=1, label="Steps", min_width=180)
guidance_scale = gr.Slider(1.0, 6.0, value=2.0, step=0.1, label="CFG", min_width=180)
seed = gr.Number(value=0, precision=0, label="Seed", min_width=140)
with gr.Row():
paste = gr.Checkbox(value=True, label="Paste")
compensate = gr.Checkbox(value=False, label="Compensate")
run_button = gr.Button("Inpaint", variant="primary")
run_button.click(
fn=run_inpaint,
inputs=[input_image, input_mask, model_name, steps, guidance_scale, paste, compensate, seed],
outputs=[output_image, status],
api_name="inpaint",
concurrency_limit=1,
)
gr.Examples(
examples=[
["examples/road.png", "examples/road_rocks_mask.png", "General scenes (Places2)", 20, 2.0, True, False, 0],
["examples/bench.png", "examples/bench_mask.png", "General scenes (Places2)", 20, 2.0, True, False, 1],
],
inputs=[input_image, input_mask, model_name, steps, guidance_scale, paste, compensate, seed],
outputs=[output_image, status],
fn=run_inpaint,
cache_examples=True,
cache_mode="lazy",
)
demo.queue(max_size=8, default_concurrency_limit=1)
if __name__ == "__main__":
demo.launch()