Spaces:
Runtime error
Runtime error
| from pathlib import Path | |
| import gradio as gr | |
| import torch | |
| from finetuning import FineTunedModel | |
| from StableDiffuser import StableDiffuser | |
| from tqdm import tqdm | |
| class Demo: | |
| def __init__(self) -> None: | |
| self.training = False | |
| self.generating = False | |
| self.nsteps = 50 | |
| self.diffuser = StableDiffuser(scheduler='DDIM', seed=42).to('cuda') | |
| self.finetuner = None | |
| with gr.Blocks() as demo: | |
| self.layout() | |
| demo.queue(concurrency_count=2).launch() | |
| def disable(self): | |
| return [gr.update(interactive=False), gr.update(interactive=False)] | |
| def layout(self): | |
| with gr.Row(): | |
| self.explain = gr.HTML(interactive=False, | |
| value="<p>This page demonstrates Erasing Concepts in Stable Diffusion (Ganikota, Materzynska, Fiotto-Kaufman and Bau; paper and code linked from https://erasing.baulab.info/). <br> Use it in two steps <br> 1. First, on the left fine-tune your own custom model by naming the concept that you want to erase. For example, you can try erasing all cars from a model by entering the prompt corresponding to the concept to erase as 'car'. This can take awhile. For example, with the default settings, this can take about an hour. <br> 2. Second, on the right once you have your model fine-tuned, you can try running it in inference. <br>If you want to run it yourself, then you can create your own instance. Configuration, code, and details are at https://github.com/xxxx/xxxx/xxx</p>") | |
| with gr.Row(): | |
| with gr.Column(scale=1) as training_column: | |
| self.prompt_input = gr.Text( | |
| placeholder="Enter prompt...", | |
| label="Prompt to Erase", | |
| info="Prompt corresponding to concept to erase" | |
| ) | |
| self.train_method_input = gr.Dropdown( | |
| choices=['ESD-x', 'ESD-self'], | |
| value='ESD-x', | |
| label='Train Method', | |
| info='Method of training' | |
| ) | |
| self.neg_guidance_input = gr.Number( | |
| value=1, | |
| label="Negative Guidance", | |
| info='Guidance of negative training used to train' | |
| ) | |
| self.iterations_input = gr.Number( | |
| value=150, | |
| precision=0, | |
| label="Iterations", | |
| info='iterations used to train' | |
| ) | |
| self.lr_input = gr.Number( | |
| value=1e-5, | |
| label="Learning Rate", | |
| info='Learning rate used to train' | |
| ) | |
| self.progress_bar = gr.Text(interactive=False, label="Training Progress") | |
| self.train_button = gr.Button( | |
| value="Train", | |
| ) | |
| with gr.Column(scale=2) as inference_column: | |
| with gr.Row(): | |
| with gr.Column(scale=5): | |
| self.prompt_input_infr = gr.Text( | |
| placeholder="Enter prompt...", | |
| label="Prompt", | |
| info="Prompt to generate" | |
| ) | |
| with gr.Column(scale=1): | |
| self.seed_infr = gr.Number( | |
| label="Seed", | |
| value=42 | |
| ) | |
| with gr.Row(): | |
| self.image_new = gr.Image( | |
| label="New Image", | |
| interactive=False | |
| ) | |
| self.image_orig = gr.Image( | |
| label="Orig Image", | |
| interactive=False | |
| ) | |
| with gr.Row(): | |
| self.infr_button = gr.Button( | |
| value="Generate", | |
| interactive=False | |
| ) | |
| self.infr_button.click(self.inference, inputs = [ | |
| self.prompt_input_infr, | |
| self.seed_infr | |
| ], | |
| outputs=[ | |
| self.image_new, | |
| self.image_orig | |
| ] | |
| ) | |
| self.train_button.click(self.disable, | |
| outputs=[self.train_button, self.infr_button] | |
| ) | |
| self.train_button.click(self.train, inputs = [ | |
| self.prompt_input, | |
| self.train_method_input, | |
| self.neg_guidance_input, | |
| self.iterations_input, | |
| self.lr_input | |
| ], | |
| outputs=[self.train_button, self.infr_button, self.progress_bar] | |
| ) | |
| def train(self, prompt, train_method, neg_guidance, iterations, lr, pbar = gr.Progress(track_tqdm=True)): | |
| if self.training: | |
| return [None, None, None] | |
| else: | |
| self.training = True | |
| del self.finetuner | |
| torch.cuda.empty_cache() | |
| self.diffuser = self.diffuser.train().float() | |
| if train_method == 'ESD-x': | |
| modules = ".*attn2$" | |
| elif train_method == 'ESD-self': | |
| modules = ".*attn1$" | |
| finetuner = FineTunedModel(self.diffuser, modules) | |
| optimizer = torch.optim.Adam(finetuner.parameters(), lr=lr) | |
| criteria = torch.nn.MSELoss() | |
| pbar = tqdm(range(iterations)) | |
| with torch.no_grad(): | |
| neutral_text_embeddings = self.diffuser.get_text_embeddings([''],n_imgs=1) | |
| positive_text_embeddings = self.diffuser.get_text_embeddings([prompt],n_imgs=1) | |
| for i in pbar: | |
| with torch.no_grad(): | |
| self.diffuser.set_scheduler_timesteps(self.nsteps) | |
| optimizer.zero_grad() | |
| iteration = torch.randint(1, self.nsteps - 1, (1,)).item() | |
| latents = self.diffuser.get_initial_latents(1, 512, 1) | |
| with finetuner: | |
| latents_steps, _ = self.diffuser.diffusion( | |
| latents, | |
| positive_text_embeddings, | |
| start_iteration=0, | |
| end_iteration=iteration, | |
| guidance_scale=3, | |
| show_progress=False | |
| ) | |
| self.diffuser.set_scheduler_timesteps(1000) | |
| iteration = int(iteration / self.nsteps * 1000) | |
| positive_latents = self.diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=3) | |
| neutral_latents = self.diffuser.predict_noise(iteration, latents_steps[0], neutral_text_embeddings, guidance_scale=3) | |
| with finetuner: | |
| negative_latents = self.diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=3) | |
| positive_latents.requires_grad = False | |
| neutral_latents.requires_grad = False | |
| loss = criteria(negative_latents, neutral_latents - (neg_guidance*(positive_latents - neutral_latents))) #loss = criteria(e_n, e_0) works the best try 5000 epochs | |
| loss.backward() | |
| optimizer.step() | |
| self.finetuner = finetuner.eval().half() | |
| self.diffuser = self.diffuser.eval().half() | |
| torch.cuda.empty_cache() | |
| self.training = False | |
| return [gr.update(interactive=True), gr.update(interactive=True), None] | |
| def inference(self, prompt, seed, pbar = gr.Progress(track_tqdm=True)): | |
| if self.generating: | |
| return [None, None] | |
| else: | |
| self.generating = True | |
| self.diffuser._seed = seed | |
| images = self.diffuser( | |
| prompt, | |
| n_steps=50, | |
| reseed=True | |
| ) | |
| orig_image = images[0][0] | |
| torch.cuda.empty_cache() | |
| with self.finetuner: | |
| images = self.diffuser( | |
| prompt, | |
| n_steps=50, | |
| reseed=True | |
| ) | |
| edited_image = images[0][0] | |
| self.generating = False | |
| torch.cuda.empty_cache() | |
| return edited_image, orig_image | |
| demo = Demo() | |