File size: 3,230 Bytes
b988971
bb2e2c7
b988971
a697bd3
 
 
 
 
 
 
 
 
 
b988971
 
 
 
 
 
 
0f018d2
 
b988971
 
 
 
 
d9ccc15
 
 
 
 
 
 
 
0f018d2
 
 
 
 
 
d9ccc15
b988971
 
 
87c1aac
b988971
87c1aac
b988971
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d9ccc15
b988971
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# Copyright 2024 Adobe. All rights reserved.
import spaces

# Patch gradio_client 5.21.0 bug: _json_schema_to_python_type receives bool
# (from "additionalProperties": false in JSON schema) and crashes on "const" in bool
import gradio_client.utils as _gc_utils
_orig_schema_to_type = _gc_utils._json_schema_to_python_type
def _patched_schema_to_type(schema, defs=None):
    if not isinstance(schema, dict):
        return "any"
    return _orig_schema_to_type(schema, defs)
_gc_utils._json_schema_to_python_type = _patched_schema_to_type

from huggingface_hub import hf_hub_download
from run_magicfu import MagicFixup
import torchvision
from torch import autocast
from PIL import Image
import gradio as gr
import numpy as np
import shutil
import os


# Download checkpoint from HF Hub at startup
checkpoint_path = hf_hub_download(repo_id="HadiZayer/MagicFixup", filename="magicfu_weights")

# Download example images from the Space repo at startup
EXAMPLE_PAIRS = [
    ("examples/fox_drinking_og.png",    "examples/fox_drinking__edit__01.png"),
    ("examples/palm_tree_og.png",       "examples/palm_tree__edit__01.png"),
    ("examples/kingfisher_og.png",      "examples/kingfisher__edit__001.png"),
    ("examples/pipes_og.png",           "examples/pipes__edit__01.png"),
    ("examples/dog_beach_og.png",       "examples/dog_beach__edit__003.png"),
]
os.makedirs("/tmp/magicfixup_examples", exist_ok=True)
examples = []
for og, edit in EXAMPLE_PAIRS:
    og_tmp   = shutil.copy(hf_hub_download(repo_id="HadiZayer/MagicFixup", filename=og,   repo_type="model"), "/tmp/magicfixup_examples")
    edit_tmp = shutil.copy(hf_hub_download(repo_id="HadiZayer/MagicFixup", filename=edit, repo_type="model"), "/tmp/magicfixup_examples")
    examples.append([og_tmp, edit_tmp])

magic_fixup = MagicFixup(model_path=checkpoint_path)


@spaces.GPU
def sample(original_image, coarse_edit):
    magic_fixup.model.cuda()
    to_tensor = torchvision.transforms.ToTensor()
    with autocast("cuda"):
        w, h = coarse_edit.size
        ref_image_t = to_tensor(original_image.resize((512, 512))).half().cuda()
        coarse_edit_t = to_tensor(coarse_edit.resize((512, 512))).half().cuda()
        coarse_edit_mask_t = to_tensor(coarse_edit.resize((512, 512))).half().cuda()
        mask_t = (coarse_edit_mask_t[-1][None, None, ...]).half()
        coarse_edit_t_rgb = coarse_edit_t[:-1]

        out_rgb = magic_fixup.edit_image(ref_image_t, coarse_edit_t_rgb, mask_t, start_step=1.0, steps=50)
        output = out_rgb.squeeze().cpu().detach().moveaxis(0, -1).float().numpy()
        output = (output * 255.0).astype(np.uint8)
        output_pil = Image.fromarray(output)
        output_pil = output_pil.resize((w, h))
        return output_pil


demo = gr.Interface(
    fn=sample,
    inputs=[
        gr.Image(type="pil", image_mode="RGB", label="Original Image"),
        gr.Image(type="pil", image_mode="RGBA", label="Coarse Edit (with alpha mask)"),
    ],
    outputs=gr.Image(label="Result"),
    examples=examples,
    title="Magic Fixup",
    description="Upload your original image and a coarse edit (PNG with alpha channel marking the edited region). Magic Fixup will refine the edit to look photorealistic.",
)

demo.launch()