iitolstykh's picture
Update app.py
18b7145 verified
import spaces
import os
# import shlex
# import subprocess
# if os.getenv('SYSTEM', "") == 'spaces' and os.getenv('USE_PRIVATE_PACKAGE', False):
# GITHUB_TOKEN = os.getenv('GITHUB_TOKEN')
# GITHUB_USER = os.getenv('GITHUB_USER')
# git_repo = f"https://{GITHUB_TOKEN}@github.com/{GITHUB_USER}/VIBE.git"
# subprocess.call(shlex.split(f'pip install git+{git_repo}'))
from functools import partial
from gradio.components import Image, Textbox
import gradio as gr
from PIL import Image as PILImage
from huggingface_hub import snapshot_download
import os
import random
import torch
import numpy as np
import pathlib
from vibe.editor import ImageEditor
MAX_SEED = np.iinfo(np.int32).max
def load_pipeline():
HF_TOKEN = os.getenv('HF_TOKEN')
model_path = snapshot_download(
repo_id="iitolstykh/VIBE-Image-Edit",
repo_type="model",
token=HF_TOKEN,
)
# Load model
editor_pipeline = ImageEditor(
checkpoint_path=model_path,
image_guidance_scale=1.2,
guidance_scale=4.5,
num_inference_steps=20,
device="cuda",
)
print(f"Model loaded. Model device: {editor_pipeline.pipe.device}")
return editor_pipeline
pipeline = load_pipeline()
def set_env(seed=0):
torch.manual_seed(seed)
torch.set_grad_enabled(False)
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
if randomize_seed:
seed = random.randint(0, MAX_SEED)
return seed
@spaces.GPU(duration=180)
def generate_img(
pil_image,
edit_prompt: str,
sample_steps,
scale,
image_guidance_scale,
seed,
progress=gr.Progress(track_tqdm=True),
):
edited_image = pipeline.generate_edited_image(
instruction=edit_prompt,
conditioning_image=pil_image,
num_images_per_prompt=1,
num_inference_steps=sample_steps,
guidance_scale=scale,
image_guidance_scale=image_guidance_scale,
seed=seed,
)
return edited_image[0]
if __name__ == "__main__":
DESCRIPTION = f"""DEMO for VIBE-Image-Edit model: https://huggingface.co/iitolstykh/VIBE-Image-Edit"""
image_dir = pathlib.Path('images')
examples = [[path.as_posix(), "let this case swim in the river", 20, 4.5, 1.2, 42] for path in sorted(image_dir.glob('*.png'))]
demo = gr.Interface(
fn=generate_img,
inputs=[
gr.Image(label="Input", type="pil"),
Textbox(label="Prompt", placeholder="Please enter your prompt. \n"),
gr.Slider(label="Sample Steps", minimum=1, maximum=100, value=20, step=1),
gr.Slider(
label="Guidance Scale", minimum=0.1, maximum=30.0, value=4.5, step=0.1
),
gr.Slider(
label="Image Guidance Scale",
minimum=0.1,
maximum=30.0,
value=1.2,
step=0.1,
),
gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=42,
),
],
# outputs = [gr.Gallery(label="Result", show_label=False, type="pil")],
outputs = [gr.Image(label="Result", show_label=False, type="pil")],
title="",
description=DESCRIPTION,
examples=examples,
)
demo.queue(max_size=100).launch()