Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from huggingface_hub import hf_hub_download | |
| from utils import utils, tools, preprocess | |
| # BASE_MODEL_PATH = "stablediffusionapi/neta-art-xl-v2" | |
| VAE_PATH = "madebyollin/sdxl-vae-fp16-fix" | |
| REPO_ID = "Pbihao/ControlNeXt" | |
| UNET_FILENAME = "ControlAny-SDXL/anime_canny/unet.safetensors" | |
| CONTROLNET_FILENAME = "ControlAny-SDXL/anime_canny/controlnet.safetensors" | |
| CACHE_DIR = None | |
| def ui(): | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model_file = hf_hub_download( | |
| repo_id='Lykon/AAM_XL_AnimeMix', | |
| filename='AAM_XL_Anime_Mix.safetensors', | |
| cache_dir=CACHE_DIR, | |
| ) | |
| unet_file = hf_hub_download( | |
| repo_id=REPO_ID, | |
| filename=UNET_FILENAME, | |
| cache_dir=CACHE_DIR, | |
| ) | |
| controlnet_file = hf_hub_download( | |
| repo_id=REPO_ID, | |
| filename=CONTROLNET_FILENAME, | |
| cache_dir=CACHE_DIR, | |
| ) | |
| pipeline = tools.get_pipeline( | |
| pretrained_model_name_or_path=model_file, | |
| unet_model_name_or_path=unet_file, | |
| controlnet_model_name_or_path=controlnet_file, | |
| vae_model_name_or_path=VAE_PATH, | |
| load_weight_increasement=True, | |
| device=device, | |
| hf_cache_dir=CACHE_DIR, | |
| use_safetensors=True, | |
| enable_xformers_memory_efficient_attention=True, | |
| ) | |
| preprocessors = ['canny'] | |
| schedulers = ['Euler A', 'UniPC', 'Euler', 'DDIM', 'DDPM'] | |
| css = """ | |
| #col-container { | |
| margin: 0 auto; | |
| max-width: 520px; | |
| } | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| gr.Markdown(f""" | |
| # [ControlNeXt](https://github.com/dvlab-research/ControlNeXt) Official Demo | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=9): | |
| prompt = gr.Textbox(lines=3, placeholder='prompt', container=False) | |
| negative_prompt = gr.Textbox(lines=3, placeholder='negative prompt', container=False) | |
| with gr.Column(scale=1): | |
| generate_button = gr.Button("Generate", variant='primary', min_width=96) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| with gr.Row(): | |
| control_image = gr.Image( | |
| value=None, | |
| label='Condition', | |
| sources=['upload'], | |
| type='pil', | |
| height=512, | |
| show_download_button=True, | |
| show_share_button=True, | |
| ) | |
| with gr.Row(): | |
| scheduler = gr.Dropdown( | |
| label='Scheduler', | |
| choices=schedulers, | |
| value='Euler A', | |
| multiselect=False, | |
| allow_custom_value=False, | |
| filterable=True, | |
| ) | |
| num_inference_steps = gr.Slider(minimum=1, maximum=100, step=1, value=20, label='Steps') | |
| with gr.Row(): | |
| cfg_scale = gr.Slider(minimum=1, maximum=30, step=1, value=7.5, label='CFG Scale') | |
| controlnet_scale = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label='ControlNet Scale') | |
| with gr.Row(): | |
| seed = gr.Number(label='Seed', step=1, precision=0, value=-1) | |
| with gr.Row(): | |
| processor = gr.Dropdown( | |
| label='Image Preprocessor', | |
| choices=preprocessors, | |
| value='canny', | |
| ) | |
| process_button = gr.Button("Process", variant='primary', min_width=96, scale=0) | |
| with gr.Column(scale=1): | |
| output = gr.Gallery( | |
| label='Output', | |
| value=None, | |
| object_fit='scale-down', | |
| columns=4, | |
| height=512, | |
| show_download_button=True, | |
| show_share_button=True, | |
| ) | |
| def generate( | |
| prompt, | |
| control_image, | |
| negative_prompt, | |
| cfg_scale, | |
| controlnet_scale, | |
| num_inference_steps, | |
| scheduler, | |
| seed, | |
| ): | |
| pipeline.scheduler = tools.get_scheduler(scheduler, pipeline.scheduler.config) | |
| generator = torch.Generator(device=device).manual_seed(max(0, min(seed, np.iinfo(np.int32).max))) if seed != -1 else None | |
| if control_image is None: | |
| raise gr.Error('Please upload an image.') | |
| width, height = utils.around_reso(control_image.width, control_image.height, reso=1024, max_width=2048, max_height=2048, divisible=32) | |
| control_image = control_image.resize((width, height)).convert('RGB') | |
| with torch.autocast(device): | |
| output_images = pipeline.__call__( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| controlnet_image=control_image, | |
| controlnet_scale=controlnet_scale, | |
| width=width, | |
| height=height, | |
| generator=generator, | |
| guidance_scale=cfg_scale, | |
| num_inference_steps=num_inference_steps, | |
| ).images | |
| return output_images | |
| def process( | |
| image, | |
| processor, | |
| ): | |
| if image is None: | |
| raise gr.Error('Please upload an image.') | |
| processor = preprocess.get_extractor(processor) | |
| image = processor(image) | |
| return image | |
| generate_button.click( | |
| fn=generate, | |
| inputs=[prompt, control_image, negative_prompt, cfg_scale, controlnet_scale, num_inference_steps, scheduler, seed], | |
| outputs=[output], | |
| ) | |
| process_button.click( | |
| fn=process, | |
| inputs=[control_image, processor], | |
| outputs=[control_image], | |
| ) | |
| return demo | |
| if __name__ == '__main__': | |
| demo = ui() | |
| demo.queue().launch() | |