File size: 3,386 Bytes
49afada
1199a4a
18b7145
 
 
 
 
 
 
 
1199a4a
 
 
 
 
b810fa6
1199a4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c480799
1199a4a
 
3e983ee
 
1199a4a
 
a73d81d
a87ce70
 
 
1199a4a
 
 
 
 
 
 
 
 
 
 
37e6d36
1199a4a
37e6d36
1199a4a
 
 
 
eb3d4ce
a87ce70
37e6d36
1199a4a
 
 
 
 
 
 
eb3d4ce
37e6d36
cf4eff2
1199a4a
 
 
 
55031b2
1199a4a
 
 
 
 
a87ce70
1199a4a
37e6d36
1199a4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa90884
 
1199a4a
 
 
 
 
37e6d36
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import spaces
import os
# import shlex
# import subprocess

# if os.getenv('SYSTEM', "") == 'spaces' and os.getenv('USE_PRIVATE_PACKAGE', False):
#     GITHUB_TOKEN = os.getenv('GITHUB_TOKEN')
#     GITHUB_USER = os.getenv('GITHUB_USER')
#     git_repo = f"https://{GITHUB_TOKEN}@github.com/{GITHUB_USER}/VIBE.git"
#     subprocess.call(shlex.split(f'pip install git+{git_repo}'))

from functools import partial
from gradio.components import Image, Textbox
import gradio as gr

from PIL import Image as PILImage
from huggingface_hub import snapshot_download
import os
import random
import torch
import numpy as np
import pathlib

from vibe.editor import ImageEditor

MAX_SEED = np.iinfo(np.int32).max


def load_pipeline():
    HF_TOKEN = os.getenv('HF_TOKEN')
    model_path = snapshot_download(
        repo_id="iitolstykh/VIBE-Image-Edit",
        repo_type="model",
        token=HF_TOKEN,
    )

    # Load model
    editor_pipeline = ImageEditor(
        checkpoint_path=model_path,
        image_guidance_scale=1.2,
        guidance_scale=4.5,
        num_inference_steps=20,
        device="cuda",
    )

    print(f"Model loaded. Model device: {editor_pipeline.pipe.device}")

    return editor_pipeline


pipeline = load_pipeline()


def set_env(seed=0):
    torch.manual_seed(seed)
    torch.set_grad_enabled(False)


def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    return seed


@spaces.GPU(duration=180)
def generate_img(
    pil_image,
    edit_prompt: str,
    sample_steps,
    scale,
    image_guidance_scale,
    seed,
    progress=gr.Progress(track_tqdm=True),
):    
    edited_image = pipeline.generate_edited_image(
        instruction=edit_prompt,
        conditioning_image=pil_image,
        num_images_per_prompt=1,
        num_inference_steps=sample_steps,
        guidance_scale=scale,
        image_guidance_scale=image_guidance_scale,
        seed=seed,
    )
    return edited_image[0]


if __name__ == "__main__":

    DESCRIPTION = f"""DEMO for VIBE-Image-Edit model: https://huggingface.co/iitolstykh/VIBE-Image-Edit"""

    image_dir = pathlib.Path('images')
    examples = [[path.as_posix(), "let this case swim in the river", 20, 4.5, 1.2, 42] for path in sorted(image_dir.glob('*.png'))]

    demo = gr.Interface(
        fn=generate_img,
        inputs=[
            gr.Image(label="Input", type="pil"),
            Textbox(label="Prompt", placeholder="Please enter your prompt. \n"),
            gr.Slider(label="Sample Steps", minimum=1, maximum=100, value=20, step=1),
            gr.Slider(
                label="Guidance Scale", minimum=0.1, maximum=30.0, value=4.5, step=0.1
            ),
            gr.Slider(
                label="Image Guidance Scale",
                minimum=0.1,
                maximum=30.0,
                value=1.2,
                step=0.1,
            ),
            gr.Slider(
                label="Seed",
                minimum=0,
                maximum=MAX_SEED,
                step=1,
                value=42,
            ),
        ],
        # outputs = [gr.Gallery(label="Result", show_label=False, type="pil")],
        outputs = [gr.Image(label="Result", show_label=False, type="pil")],
        title="",
        description=DESCRIPTION,
        examples=examples,
    )

    demo.queue(max_size=100).launch()