import spaces import os # import shlex # import subprocess # if os.getenv('SYSTEM', "") == 'spaces' and os.getenv('USE_PRIVATE_PACKAGE', False): # GITHUB_TOKEN = os.getenv('GITHUB_TOKEN') # GITHUB_USER = os.getenv('GITHUB_USER') # git_repo = f"https://{GITHUB_TOKEN}@github.com/{GITHUB_USER}/VIBE.git" # subprocess.call(shlex.split(f'pip install git+{git_repo}')) from functools import partial from gradio.components import Image, Textbox import gradio as gr from PIL import Image as PILImage from huggingface_hub import snapshot_download import os import random import torch import numpy as np import pathlib from vibe.editor import ImageEditor MAX_SEED = np.iinfo(np.int32).max def load_pipeline(): HF_TOKEN = os.getenv('HF_TOKEN') model_path = snapshot_download( repo_id="iitolstykh/VIBE-Image-Edit", repo_type="model", token=HF_TOKEN, ) # Load model editor_pipeline = ImageEditor( checkpoint_path=model_path, image_guidance_scale=1.2, guidance_scale=4.5, num_inference_steps=20, device="cuda", ) print(f"Model loaded. Model device: {editor_pipeline.pipe.device}") return editor_pipeline pipeline = load_pipeline() def set_env(seed=0): torch.manual_seed(seed) torch.set_grad_enabled(False) def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: if randomize_seed: seed = random.randint(0, MAX_SEED) return seed @spaces.GPU(duration=180) def generate_img( pil_image, edit_prompt: str, sample_steps, scale, image_guidance_scale, seed, progress=gr.Progress(track_tqdm=True), ): edited_image = pipeline.generate_edited_image( instruction=edit_prompt, conditioning_image=pil_image, num_images_per_prompt=1, num_inference_steps=sample_steps, guidance_scale=scale, image_guidance_scale=image_guidance_scale, seed=seed, ) return edited_image[0] if __name__ == "__main__": DESCRIPTION = f"""DEMO for VIBE-Image-Edit model: https://huggingface.co/iitolstykh/VIBE-Image-Edit""" image_dir = pathlib.Path('images') examples = [[path.as_posix(), "let this case swim in the river", 20, 4.5, 1.2, 42] for path in sorted(image_dir.glob('*.png'))] demo = gr.Interface( fn=generate_img, inputs=[ gr.Image(label="Input", type="pil"), Textbox(label="Prompt", placeholder="Please enter your prompt. \n"), gr.Slider(label="Sample Steps", minimum=1, maximum=100, value=20, step=1), gr.Slider( label="Guidance Scale", minimum=0.1, maximum=30.0, value=4.5, step=0.1 ), gr.Slider( label="Image Guidance Scale", minimum=0.1, maximum=30.0, value=1.2, step=0.1, ), gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, ), ], # outputs = [gr.Gallery(label="Result", show_label=False, type="pil")], outputs = [gr.Image(label="Result", show_label=False, type="pil")], title="", description=DESCRIPTION, examples=examples, ) demo.queue(max_size=100).launch()