Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from multiprocessing import cpu_count | |
| from pathlib import Path | |
| from src.ui_shared import ( | |
| model_ids, | |
| scheduler_names, | |
| default_scheduler, | |
| controlnet_ids, | |
| assets_directory, | |
| ) | |
| from src.ui_functions import generate, run_training | |
| default_img_size = 512 | |
| with open(f"{assets_directory}/header.MD") as fp: | |
| header = fp.read() | |
| with open(f"{assets_directory}/footer.MD") as fp: | |
| footer = fp.read() | |
| theme = gr.themes.Soft( | |
| primary_hue="blue", | |
| neutral_hue="slate", | |
| ) | |
| with gr.Blocks(theme=theme) as demo: | |
| header_component = gr.Markdown(header) | |
| with gr.Row().style(equal_height=True): | |
| with gr.Column(scale=70): | |
| prompt = gr.Textbox( | |
| label="Prompt", placeholder="Press <Shift+Enter> to generate", lines=2 | |
| ) | |
| neg_prompt = gr.Textbox(label="Negative Prompt", placeholder="", lines=2) | |
| with gr.Row(): | |
| controlnet_prompt = gr.Textbox( | |
| label="Controlnet Prompt", | |
| placeholder="If empty, defaults to base `Prompt`", | |
| lines=2, | |
| ) | |
| controlnet_negative_prompt = gr.Textbox( | |
| label="Controlnet Negative Prompt", | |
| placeholder="If empty, defaults to base `Negative Prompt`", | |
| lines=2, | |
| ) | |
| with gr.Column(scale=30): | |
| model_name = gr.Dropdown( | |
| label="Model", choices=model_ids, value=model_ids[0], allow_custom_value=True | |
| ) | |
| controlnet_name = gr.Dropdown( | |
| label="Controlnet", choices=controlnet_ids, value=controlnet_ids[0], allow_custom_value=True | |
| ) | |
| scheduler_name = gr.Dropdown( | |
| label="Scheduler", choices=scheduler_names, value=default_scheduler, allow_custom_value=True | |
| ) | |
| with gr.Row(): | |
| generate_button = gr.Button(value="Generate", variant="primary") | |
| dark_mode_btn = gr.Button("Dark Mode", variant="secondary") | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Tab("Inference") as tab: | |
| guidance_image = gr.Image( | |
| label="Guidance Image", | |
| source="upload", | |
| tool="editor", | |
| type="pil", | |
| ).style(height=256) | |
| with gr.Row(): | |
| controlnet_cond_scale = gr.Slider( | |
| label="Controlnet Weight", | |
| value=1.0, | |
| minimum=0.0, | |
| maximum=2.0, | |
| step=0.1, | |
| ) | |
| with gr.Row(): | |
| batch_size = gr.Slider( | |
| label="Batch Size", value=1, minimum=1, maximum=4, step=1 | |
| ) | |
| seed = gr.Slider(-1, 2147483647, label="Seed", value=-1, step=1) | |
| with gr.Row(): | |
| guidance = gr.Slider( | |
| label="Guidance scale", value=7.5, minimum=0, maximum=20 | |
| ) | |
| steps = gr.Slider( | |
| label="Steps", value=20, minimum=1, maximum=100, step=1 | |
| ) | |
| with gr.Row(): | |
| width = gr.Slider( | |
| label="Width", | |
| value=default_img_size, | |
| minimum=64, | |
| maximum=1024, | |
| step=32, | |
| ) | |
| height = gr.Slider( | |
| label="Height", | |
| value=default_img_size, | |
| minimum=64, | |
| maximum=1024, | |
| step=32, | |
| ) | |
| with gr.Tab("Train Anime ControlNet") as tab: | |
| with gr.Row(): | |
| train_batch_size = gr.Slider( | |
| label="Training Batch Size", | |
| minimum=1, | |
| maximum=8, | |
| step=1, | |
| value=1, | |
| ) | |
| gradient_accumulation_steps = gr.Slider( | |
| label="Gradient Accumulation steps", | |
| minimum=1, | |
| maximum=6, | |
| step=1, | |
| value=4, | |
| ) | |
| with gr.Row(): | |
| num_train_epochs = gr.Number( | |
| label="Total training epochs", value=2 | |
| ) | |
| train_learning_rate = gr.Number(label="Learning Rate", value=5.0e-6) | |
| with gr.Row(): | |
| checkpointing_steps = gr.Number( | |
| label="Steps between saving checkpoints", value=4000 | |
| ) | |
| image_logging_steps = gr.Number( | |
| label="Steps between logging example images (pass 0 to disable)", | |
| value=0, | |
| ) | |
| with gr.Row(): | |
| train_data_dir = gr.Textbox( | |
| label=f"Path to training image folder", | |
| value="lint/anybooru", | |
| ) | |
| valid_data_dir = gr.Textbox( | |
| label=f"Path to validation image folder", | |
| value="", | |
| ) | |
| with gr.Row(): | |
| controlnet_weights_path = gr.Textbox( | |
| label=f"Repo for initializing Controlnet Weights", | |
| value="lint/anime_control/anime_merge", | |
| ) | |
| output_dir = gr.Textbox( | |
| label=f"Output directory for trained weights", value="./models" | |
| ) | |
| with gr.Row(): | |
| train_whole_controlnet = gr.Checkbox( | |
| label="Train whole controlnet", value=True | |
| ) | |
| save_whole_pipeline = gr.Checkbox( | |
| label="Save whole pipeline", value=True | |
| ) | |
| training_button = gr.Button( | |
| value="Train Anime ControlNet", variant="primary" | |
| ) | |
| training_status = gr.Text(label="Training Status") | |
| with gr.Column(): | |
| gallery = gr.Gallery( | |
| label="Generated images", show_label=False, elem_id="gallery" | |
| ).style(height=default_img_size, grid=2) | |
| generation_details = gr.Markdown() | |
| # pipe_kwargs = gr.Textbox(label="Pipe kwargs", value="{\n\t\n}", visible=False) | |
| # if torch.cuda.is_available(): | |
| # giga = 2**30 | |
| # vram_guage = gr.Slider(0, torch.cuda.memory_reserved(0)/giga, label='VRAM Allocated to Reserved (GB)', value=0, step=1) | |
| # demo.load(lambda : torch.cuda.memory_allocated(0)/giga, inputs=[], outputs=vram_guage, every=0.5, show_progress=False) | |
| footer_component = gr.Markdown(footer) | |
| inputs = [ | |
| model_name, | |
| guidance_image, | |
| controlnet_name, | |
| scheduler_name, | |
| prompt, | |
| guidance, | |
| steps, | |
| batch_size, | |
| width, | |
| height, | |
| seed, | |
| neg_prompt, | |
| controlnet_prompt, | |
| controlnet_negative_prompt, | |
| controlnet_cond_scale, | |
| # pipe_kwargs, | |
| ] | |
| outputs = [gallery, generation_details] | |
| prompt.submit(generate, inputs=inputs, outputs=outputs) | |
| generate_button.click(generate, inputs=inputs, outputs=outputs) | |
| training_inputs = [ | |
| model_name, | |
| controlnet_weights_path, | |
| train_data_dir, | |
| valid_data_dir, | |
| train_batch_size, | |
| train_whole_controlnet, | |
| gradient_accumulation_steps, | |
| num_train_epochs, | |
| train_learning_rate, | |
| output_dir, | |
| checkpointing_steps, | |
| image_logging_steps, | |
| save_whole_pipeline, | |
| ] | |
| training_button.click( | |
| run_training, | |
| inputs=training_inputs, | |
| outputs=[training_status], | |
| ) | |
| # from gradio.themes.builder | |
| toggle_dark_mode_args = dict( | |
| fn=None, | |
| inputs=None, | |
| outputs=None, | |
| _js="""() => { | |
| if (document.querySelectorAll('.dark').length) { | |
| document.querySelectorAll('.dark').forEach(el => el.classList.remove('dark')); | |
| } else { | |
| document.querySelector('body').classList.add('dark'); | |
| } | |
| }""", | |
| ) | |
| demo.load(**toggle_dark_mode_args) | |
| dark_mode_btn.click(**toggle_dark_mode_args) | |
| if __name__ == "__main__": | |
| demo.queue(concurrency_count=cpu_count()).launch(favicon_path=favicon_path) | |