import os _hf_cache = "/data/.cache/huggingface" if os.path.isdir("/data") and os.access("/data", os.W_OK) else "/tmp/hf_home" os.environ.setdefault("HF_HOME", _hf_cache) os.environ.setdefault("HF_MODULES_CACHE", "/tmp/hf_modules") os.environ.setdefault("MPLCONFIGDIR", "/tmp/matplotlib") os.environ.setdefault("GRADIO_SSR_MODE", "false") import time from pathlib import Path from typing import Dict, Tuple import spaces import gradio as gr import torch from diffusers import DDIMScheduler from diffusers.models import AutoencoderKL from huggingface_hub import hf_hub_download, snapshot_download from PIL import Image from removal.v1_2 import build_removal_model, load_cfg, load_removal_model from removal.v1_2.pipeline import RemovalSDXLPipeline_BatchMode ROOT = Path(__file__).resolve().parent CONFIG_PATH = ROOT / "config" / "model_cfg" / "moebius.yaml" MOEBIUS_REPO = "hustvl/Moebius" PIXELHACKER_REPO = "hustvl/PixelHacker" DEFAULT_MODEL_KEY = "ft_places2" MODEL_CHOICES = { "General scenes (Places2)": "ft_places2", "Portraits (CelebA-HQ)": "ft_celebahq", "Faces (FFHQ)": "ft_ffhq", "Pretrained": "pretrained", } _PIPELINE_CACHE: Dict[str, RemovalSDXLPipeline_BatchMode] = {} def _download_vae_dir() -> str: repo_dir = snapshot_download( repo_id=PIXELHACKER_REPO, allow_patterns=["vae/*"], ) return str(Path(repo_dir) / "vae") def _download_model_weight(model_key: str) -> str: return hf_hub_download( repo_id=MOEBIUS_REPO, filename=f"{model_key}/diffusion_pytorch_model.bin", ) def _build_cpu_pipeline(model_key: str) -> RemovalSDXLPipeline_BatchMode: model_cfg = load_cfg(str(CONFIG_PATH)) model_cfg["vae"]["model_dir"] = _download_vae_dir() removal_model = build_removal_model(model_cfg, 20) weight_path = _download_model_weight(model_key) print(load_removal_model(removal_model, weight_path, device="cpu")) vae = AutoencoderKL.from_pretrained(model_cfg["vae"]["model_dir"]) scheduler = DDIMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False, ) return RemovalSDXLPipeline_BatchMode( removal_model=removal_model, vae=vae, scheduler=scheduler, device="cpu", dtype=torch.float32, ) def _get_pipeline(model_key: str) -> RemovalSDXLPipeline_BatchMode: if model_key not in _PIPELINE_CACHE: _PIPELINE_CACHE[model_key] = _build_cpu_pipeline(model_key) return _PIPELINE_CACHE[model_key] def _set_pipeline_device(pipe: RemovalSDXLPipeline_BatchMode, device: str) -> None: pipe.device = device pipe.vae.to(device=device, dtype=pipe.dtype).eval() pipe.removal_model.to(device=device, dtype=pipe.dtype).eval() half_id_num = pipe.removal_model.num_embeddings // 2 id_num = pipe.removal_model.num_embeddings input_ids = torch.tensor([list(range(half_id_num))], dtype=torch.int64, device=device, requires_grad=False) un_input_ids = torch.tensor([list(range(half_id_num, id_num))], dtype=torch.int64, device=device, requires_grad=False) pipe.input_ids = torch.cat([un_input_ids, input_ids]).to(device=device) def _normalize_inputs(image: Image.Image, mask: Image.Image) -> Tuple[Image.Image, Image.Image]: if image is None: raise gr.Error("Upload an image.") if mask is None: raise gr.Error("Upload a mask.") image = image.convert("RGB") mask = mask.convert("L").resize(image.size, Image.Resampling.NEAREST) mask_min, mask_max = mask.getextrema() if mask_max < 8: raise gr.Error("The mask is empty. Use white pixels for the area to inpaint.") if mask_min > 247: raise gr.Error("The mask covers the whole image. Leave black pixels outside the edit area.") return image, mask def _model_key(label: str) -> str: return MODEL_CHOICES.get(label, DEFAULT_MODEL_KEY) def _estimate_duration(image, mask, model_name, steps, guidance_scale, paste, compensate, seed, *args, **kwargs): del image, mask, model_name, guidance_scale, paste, compensate, seed, args, kwargs return min(240, 90 + int(steps) * 5) _get_pipeline(DEFAULT_MODEL_KEY) @spaces.GPU(duration=1) def _zerogpu_probe(): return "ready" @spaces.GPU(duration=_estimate_duration) def run_inpaint(image, mask, model_name, steps, guidance_scale, paste, compensate, seed): image, mask = _normalize_inputs(image, mask) model_key = _model_key(model_name) seed_value = 0 if seed is None else int(seed) pipe = _get_pipeline(model_key) started = time.perf_counter() try: _set_pipeline_device(pipe, "cuda") with torch.inference_mode(): outputs = pipe( [image], [mask], image_size=512, num_steps=int(steps), guidance_scale=float(guidance_scale), paste=bool(paste), compensate=bool(compensate), noise_offset=0.0357, retry=seed_value, mute=True, ) elapsed = time.perf_counter() - started return outputs[0], f"Completed in {elapsed:.1f}s" finally: _set_pipeline_device(pipe, "cpu") if torch.cuda.is_available(): torch.cuda.empty_cache() with gr.Blocks(title="Moebius Inpainting", fill_width=True) as demo: gr.Markdown("# Moebius Inpainting") with gr.Row(): with gr.Column(scale=1, min_width=320): input_image = gr.Image( label="Image", type="pil", image_mode="RGB", sources=["upload", "clipboard"], height=360, ) input_mask = gr.Image( label="Mask", type="pil", image_mode="L", sources=["upload", "clipboard"], height=360, ) with gr.Column(scale=1, min_width=320): output_image = gr.Image(label="Result", type="pil", height=520) status = gr.Markdown() with gr.Row(): model_name = gr.Dropdown( label="Checkpoint", choices=list(MODEL_CHOICES.keys()), value="General scenes (Places2)", min_width=240, ) steps = gr.Slider(4, 30, value=20, step=1, label="Steps", min_width=180) guidance_scale = gr.Slider(1.0, 6.0, value=2.0, step=0.1, label="CFG", min_width=180) seed = gr.Number(value=0, precision=0, label="Seed", min_width=140) with gr.Row(): paste = gr.Checkbox(value=True, label="Paste") compensate = gr.Checkbox(value=False, label="Compensate") run_button = gr.Button("Inpaint", variant="primary") run_button.click( fn=run_inpaint, inputs=[input_image, input_mask, model_name, steps, guidance_scale, paste, compensate, seed], outputs=[output_image, status], api_name="inpaint", concurrency_limit=1, ) gr.Examples( examples=[ ["examples/road.png", "examples/road_rocks_mask.png", "General scenes (Places2)", 20, 2.0, True, False, 0], ["examples/bench.png", "examples/bench_mask.png", "General scenes (Places2)", 20, 2.0, True, False, 1], ], inputs=[input_image, input_mask, model_name, steps, guidance_scale, paste, compensate, seed], outputs=[output_image, status], fn=run_inpaint, cache_examples=True, cache_mode="lazy", ) demo.queue(max_size=8, default_concurrency_limit=1) if __name__ == "__main__": demo.launch()