Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import random | |
| from PIL import Image | |
| import os | |
| import argparse | |
| import shutil | |
| import gc | |
| import importlib | |
| import json | |
| from multiprocessing import cpu_count | |
| import cv2 | |
| import numpy as np | |
| from pathlib import Path | |
| from diffusers import ( | |
| StableDiffusionControlNetPipeline, | |
| StableDiffusionPipeline, | |
| ControlNetModel, | |
| AutoencoderKL, | |
| ) | |
| from src.controlnet_pipe import ControlNetPipe as StableDiffusionControlNetPipeline | |
| from src.lab import Lab | |
| from src.ui_shared import ( | |
| default_scheduler, | |
| scheduler_dict, | |
| model_ids, | |
| controlnet_ids, | |
| is_hfspace, | |
| ) | |
| CONTROLNET_REPO = "lint/anime_control" | |
| _xformers_available = importlib.util.find_spec("xformers") is not None | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # device = 'cpu' | |
| dtype = torch.float16 if device == "cuda" else torch.float32 | |
| pipe = None | |
| loaded_model_id = "" | |
| loaded_controlnet_id = "" | |
| def load_pipe(model_id, controlnet_id, scheduler_name): | |
| global pipe, loaded_model_id, loaded_controlnet_id | |
| scheduler = scheduler_dict[scheduler_name] | |
| reload_pipe = False | |
| if pipe: | |
| new_weights = pipe.components | |
| else: | |
| new_weights = {} | |
| if model_id != loaded_model_id: | |
| new_pipe = StableDiffusionPipeline.from_pretrained( | |
| model_id, | |
| vae=AutoencoderKL.from_pretrained("lint/anime_vae", torch_dtype=dtype), | |
| safety_checker=None, | |
| feature_extractor=None, | |
| requires_safety_checker=False, | |
| use_safetensors=False, | |
| torch_dtype=dtype, | |
| ) | |
| loaded_model_id = model_id | |
| new_weights.update(new_pipe.components) | |
| new_weights["scheduler"] = scheduler.from_pretrained(model_id, subfolder="scheduler") | |
| reload_pipe = True | |
| if controlnet_id != loaded_controlnet_id: | |
| controlnet = ControlNetModel.from_pretrained( | |
| CONTROLNET_REPO, | |
| subfolder=controlnet_id, | |
| torch_dtype=dtype, | |
| ) | |
| loaded_controlnet_id = controlnet_id | |
| new_weights["controlnet"] = controlnet | |
| reload_pipe = True | |
| if reload_pipe: | |
| pipe = StableDiffusionControlNetPipeline( | |
| **new_weights, | |
| requires_safety_checker=False, | |
| ) | |
| if device == "cuda": | |
| for component in pipe.components.values(): | |
| if isinstance(component, torch.nn.Module): | |
| component.to("cuda", torch.float16) | |
| if _xformers_available: | |
| pipe.enable_xformers_memory_efficient_attention() | |
| pipe.enable_attention_slicing() | |
| pipe.enable_vae_tiling() | |
| return pipe | |
| # initialize with preloaded pipe | |
| if is_hfspace: | |
| pipe = load_pipe(model_ids[0], controlnet_ids[0], default_scheduler) | |
| def extract_canny(image): | |
| CANNY_THRESHOLD = (100, 200) | |
| image_array = np.asarray(image) | |
| canny_image = cv2.Canny(image_array, *CANNY_THRESHOLD) | |
| canny_image = canny_image[:, :, None] | |
| canny_image = np.concatenate([canny_image]*3, axis=2) | |
| return Image.fromarray(canny_image) | |
| def generate( | |
| model_name, | |
| guidance_image, | |
| controlnet_name, | |
| scheduler_name, | |
| prompt, | |
| guidance, | |
| steps, | |
| n_images=1, | |
| width=512, | |
| height=512, | |
| seed=0, | |
| neg_prompt="", | |
| controlnet_prompt=None, | |
| controlnet_negative_prompt=None, | |
| controlnet_cond_scale=1.0, | |
| progress=gr.Progress(), | |
| ): | |
| if seed == -1: | |
| seed = random.randint(0, 2147483647) | |
| if guidance_image: | |
| guidance_image = extract_canny(guidance_image) | |
| else: | |
| guidance_image = torch.zeros(1, 3, height, width) | |
| generator = torch.Generator(device).manual_seed(seed) | |
| pipe = load_pipe( | |
| model_id=model_name, | |
| controlnet_id=controlnet_name, | |
| scheduler_name=scheduler_name, | |
| ) | |
| status_message = f"Prompt: '{prompt}' | Seed: {seed} | Guidance: {guidance} | Scheduler: {scheduler_name} | Steps: {steps}" | |
| # pass None so pipeline uses base prompt as controlnet_prompt | |
| if controlnet_prompt == "": | |
| controlnet_prompt = None # | |
| if controlnet_negative_prompt == "": | |
| controlnet_negative_prompt = None | |
| if controlnet_prompt: | |
| controlnet_prompt_embeds = pipe._encode_prompt( | |
| controlnet_prompt, | |
| device, | |
| n_images, | |
| do_classifier_free_guidance = guidance > 1.0, | |
| negative_prompt = controlnet_negative_prompt, | |
| prompt_embeds=None, | |
| negative_prompt_embeds=None, | |
| ) | |
| else: | |
| controlnet_prompt_embeds = None | |
| result = pipe( | |
| prompt, | |
| image=guidance_image, | |
| height=height, | |
| width=width, | |
| num_inference_steps=int(steps), | |
| guidance_scale=guidance, | |
| negative_prompt=neg_prompt, | |
| num_images_per_prompt=n_images, | |
| generator=generator, | |
| controlnet_conditioning_scale = float(controlnet_cond_scale), | |
| controlnet_prompt_embeds = controlnet_prompt_embeds, | |
| ) | |
| return result.images, status_message | |
| def run_training( | |
| model_name, | |
| controlnet_weights_path, | |
| train_data_dir, | |
| valid_data_dir, | |
| train_batch_size, | |
| train_whole_controlnet, | |
| gradient_accumulation_steps, | |
| num_train_epochs, | |
| train_learning_rate, | |
| output_dir, | |
| checkpointing_steps, | |
| image_logging_steps, | |
| save_whole_pipeline, | |
| progress=gr.Progress(), | |
| ): | |
| global pipe | |
| if device == "cpu": | |
| raise gr.Error("Training not supported on CPU") | |
| pathobj = Path(controlnet_weights_path) | |
| controlnet_path = str(Path().joinpath(*pathobj.parts[:-1])) | |
| subfolder = str(pathobj.parts[-1]) | |
| controlnet = ControlNetModel.from_pretrained( | |
| controlnet_path, | |
| subfolder=subfolder, | |
| low_cpu_mem_usage=False, | |
| device_map=None, | |
| ) | |
| pipe.components["controlnet"] = controlnet | |
| pipe = StableDiffusionControlNetPipeline( | |
| **pipe.components, | |
| requires_safety_checker=False, | |
| ) | |
| training_args = argparse.Namespace( | |
| # start training from preexisting models | |
| pretrained_model_name_or_path=None, | |
| controlnet_weights_path=None, | |
| # dataset args | |
| train_data_dir=train_data_dir, | |
| valid_data_dir=valid_data_dir, | |
| resolution=512, | |
| from_hf_hub = train_data_dir == "lint/anybooru", | |
| controlnet_hint_key="canny", | |
| # training args | |
| # options are ["zero convolutions", "input hint blocks"], trains whole controlnet by default | |
| training_stage="" if train_whole_controlnet else "zero convolutions", | |
| learning_rate=float(train_learning_rate), | |
| num_train_epochs=int(num_train_epochs), | |
| seed=3434554, | |
| max_grad_norm=1.0, | |
| gradient_accumulation_steps=int(gradient_accumulation_steps), | |
| # VRAM args | |
| batch_size=train_batch_size, | |
| mixed_precision="fp16", # set to "fp16" for mixed-precision training. | |
| gradient_checkpointing=True, # set this to True to lower the memory usage. | |
| use_8bit_adam=False, # use 8bit optimizer from bitsandbytes | |
| enable_xformers_memory_efficient_attention=True, | |
| allow_tf32=True, | |
| dataloader_num_workers=cpu_count(), | |
| # logging args | |
| output_dir=output_dir, | |
| report_to="tensorboard", | |
| image_logging_steps=image_logging_steps, # disabled when 0. costs additional VRAM to log images | |
| save_whole_pipeline=save_whole_pipeline, | |
| checkpointing_steps=checkpointing_steps, | |
| ) | |
| try: | |
| lab = Lab(training_args, pipe) | |
| lab.train(training_args.num_train_epochs, gr_progress=progress) | |
| except Exception as e: | |
| raise gr.Error(e) | |
| for component in pipe.components.values(): | |
| if isinstance(component, torch.nn.Module): | |
| component.to(device, dtype=dtype) | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| return f"Finished training! Check the {training_args.output_dir} directory for saved model weights" | |