| import os |
|
|
| _hf_cache = "/data/.cache/huggingface" if os.path.isdir("/data") and os.access("/data", os.W_OK) else "/tmp/hf_home" |
| os.environ.setdefault("HF_HOME", _hf_cache) |
| os.environ.setdefault("HF_MODULES_CACHE", "/tmp/hf_modules") |
| os.environ.setdefault("MPLCONFIGDIR", "/tmp/matplotlib") |
| os.environ.setdefault("GRADIO_SSR_MODE", "false") |
|
|
| import time |
| from pathlib import Path |
| from typing import Dict, Tuple |
|
|
| import spaces |
| import gradio as gr |
| import torch |
| from diffusers import DDIMScheduler |
| from diffusers.models import AutoencoderKL |
| from huggingface_hub import hf_hub_download, snapshot_download |
| from PIL import Image |
|
|
| from removal.v1_2 import build_removal_model, load_cfg, load_removal_model |
| from removal.v1_2.pipeline import RemovalSDXLPipeline_BatchMode |
|
|
|
|
| ROOT = Path(__file__).resolve().parent |
| CONFIG_PATH = ROOT / "config" / "model_cfg" / "moebius.yaml" |
| MOEBIUS_REPO = "hustvl/Moebius" |
| PIXELHACKER_REPO = "hustvl/PixelHacker" |
| DEFAULT_MODEL_KEY = "ft_places2" |
|
|
| MODEL_CHOICES = { |
| "General scenes (Places2)": "ft_places2", |
| "Portraits (CelebA-HQ)": "ft_celebahq", |
| "Faces (FFHQ)": "ft_ffhq", |
| "Pretrained": "pretrained", |
| } |
|
|
| _PIPELINE_CACHE: Dict[str, RemovalSDXLPipeline_BatchMode] = {} |
|
|
|
|
| def _download_vae_dir() -> str: |
| repo_dir = snapshot_download( |
| repo_id=PIXELHACKER_REPO, |
| allow_patterns=["vae/*"], |
| ) |
| return str(Path(repo_dir) / "vae") |
|
|
|
|
| def _download_model_weight(model_key: str) -> str: |
| return hf_hub_download( |
| repo_id=MOEBIUS_REPO, |
| filename=f"{model_key}/diffusion_pytorch_model.bin", |
| ) |
|
|
|
|
| def _build_cpu_pipeline(model_key: str) -> RemovalSDXLPipeline_BatchMode: |
| model_cfg = load_cfg(str(CONFIG_PATH)) |
| model_cfg["vae"]["model_dir"] = _download_vae_dir() |
|
|
| removal_model = build_removal_model(model_cfg, 20) |
| weight_path = _download_model_weight(model_key) |
| print(load_removal_model(removal_model, weight_path, device="cpu")) |
|
|
| vae = AutoencoderKL.from_pretrained(model_cfg["vae"]["model_dir"]) |
| scheduler = DDIMScheduler( |
| beta_start=0.00085, |
| beta_end=0.012, |
| beta_schedule="scaled_linear", |
| num_train_timesteps=1000, |
| clip_sample=False, |
| ) |
|
|
| return RemovalSDXLPipeline_BatchMode( |
| removal_model=removal_model, |
| vae=vae, |
| scheduler=scheduler, |
| device="cpu", |
| dtype=torch.float32, |
| ) |
|
|
|
|
| def _get_pipeline(model_key: str) -> RemovalSDXLPipeline_BatchMode: |
| if model_key not in _PIPELINE_CACHE: |
| _PIPELINE_CACHE[model_key] = _build_cpu_pipeline(model_key) |
| return _PIPELINE_CACHE[model_key] |
|
|
|
|
| def _set_pipeline_device(pipe: RemovalSDXLPipeline_BatchMode, device: str) -> None: |
| pipe.device = device |
| pipe.vae.to(device=device, dtype=pipe.dtype).eval() |
| pipe.removal_model.to(device=device, dtype=pipe.dtype).eval() |
|
|
| half_id_num = pipe.removal_model.num_embeddings // 2 |
| id_num = pipe.removal_model.num_embeddings |
| input_ids = torch.tensor([list(range(half_id_num))], dtype=torch.int64, device=device, requires_grad=False) |
| un_input_ids = torch.tensor([list(range(half_id_num, id_num))], dtype=torch.int64, device=device, requires_grad=False) |
| pipe.input_ids = torch.cat([un_input_ids, input_ids]).to(device=device) |
|
|
|
|
| def _normalize_inputs(image: Image.Image, mask: Image.Image) -> Tuple[Image.Image, Image.Image]: |
| if image is None: |
| raise gr.Error("Upload an image.") |
| if mask is None: |
| raise gr.Error("Upload a mask.") |
|
|
| image = image.convert("RGB") |
| mask = mask.convert("L").resize(image.size, Image.Resampling.NEAREST) |
|
|
| mask_min, mask_max = mask.getextrema() |
| if mask_max < 8: |
| raise gr.Error("The mask is empty. Use white pixels for the area to inpaint.") |
| if mask_min > 247: |
| raise gr.Error("The mask covers the whole image. Leave black pixels outside the edit area.") |
|
|
| return image, mask |
|
|
|
|
| def _model_key(label: str) -> str: |
| return MODEL_CHOICES.get(label, DEFAULT_MODEL_KEY) |
|
|
|
|
| def _estimate_duration(image, mask, model_name, steps, guidance_scale, paste, compensate, seed, *args, **kwargs): |
| del image, mask, model_name, guidance_scale, paste, compensate, seed, args, kwargs |
| return min(240, 90 + int(steps) * 5) |
|
|
|
|
| _get_pipeline(DEFAULT_MODEL_KEY) |
|
|
|
|
| @spaces.GPU(duration=1) |
| def _zerogpu_probe(): |
| return "ready" |
|
|
|
|
| @spaces.GPU(duration=_estimate_duration) |
| def run_inpaint(image, mask, model_name, steps, guidance_scale, paste, compensate, seed): |
| image, mask = _normalize_inputs(image, mask) |
| model_key = _model_key(model_name) |
| seed_value = 0 if seed is None else int(seed) |
| pipe = _get_pipeline(model_key) |
|
|
| started = time.perf_counter() |
| try: |
| _set_pipeline_device(pipe, "cuda") |
| with torch.inference_mode(): |
| outputs = pipe( |
| [image], |
| [mask], |
| image_size=512, |
| num_steps=int(steps), |
| guidance_scale=float(guidance_scale), |
| paste=bool(paste), |
| compensate=bool(compensate), |
| noise_offset=0.0357, |
| retry=seed_value, |
| mute=True, |
| ) |
| elapsed = time.perf_counter() - started |
| return outputs[0], f"Completed in {elapsed:.1f}s" |
| finally: |
| _set_pipeline_device(pipe, "cpu") |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
|
|
| with gr.Blocks(title="Moebius Inpainting", fill_width=True) as demo: |
| gr.Markdown("# Moebius Inpainting") |
| with gr.Row(): |
| with gr.Column(scale=1, min_width=320): |
| input_image = gr.Image( |
| label="Image", |
| type="pil", |
| image_mode="RGB", |
| sources=["upload", "clipboard"], |
| height=360, |
| ) |
| input_mask = gr.Image( |
| label="Mask", |
| type="pil", |
| image_mode="L", |
| sources=["upload", "clipboard"], |
| height=360, |
| ) |
| with gr.Column(scale=1, min_width=320): |
| output_image = gr.Image(label="Result", type="pil", height=520) |
| status = gr.Markdown() |
|
|
| with gr.Row(): |
| model_name = gr.Dropdown( |
| label="Checkpoint", |
| choices=list(MODEL_CHOICES.keys()), |
| value="General scenes (Places2)", |
| min_width=240, |
| ) |
| steps = gr.Slider(4, 30, value=20, step=1, label="Steps", min_width=180) |
| guidance_scale = gr.Slider(1.0, 6.0, value=2.0, step=0.1, label="CFG", min_width=180) |
| seed = gr.Number(value=0, precision=0, label="Seed", min_width=140) |
|
|
| with gr.Row(): |
| paste = gr.Checkbox(value=True, label="Paste") |
| compensate = gr.Checkbox(value=False, label="Compensate") |
| run_button = gr.Button("Inpaint", variant="primary") |
|
|
| run_button.click( |
| fn=run_inpaint, |
| inputs=[input_image, input_mask, model_name, steps, guidance_scale, paste, compensate, seed], |
| outputs=[output_image, status], |
| api_name="inpaint", |
| concurrency_limit=1, |
| ) |
|
|
| gr.Examples( |
| examples=[ |
| ["examples/road.png", "examples/road_rocks_mask.png", "General scenes (Places2)", 20, 2.0, True, False, 0], |
| ["examples/bench.png", "examples/bench_mask.png", "General scenes (Places2)", 20, 2.0, True, False, 1], |
| ], |
| inputs=[input_image, input_mask, model_name, steps, guidance_scale, paste, compensate, seed], |
| outputs=[output_image, status], |
| fn=run_inpaint, |
| cache_examples=True, |
| cache_mode="lazy", |
| ) |
|
|
|
|
| demo.queue(max_size=8, default_concurrency_limit=1) |
|
|
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|