Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import random | |
| from PIL import Image | |
| import os | |
| import argparse | |
| import shutil | |
| import gc | |
| import importlib | |
| import json | |
| from diffusers import ( | |
| StableDiffusionPipeline, | |
| StableDiffusionImg2ImgPipeline, | |
| ) | |
| from .inpaint_pipeline import SDInpaintPipeline as StableDiffusionInpaintPipelineLegacy | |
| from .textual_inversion import main as run_textual_inversion | |
| from .shared import default_scheduler, scheduler_dict, model_ids | |
| _xformers_available = importlib.util.find_spec("xformers") is not None | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # device = 'cpu' | |
| dtype = torch.float16 if device == "cuda" else torch.float32 | |
| low_vram_mode = False | |
| tab_to_pipeline = { | |
| 1: StableDiffusionPipeline, | |
| 2: StableDiffusionImg2ImgPipeline, | |
| 3: StableDiffusionInpaintPipelineLegacy, | |
| } | |
| def load_pipe(model_id, scheduler_name, tab_index=1, pipe_kwargs="{}"): | |
| global pipe, loaded_model_id | |
| scheduler = scheduler_dict[scheduler_name] | |
| pipe_class = tab_to_pipeline[tab_index] | |
| # load new weights from disk only when changing model_id | |
| if model_id != loaded_model_id: | |
| pipe = pipe_class.from_pretrained( | |
| model_id, | |
| torch_dtype=dtype, | |
| safety_checker=None, | |
| requires_safety_checker=False, | |
| scheduler=scheduler.from_pretrained(model_id, subfolder="scheduler"), | |
| **json.loads(pipe_kwargs), | |
| ) | |
| loaded_model_id = model_id | |
| # if same model_id, instantiate new pipeline with same underlying pytorch objects to avoid reloading weights from disk | |
| elif pipe_class != pipe.__class__ or not isinstance(pipe.scheduler, scheduler): | |
| pipe.components["scheduler"] = scheduler.from_pretrained( | |
| model_id, subfolder="scheduler" | |
| ) | |
| pipe = pipe_class(**pipe.components) | |
| if device == "cuda": | |
| pipe = pipe.to(device) | |
| if _xformers_available: | |
| pipe.enable_xformers_memory_efficient_attention() | |
| print("using xformers") | |
| if low_vram_mode: | |
| pipe.enable_attention_slicing() | |
| print("using attention slicing to lower VRAM") | |
| return pipe | |
| pipe = None | |
| loaded_model_id = "" | |
| pipe = load_pipe(model_ids[0], default_scheduler) | |
| def pad_image(image): | |
| w, h = image.size | |
| if w == h: | |
| return image | |
| elif w > h: | |
| new_image = Image.new(image.mode, (w, w), (0, 0, 0)) | |
| new_image.paste(image, (0, (w - h) // 2)) | |
| return new_image | |
| else: | |
| new_image = Image.new(image.mode, (h, h), (0, 0, 0)) | |
| new_image.paste(image, ((h - w) // 2, 0)) | |
| return new_image | |
| def generate( | |
| model_name, | |
| scheduler_name, | |
| prompt, | |
| guidance, | |
| steps, | |
| n_images=1, | |
| width=512, | |
| height=512, | |
| seed=0, | |
| image=None, | |
| strength=0.5, | |
| inpaint_image=None, | |
| inpaint_strength=0.5, | |
| inpaint_radio="", | |
| neg_prompt="", | |
| tab_index=1, | |
| pipe_kwargs="{}", | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| if seed == -1: | |
| seed = random.randint(0, 2147483647) | |
| generator = torch.Generator(device).manual_seed(seed) | |
| pipe = load_pipe( | |
| model_id=model_name, | |
| scheduler_name=scheduler_name, | |
| tab_index=tab_index, | |
| pipe_kwargs=pipe_kwargs, | |
| ) | |
| status_message = f"Prompt: '{prompt}' | Seed: {seed} | Guidance: {guidance} | Scheduler: {scheduler_name} | Steps: {steps}" | |
| if tab_index == 1: | |
| status_message = "Text to Image " + status_message | |
| result = pipe( | |
| prompt, | |
| negative_prompt=neg_prompt, | |
| num_images_per_prompt=n_images, | |
| num_inference_steps=int(steps), | |
| guidance_scale=guidance, | |
| width=width, | |
| height=height, | |
| generator=generator, | |
| ) | |
| elif tab_index == 2: | |
| status_message = "Image to Image " + status_message | |
| print(image.size) | |
| image = image.resize((width, height)) | |
| print(image.size) | |
| result = pipe( | |
| prompt, | |
| negative_prompt=neg_prompt, | |
| num_images_per_prompt=n_images, | |
| image=image, | |
| num_inference_steps=int(steps), | |
| strength=strength, | |
| guidance_scale=guidance, | |
| generator=generator, | |
| ) | |
| elif tab_index == 3: | |
| status_message = "Inpainting " + status_message | |
| init_image = inpaint_image["image"].resize((width, height)) | |
| mask = inpaint_image["mask"].resize((width, height)) | |
| result = pipe( | |
| prompt, | |
| negative_prompt=neg_prompt, | |
| num_images_per_prompt=n_images, | |
| image=init_image, | |
| mask_image=mask, | |
| num_inference_steps=int(steps), | |
| strength=inpaint_strength, | |
| preserve_unmasked_image=( | |
| inpaint_radio == "preserve non-masked portions of image" | |
| ), | |
| guidance_scale=guidance, | |
| generator=generator, | |
| ) | |
| else: | |
| return None, f"Unhandled tab index: {tab_index}" | |
| return result.images, status_message | |
| # based on lvkaokao/textual-inversion-training | |
| def train_textual_inversion( | |
| model_name, | |
| scheduler_name, | |
| type_of_thing, | |
| files, | |
| concept_word, | |
| init_word, | |
| text_train_steps, | |
| text_train_bsz, | |
| text_learning_rate, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| if device == "cpu": | |
| raise gr.Error("Textual inversion training not supported on CPU") | |
| pipe = load_pipe( | |
| model_id=model_name, | |
| scheduler_name=scheduler_name, | |
| tab_index=1, | |
| ) | |
| pipe.disable_xformers_memory_efficient_attention() # xformers handled by textual inversion script | |
| concept_dir = "concept_images" | |
| output_dir = "output_model" | |
| training_resolution = 512 | |
| if os.path.exists(output_dir): | |
| shutil.rmtree("output_model") | |
| if os.path.exists(concept_dir): | |
| shutil.rmtree("concept_images") | |
| os.makedirs(concept_dir, exist_ok=True) | |
| os.makedirs(output_dir, exist_ok=True) | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| if concept_word == "" or concept_word == None: | |
| raise gr.Error("You forgot to define your concept prompt") | |
| for j, file_temp in enumerate(files): | |
| file = Image.open(file_temp.name) | |
| image = pad_image(file) | |
| image = image.resize((training_resolution, training_resolution)) | |
| extension = file_temp.name.split(".")[1] | |
| image = image.convert("RGB") | |
| image.save(f"{concept_dir}/{j+1}.{extension}", quality=100) | |
| args_general = argparse.Namespace( | |
| train_data_dir=concept_dir, | |
| learnable_property=type_of_thing, | |
| placeholder_token=concept_word, | |
| initializer_token=init_word, | |
| resolution=training_resolution, | |
| train_batch_size=text_train_bsz, | |
| gradient_accumulation_steps=1, | |
| gradient_checkpointing=True, | |
| mixed_precision="fp16", | |
| use_bf16=False, | |
| max_train_steps=int(text_train_steps), | |
| learning_rate=text_learning_rate, | |
| scale_lr=True, | |
| lr_scheduler="constant", | |
| lr_warmup_steps=0, | |
| output_dir=output_dir, | |
| ) | |
| try: | |
| final_result = run_textual_inversion(pipe, args_general) | |
| except Exception as e: | |
| raise gr.Error(e) | |
| pipe.text_encoder = pipe.text_encoder.eval().to(device, dtype=dtype) | |
| pipe.unet = pipe.unet.eval().to(device, dtype=dtype) | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| return ( | |
| f"Finished training! Check the {output_dir} directory for saved model weights" | |
| ) | |