Spaces:
Runtime error
Runtime error
| import warnings | |
| warnings.filterwarnings("ignore") | |
| from diffusers import StableDiffusionPipeline, DDIMInverseScheduler, DDIMScheduler | |
| import torch | |
| from typing import Optional | |
| from tqdm import tqdm | |
| from diffusers.models.attention_processor import Attention, AttnProcessor2_0 | |
| import torchvision | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import gc | |
| import gradio as gr | |
| import numpy as np | |
| import os | |
| import pickle | |
| from transformers import CLIPImageProcessor | |
| from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker | |
| import argparse | |
| weights = { | |
| 'down': { | |
| 4096: 0.0, | |
| 1024: 1.0, | |
| 256: 1.0, | |
| }, | |
| 'mid': { | |
| 64: 1.0, | |
| }, | |
| 'up': { | |
| 256: 1.0, | |
| 1024: 1.0, | |
| 4096: 0.0, | |
| } | |
| } | |
| num_inference_steps = 10 | |
| model_id = "stabilityai/stable-diffusion-2-1-base" | |
| pipe = StableDiffusionPipeline.from_pretrained(model_id).to("cuda") | |
| inverse_scheduler = DDIMInverseScheduler.from_pretrained(model_id, subfolder="scheduler") | |
| scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler") | |
| safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker").to("cuda") | |
| feature_extractor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| should_stop = False | |
| def save_state_to_file(state): | |
| filename = "state.pkl" | |
| with open(filename, 'wb') as f: | |
| pickle.dump(state, f) | |
| return filename | |
| def load_state_from_file(filename): | |
| with open(filename, 'rb') as f: | |
| state = pickle.load(f) | |
| return state | |
| def stop_reconstruct(): | |
| global should_stop | |
| should_stop = True | |
| def reconstruct(input_img, caption): | |
| img = input_img | |
| cond_prompt_embeds = pipe.encode_prompt(prompt=caption, device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0] | |
| uncond_prompt_embeds = pipe.encode_prompt(prompt="", device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0] | |
| prompt_embeds_combined = torch.cat([uncond_prompt_embeds, cond_prompt_embeds]) | |
| transform = torchvision.transforms.Compose([ | |
| torchvision.transforms.Resize((512, 512)), | |
| torchvision.transforms.ToTensor() | |
| ]) | |
| loaded_image = transform(img).to("cuda").unsqueeze(0) | |
| if loaded_image.shape[1] == 4: | |
| loaded_image = loaded_image[:,:3,:,:] | |
| with torch.no_grad(): | |
| encoded_image = pipe.vae.encode(loaded_image*2 - 1) | |
| real_image_latents = pipe.vae.config.scaling_factor * encoded_image.latent_dist.sample() | |
| guidance_scale = 1 | |
| inverse_scheduler.set_timesteps(num_inference_steps, device="cuda") | |
| timesteps = inverse_scheduler.timesteps | |
| latents = real_image_latents | |
| inversed_latents = [] | |
| with torch.no_grad(): | |
| replace_attention_processor(pipe.unet, True) | |
| for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc="Inference steps"): | |
| inversed_latents.append(latents) | |
| latent_model_input = torch.cat([latents] * 2) | |
| noise_pred = pipe.unet( | |
| latent_model_input, | |
| t, | |
| encoder_hidden_states=prompt_embeds_combined, | |
| cross_attention_kwargs=None, | |
| return_dict=False, | |
| )[0] | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
| latents = inverse_scheduler.step(noise_pred, t, latents, return_dict=False)[0] | |
| # initial state | |
| real_image_initial_latents = latents | |
| W_values = uncond_prompt_embeds.repeat(num_inference_steps, 1, 1) | |
| QT = nn.Parameter(W_values.clone()) | |
| guidance_scale = 7.5 | |
| scheduler.set_timesteps(num_inference_steps, device="cuda") | |
| timesteps = scheduler.timesteps | |
| optimizer = torch.optim.AdamW([QT], lr=0.008) | |
| pipe.vae.eval() | |
| pipe.vae.requires_grad_(False) | |
| pipe.unet.eval() | |
| pipe.unet.requires_grad_(False) | |
| last_loss = 1 | |
| for epoch in range(50): | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| if last_loss < 0.02: | |
| break | |
| elif last_loss < 0.03: | |
| for param_group in optimizer.param_groups: | |
| param_group['lr'] = 0.003 | |
| elif last_loss < 0.035: | |
| for param_group in optimizer.param_groups: | |
| param_group['lr'] = 0.006 | |
| intermediate_values = real_image_initial_latents.clone() | |
| for i in range(num_inference_steps): | |
| latents = intermediate_values.detach().clone() | |
| t = timesteps[i] | |
| prompt_embeds = torch.cat([QT[i].unsqueeze(0), cond_prompt_embeds.detach()]) | |
| latent_model_input = torch.cat([latents] * 2) | |
| noise_pred_model = pipe.unet( | |
| latent_model_input, | |
| t, | |
| encoder_hidden_states=prompt_embeds, | |
| cross_attention_kwargs=None, | |
| return_dict=False, | |
| )[0] | |
| noise_pred_uncond, noise_pred_text = noise_pred_model.chunk(2) | |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
| intermediate_values = scheduler.step(noise_pred, t, latents, return_dict=False)[0] | |
| loss = F.mse_loss(inversed_latents[len(timesteps) - 1 - i].detach(), intermediate_values, reduction="mean") | |
| last_loss = loss | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| global should_stop | |
| if should_stop: | |
| should_stop = False | |
| break | |
| image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0] | |
| image = (image / 2.0 + 0.5).clamp(0.0, 1.0) | |
| safety_checker_input = feature_extractor(image, return_tensors="pt", do_rescale=False).to("cuda") | |
| image = safety_checker(images=[image], clip_input=safety_checker_input.pixel_values.to("cuda"))[0] | |
| image_np = image[0].squeeze(0).float().permute(1,2,0).detach().cpu().numpy() | |
| image_np = (image_np * 255).astype(np.uint8) | |
| yield image_np, caption, [caption, real_image_initial_latents, QT] | |
| image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0] | |
| image = (image / 2.0 + 0.5).clamp(0.0, 1.0) | |
| safety_checker_input = feature_extractor(image, return_tensors="pt", do_rescale=False).to("cuda") | |
| image = safety_checker(images=[image], clip_input=safety_checker_input.pixel_values.to("cuda"))[0] | |
| image_np = image[0].squeeze(0).float().permute(1,2,0).detach().cpu().numpy() | |
| image_np = (image_np * 255).astype(np.uint8) | |
| yield image_np, caption, [caption, real_image_initial_latents, QT] | |
| class AttnReplaceProcessor(AttnProcessor2_0): | |
| def __init__(self, replace_all, weight): | |
| super().__init__() | |
| self.replace_all = replace_all | |
| self.weight = weight | |
| def __call__( | |
| self, | |
| attn: Attention, | |
| hidden_states: torch.FloatTensor, | |
| encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
| attention_mask: Optional[torch.FloatTensor] = None, | |
| temb: Optional[torch.FloatTensor] = None, | |
| *args, | |
| **kwargs, | |
| ) -> torch.FloatTensor: | |
| residual = hidden_states | |
| is_cross = not encoder_hidden_states is None | |
| input_ndim = hidden_states.ndim | |
| if input_ndim == 4: | |
| batch_size, channel, height, width = hidden_states.shape | |
| hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | |
| batch_size, _, _ = ( | |
| hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
| ) | |
| if attn.group_norm is not None: | |
| hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
| query = attn.to_q(hidden_states) | |
| if encoder_hidden_states is None: | |
| encoder_hidden_states = hidden_states | |
| elif attn.norm_cross: | |
| encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
| key = attn.to_k(encoder_hidden_states) | |
| value = attn.to_v(encoder_hidden_states) | |
| query = attn.head_to_batch_dim(query) | |
| key = attn.head_to_batch_dim(key) | |
| value = attn.head_to_batch_dim(value) | |
| attention_scores = attn.scale * torch.bmm(query, key.transpose(-1, -2)) | |
| dimension_squared = hidden_states.shape[1] | |
| if not is_cross and (self.replace_all): | |
| ucond_attn_scores_src, ucond_attn_scores_dst, attn_scores_src, attn_scores_dst = attention_scores.chunk(4) | |
| attn_scores_dst.copy_(self.weight[dimension_squared] * attn_scores_src + (1.0 - self.weight[dimension_squared]) * attn_scores_dst) | |
| ucond_attn_scores_dst.copy_(self.weight[dimension_squared] * ucond_attn_scores_src + (1.0 - self.weight[dimension_squared]) * ucond_attn_scores_dst) | |
| attention_probs = attention_scores.softmax(dim=-1) | |
| del attention_scores | |
| hidden_states = torch.bmm(attention_probs, value) | |
| hidden_states = attn.batch_to_head_dim(hidden_states) | |
| del attention_probs | |
| hidden_states = attn.to_out[0](hidden_states) | |
| if input_ndim == 4: | |
| hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | |
| if attn.residual_connection: | |
| hidden_states = hidden_states + residual | |
| hidden_states = hidden_states / attn.rescale_output_factor | |
| return hidden_states | |
| def replace_attention_processor(unet, clear = False): | |
| for name, module in unet.named_modules(): | |
| if 'attn1' in name and 'to' not in name: | |
| layer_type = name.split('.')[0].split('_')[0] | |
| if not clear: | |
| if layer_type == 'down': | |
| module.processor = AttnReplaceProcessor(True, weights['down']) | |
| elif layer_type == 'mid': | |
| module.processor = AttnReplaceProcessor(True, weights['mid']) | |
| elif layer_type == 'up': | |
| module.processor = AttnReplaceProcessor(True, weights['up']) | |
| else: | |
| module.processor = AttnReplaceProcessor(False, 0.0) | |
| def apply_prompt(meta_data, new_prompt): | |
| caption, real_image_initial_latents, QT = meta_data | |
| inference_steps = len(QT) | |
| cond_prompt_embeds = pipe.encode_prompt(prompt=caption, device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0] | |
| # uncond_prompt_embeds = pipe.encode_prompt(prompt=caption, device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0] | |
| new_prompt_embeds = pipe.encode_prompt(prompt=new_prompt, device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0] | |
| guidance_scale = 7.5 | |
| scheduler.set_timesteps(inference_steps, device="cuda") | |
| timesteps = scheduler.timesteps | |
| latents = torch.cat([real_image_initial_latents] * 2) | |
| with torch.no_grad(): | |
| replace_attention_processor(pipe.unet) | |
| for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc="Inference steps"): | |
| modified_prompt_embeds = torch.cat([QT[i].unsqueeze(0), QT[i].unsqueeze(0), cond_prompt_embeds, new_prompt_embeds]) | |
| latent_model_input = torch.cat([latents] * 2) | |
| noise_pred = pipe.unet( | |
| latent_model_input, | |
| t, | |
| encoder_hidden_states=modified_prompt_embeds, | |
| cross_attention_kwargs=None, | |
| return_dict=False, | |
| )[0] | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
| latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0] | |
| replace_attention_processor(pipe.unet, True) | |
| image = pipe.vae.decode(latents[1].unsqueeze(0) / pipe.vae.config.scaling_factor, return_dict=False)[0] | |
| image = (image / 2.0 + 0.5).clamp(0.0, 1.0) | |
| safety_checker_input = feature_extractor(image, return_tensors="pt", do_rescale=False).to("cuda") | |
| image = safety_checker(images=[image], clip_input=safety_checker_input.pixel_values.to("cuda"))[0] | |
| image_np = image[0].squeeze(0).float().permute(1,2,0).detach().cpu().numpy() | |
| image_np = (image_np * 255).astype(np.uint8) | |
| return image_np | |
| def on_image_change(filepath): | |
| # Extract the filename without extension | |
| filename = os.path.splitext(os.path.basename(filepath))[0] | |
| # Check if the filename is "example1" or "example2" | |
| if filename in ["example1", "example2", "example3", "example4"]: | |
| meta_data_raw = load_state_from_file(f"assets/{filename}.pkl") | |
| _, _, QT_raw = meta_data_raw | |
| global num_inference_steps | |
| num_inference_steps = len(QT_raw) | |
| scale_value = 7 | |
| new_prompt = "" | |
| if filename == "example1": | |
| scale_value = 7 | |
| new_prompt = "a photo of a tree, summer, colourful" | |
| elif filename == "example2": | |
| scale_value = 8 | |
| new_prompt = "a photo of a panda, two ears, white background" | |
| elif filename == "example3": | |
| scale_value = 7 | |
| new_prompt = "a realistic photo of a female warrior, flowing dark purple or black hair, bronze shoulder armour, leather chest piece, sky background with clouds" | |
| elif filename == "example4": | |
| scale_value = 7 | |
| new_prompt = "a photo of plastic bottle on some sand, beach background, sky background" | |
| update_scale(scale_value) | |
| img = apply_prompt(meta_data_raw, new_prompt) | |
| return filepath, img, meta_data_raw, num_inference_steps, scale_value, scale_value | |
| def update_value(value, key, res): | |
| global weights | |
| weights[key][res] = value | |
| def update_step(value): | |
| global num_inference_steps | |
| num_inference_steps = value | |
| def update_scale(scale): | |
| values = [1.0] * 7 | |
| if scale == 9: | |
| return values | |
| reduction_steps = (9 - scale) * 0.5 | |
| for i in range(4): # There are 4 positions to reduce symmetrically | |
| if reduction_steps >= 1: | |
| values[i] = 0.0 | |
| values[-(i + 1)] = 0.0 | |
| reduction_steps -= 1 | |
| elif reduction_steps > 0: | |
| values[i] = 0.5 | |
| values[-(i + 1)] = 0.5 | |
| break | |
| global weights | |
| index = 0 | |
| for outer_key, inner_dict in weights.items(): | |
| for inner_key in inner_dict: | |
| inner_dict[inner_key] = values[index] | |
| index += 1 | |
| return weights['down'][4096], weights['down'][1024], weights['down'][256], weights['mid'][64], weights['up'][256], weights['up'][1024], weights['up'][4096] | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| ''' | |
| <div style="text-align: center;"> | |
| <div style="display: flex; justify-content: center;"> | |
| <img src="https://github.com/user-attachments/assets/55a38e74-ab93-4d80-91c8-0fa6130af45a" alt="Logo"> | |
| </div> | |
| <h1>Out of Focus 1.0</h1> | |
| <p style="font-size:16px;">Out of AI presents a flexible tool to manipulate your images. This is our first version of Image modification tool through prompt manipulation by reconstruction through diffusion inversion process</p> | |
| </div> | |
| <br> | |
| <div style="display: flex; justify-content: center; align-items: center; text-align: center;"> | |
| <a href="https://www.buymeacoffee.com/outofai" target="_blank"><img src="https://img.shields.io/badge/-buy_me_a%C2%A0coffee-red?logo=buy-me-a-coffee" alt="Buy Me A Coffee"></a>   | |
| <a href="https://twitter.com/OutofAi" target="_blank"><img src="https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=Ashleigh%20Watson"></a>   | |
| <a href="https://twitter.com/banterless_ai" target="_blank"><img src="https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=Alex%20Nasa"></a> | |
| </div> | |
| <br> | |
| <div style="display: flex; justify-content: center; align-items: center; text-align: center;"> | |
| <p style="display: flex;gap: 6px;"> | |
| <a href="https://huggingface.co/spaces/fffiloni/OutofFocus?duplicate=true"> | |
| <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-md.svg" alt="Duplicate this Space"> | |
| </a> to skip the queue and enjoy faster inference on the GPU of your choice | |
| </p> | |
| </div> | |
| ''' | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| example_input = gr.Image(height=512, width=512, type="filepath", visible=False) | |
| image_input = gr.Image(height=512, width=512, type="pil", label="Upload Source Image") | |
| steps_slider = gr.Slider(minimum=5, maximum=25, step=5, value=num_inference_steps, label="Steps", info="Number of inference steps required to reconstruct and modify the image") | |
| prompt_input = gr.Textbox(label="Prompt", info="Give an initial prompt in details, describing the image") | |
| reconstruct_button = gr.Button("Reconstruct") | |
| stop_button = gr.Button("Stop", variant="stop", interactive=False) | |
| with gr.Column(): | |
| reconstructed_image = gr.Image(type="pil", label="Reconstructed") | |
| with gr.Row(): | |
| invisible_slider = gr.Slider(minimum=0, maximum=9, step=1, value=7, visible=False) | |
| interpolate_slider = gr.Slider(minimum=0, maximum=9, step=1, value=7, label="Cross-Attention Influence", info="Scales the related influence the source image has on the target image") | |
| with gr.Row(): | |
| new_prompt_input = gr.Textbox(label="New Prompt", interactive=False, info="Manipulate the image by changing the prompt or word addition at the end, achieve the best results by swapping words instead of adding or removing in between") | |
| with gr.Row(): | |
| apply_button = gr.Button("Generate Vision", variant="primary", interactive=False) | |
| with gr.Row(): | |
| with gr.Accordion(label="Advanced Options", open=False): | |
| gr.Markdown( | |
| ''' | |
| <div style="text-align: center;"> | |
| <h1>Weight Adjustment</h1> | |
| <p style="font-size:16px;">Specific Cross-Attention Influence weights can be manually modified for given resolutions (1.0 = Fully Source Attn 0.0 = Fully Target Attn)</p> | |
| </div> | |
| ''' | |
| ) | |
| down_slider_4096 = gr.Number(minimum=0.0, maximum=1.0, step=0.1, value=weights['down'][4096], label="Self-Attn Down 64x64") | |
| down_slider_1024 = gr.Number(minimum=0.0, maximum=1.0, step=0.1, value=weights['down'][1024], label="Self-Attn Down 32x32") | |
| down_slider_256 = gr.Number(minimum=0.0, maximum=1.0, step=0.1, value=weights['down'][256], label="Self-Attn Down 16x16") | |
| mid_slider_64 = gr.Number(minimum=0.0, maximum=1.0, step=0.1, value=weights['mid'][64], label="Self-Attn Mid 8x8") | |
| up_slider_256 = gr.Number(minimum=0.0, maximum=1.0, step=0.1, value=weights['up'][256], label="Self-Attn Up 16x16") | |
| up_slider_1024 = gr.Number(minimum=0.0, maximum=1.0, step=0.1, value=weights['up'][1024], label="Self-Attn Up 32x32") | |
| up_slider_4096 = gr.Number(minimum=0.0, maximum=1.0, step=0.1, value=weights['up'][4096], label="Self-Attn Up 64x64") | |
| with gr.Row(): | |
| show_case = gr.Examples( | |
| examples=[ | |
| ["assets/example4.png", "a photo of plastic bottle on a rock, mountain background, sky background", "a photo of plastic bottle on some sand, beach background, sky background"], | |
| ["assets/example1.png", "a photo of a tree, spring, foggy", "a photo of a tree, summer, colourful"], | |
| ["assets/example2.png", "a photo of a cat, two ears, white background", "a photo of a panda, two ears, white background"], | |
| ["assets/example3.png", "a digital illustration of a female warrior, flowing dark purple or black hair, bronze shoulder armour, leather chest piece, sky background with clouds", "a realistic photo of a female warrior, flowing dark purple or black hair, bronze shoulder armour, leather chest piece, sky background with clouds"], | |
| ], | |
| inputs=[example_input, prompt_input, new_prompt_input], | |
| label=None | |
| ) | |
| meta_data = gr.State() | |
| example_input.change( | |
| fn=on_image_change, | |
| inputs=example_input, | |
| outputs=[image_input, reconstructed_image, meta_data, steps_slider, invisible_slider, interpolate_slider] | |
| ).then( | |
| lambda: gr.update(interactive=True), | |
| outputs=apply_button | |
| ).then( | |
| lambda: gr.update(interactive=True), | |
| outputs=new_prompt_input | |
| ) | |
| steps_slider.release(update_step, inputs=steps_slider) | |
| interpolate_slider.release(update_scale, inputs=interpolate_slider, outputs=[down_slider_4096, down_slider_1024, down_slider_256, mid_slider_64, up_slider_256, up_slider_1024, up_slider_4096 ]) | |
| invisible_slider.change(update_scale, inputs=invisible_slider, outputs=[down_slider_4096, down_slider_1024, down_slider_256, mid_slider_64, up_slider_256, up_slider_1024, up_slider_4096 ]) | |
| up_slider_4096.change(update_value, inputs=[up_slider_4096, gr.State('up'), gr.State(4096)]) | |
| up_slider_1024.change(update_value, inputs=[up_slider_1024, gr.State('up'), gr.State(1024)]) | |
| up_slider_256.change(update_value, inputs=[up_slider_256, gr.State('up'), gr.State(256)]) | |
| down_slider_4096.change(update_value, inputs=[down_slider_4096, gr.State('down'), gr.State(4096)]) | |
| down_slider_1024.change(update_value, inputs=[down_slider_1024, gr.State('down'), gr.State(1024)]) | |
| down_slider_256.change(update_value, inputs=[down_slider_256, gr.State('down'), gr.State(256)]) | |
| mid_slider_64.change(update_value, inputs=[mid_slider_64, gr.State('mid'), gr.State(64)]) | |
| reconstruct_button.click(reconstruct, inputs=[image_input, prompt_input], outputs=[reconstructed_image, new_prompt_input, meta_data]).then( | |
| lambda: gr.update(interactive=True), | |
| outputs=reconstruct_button | |
| ).then( | |
| lambda: gr.update(interactive=True), | |
| outputs=new_prompt_input | |
| ).then( | |
| lambda: gr.update(interactive=True), | |
| outputs=apply_button | |
| ).then( | |
| lambda: gr.update(interactive=False), | |
| outputs=stop_button | |
| ) | |
| reconstruct_button.click( | |
| lambda: gr.update(interactive=False), | |
| outputs=reconstruct_button | |
| ) | |
| reconstruct_button.click( | |
| lambda: gr.update(interactive=True), | |
| outputs=stop_button | |
| ) | |
| reconstruct_button.click( | |
| lambda: gr.update(interactive=False), | |
| outputs=apply_button | |
| ) | |
| stop_button.click( | |
| lambda: gr.update(interactive=False), | |
| outputs=stop_button | |
| ) | |
| apply_button.click(apply_prompt, inputs=[meta_data, new_prompt_input], outputs=reconstructed_image) | |
| stop_button.click(stop_reconstruct) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--share", action="store_true") | |
| args = parser.parse_args() | |
| demo.queue() | |
| demo.launch(share=args.share) |