| | |
| | import argparse, os |
| |
|
| |
|
| | import torch |
| | import requests |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from PIL import Image |
| | from io import BytesIO |
| | from tqdm.auto import tqdm |
| | from matplotlib import pyplot as plt |
| | from torchvision import transforms as tfms |
| | from diffusers import ( |
| | StableDiffusionPipeline, |
| | DDIMScheduler, |
| | DiffusionPipeline, |
| | StableDiffusionXLPipeline, |
| | ) |
| | from diffusers.image_processor import VaeImageProcessor |
| | import torch |
| | import torch.nn as nn |
| | import torchvision |
| | import torchvision.transforms as transforms |
| | from torchvision.utils import save_image |
| | import argparse |
| | import PIL.Image as Image |
| | from torchvision.utils import make_grid |
| | import numpy |
| | from diffusers.schedulers import DDIMScheduler |
| | import torch.nn.functional as F |
| | from models import attn_injection |
| | from omegaconf import OmegaConf |
| | from typing import List, Tuple |
| |
|
| | import omegaconf |
| | import utils.exp_utils |
| | import json |
| |
|
| | device = torch.device("cuda") |
| |
|
| |
|
| | def _get_text_embeddings(prompt: str, tokenizer, text_encoder, device): |
| | |
| | text_inputs = tokenizer( |
| | prompt, |
| | padding="max_length", |
| | max_length=tokenizer.model_max_length, |
| | truncation=True, |
| | return_tensors="pt", |
| | ) |
| | text_input_ids = text_inputs.input_ids |
| |
|
| | with torch.no_grad(): |
| | prompt_embeds = text_encoder( |
| | text_input_ids.to(device), |
| | output_hidden_states=True, |
| | ) |
| |
|
| | pooled_prompt_embeds = prompt_embeds[0] |
| | prompt_embeds = prompt_embeds.hidden_states[-2] |
| | if prompt == "": |
| | negative_prompt_embeds = torch.zeros_like(prompt_embeds) |
| | negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) |
| | return negative_prompt_embeds, negative_pooled_prompt_embeds |
| | return prompt_embeds, pooled_prompt_embeds |
| |
|
| |
|
| | def _encode_text_sdxl(model: StableDiffusionXLPipeline, prompt: str): |
| | device = model._execution_device |
| | ( |
| | prompt_embeds, |
| | pooled_prompt_embeds, |
| | ) = _get_text_embeddings(prompt, model.tokenizer, model.text_encoder, device) |
| | ( |
| | prompt_embeds_2, |
| | pooled_prompt_embeds_2, |
| | ) = _get_text_embeddings(prompt, model.tokenizer_2, model.text_encoder_2, device) |
| | prompt_embeds = torch.cat((prompt_embeds, prompt_embeds_2), dim=-1) |
| | text_encoder_projection_dim = model.text_encoder_2.config.projection_dim |
| | add_time_ids = model._get_add_time_ids( |
| | (1024, 1024), (0, 0), (1024, 1024), torch.float16, text_encoder_projection_dim |
| | ).to(device) |
| | |
| | add_time_ids = add_time_ids.repeat(len(prompt), 1) |
| | added_cond_kwargs = { |
| | "text_embeds": pooled_prompt_embeds_2, |
| | "time_ids": add_time_ids, |
| | } |
| | return added_cond_kwargs, prompt_embeds |
| |
|
| |
|
| | def _encode_text_sdxl_with_negative( |
| | model: StableDiffusionXLPipeline, prompt: List[str] |
| | ): |
| |
|
| | B = len(prompt) |
| | added_cond_kwargs, prompt_embeds = _encode_text_sdxl(model, prompt) |
| | added_cond_kwargs_uncond, prompt_embeds_uncond = _encode_text_sdxl( |
| | model, ["" for _ in range(B)] |
| | ) |
| | prompt_embeds = torch.cat( |
| | ( |
| | prompt_embeds_uncond, |
| | prompt_embeds, |
| | ) |
| | ) |
| | added_cond_kwargs = { |
| | "text_embeds": torch.cat( |
| | (added_cond_kwargs_uncond["text_embeds"], added_cond_kwargs["text_embeds"]) |
| | ), |
| | "time_ids": torch.cat( |
| | (added_cond_kwargs_uncond["time_ids"], added_cond_kwargs["time_ids"]) |
| | ), |
| | } |
| | return added_cond_kwargs, prompt_embeds |
| |
|
| |
|
| | |
| | @torch.no_grad() |
| | def sample( |
| | pipe, |
| | prompt, |
| | start_step=0, |
| | start_latents=None, |
| | intermediate_latents=None, |
| | guidance_scale=3.5, |
| | num_inference_steps=30, |
| | num_images_per_prompt=1, |
| | do_classifier_free_guidance=True, |
| | negative_prompt="", |
| | device=device, |
| | ): |
| | negative_prompt = [""] * len(prompt) |
| | |
| | if isinstance(pipe, StableDiffusionPipeline): |
| | text_embeddings = pipe._encode_prompt( |
| | prompt, |
| | device, |
| | num_images_per_prompt, |
| | do_classifier_free_guidance, |
| | negative_prompt, |
| | ) |
| | added_cond_kwargs = None |
| | elif isinstance(pipe, StableDiffusionXLPipeline): |
| | added_cond_kwargs, text_embeddings = _encode_text_sdxl_with_negative( |
| | pipe, prompt |
| | ) |
| |
|
| | |
| | pipe.scheduler.set_timesteps(num_inference_steps, device=device) |
| |
|
| | |
| | if start_latents is None: |
| | start_latents = torch.randn(1, 4, 64, 64, device=device) |
| | start_latents *= pipe.scheduler.init_noise_sigma |
| |
|
| | latents = start_latents.clone() |
| |
|
| | latents = latents.repeat(len(prompt), 1, 1, 1) |
| | |
| | for i in tqdm(range(start_step, num_inference_steps)): |
| | latents[0] = intermediate_latents[(-i + 1)] |
| | t = pipe.scheduler.timesteps[i] |
| |
|
| | |
| | latent_model_input = ( |
| | torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
| | ) |
| | latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t) |
| |
|
| | |
| | noise_pred = pipe.unet( |
| | latent_model_input, |
| | t, |
| | encoder_hidden_states=text_embeddings, |
| | added_cond_kwargs=added_cond_kwargs, |
| | ).sample |
| |
|
| | |
| | if do_classifier_free_guidance: |
| | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
| | noise_pred = noise_pred_uncond + guidance_scale * ( |
| | noise_pred_text - noise_pred_uncond |
| | ) |
| | latents = pipe.scheduler.step(noise_pred, t, latents).prev_sample |
| |
|
| | |
| | images = pipe.decode_latents(latents) |
| | images = pipe.numpy_to_pil(images) |
| |
|
| | return images |
| |
|
| |
|
| | |
| | @torch.no_grad() |
| | def sample_disentangled( |
| | pipe, |
| | prompt, |
| | start_step=0, |
| | start_latents=None, |
| | intermediate_latents=None, |
| | guidance_scale=3.5, |
| | num_inference_steps=30, |
| | num_images_per_prompt=1, |
| | do_classifier_free_guidance=True, |
| | use_content_anchor=True, |
| | negative_prompt="", |
| | device=device, |
| | ): |
| | negative_prompt = [""] * len(prompt) |
| | vae_decoder = VaeImageProcessor(vae_scale_factor=pipe.vae.config.scaling_factor) |
| | |
| | if isinstance(pipe, StableDiffusionPipeline): |
| | text_embeddings = pipe._encode_prompt( |
| | prompt, |
| | device, |
| | num_images_per_prompt, |
| | do_classifier_free_guidance, |
| | negative_prompt, |
| | ) |
| | added_cond_kwargs = None |
| | elif isinstance(pipe, StableDiffusionXLPipeline): |
| | added_cond_kwargs, text_embeddings = _encode_text_sdxl_with_negative( |
| | pipe, prompt |
| | ) |
| |
|
| | |
| | pipe.scheduler.set_timesteps(num_inference_steps, device=device) |
| | |
| |
|
| | latent_shape = ( |
| | (1, 4, 64, 64) if isinstance(pipe, StableDiffusionPipeline) else (1, 4, 64, 64) |
| | ) |
| | generative_latent = torch.randn(latent_shape, device=device) |
| | generative_latent *= pipe.scheduler.init_noise_sigma |
| |
|
| | latents = start_latents.clone() |
| |
|
| | latents = latents.repeat(len(prompt), 1, 1, 1) |
| | |
| |
|
| | latents[1] = generative_latent |
| | |
| | for i in tqdm(range(start_step, num_inference_steps), desc="Stylizing"): |
| |
|
| | if use_content_anchor: |
| | latents[0] = intermediate_latents[(-i + 1)] |
| | t = pipe.scheduler.timesteps[i] |
| |
|
| | |
| | latent_model_input = ( |
| | torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
| | ) |
| | latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t) |
| |
|
| | |
| | noise_pred = pipe.unet( |
| | latent_model_input, |
| | t, |
| | encoder_hidden_states=text_embeddings, |
| | added_cond_kwargs=added_cond_kwargs, |
| | ).sample |
| |
|
| | |
| | if do_classifier_free_guidance: |
| | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
| | noise_pred = noise_pred_uncond + guidance_scale * ( |
| | noise_pred_text - noise_pred_uncond |
| | ) |
| |
|
| | latents = pipe.scheduler.step(noise_pred, t, latents).prev_sample |
| |
|
| | |
| | |
| | pipe.vae.to(dtype=torch.float32) |
| | latents = latents.to(next(iter(pipe.vae.post_quant_conv.parameters())).dtype) |
| | latents = 1 / pipe.vae.config.scaling_factor * latents |
| | images = pipe.vae.decode(latents, return_dict=False)[0] |
| | images = (images / 2 + 0.5).clamp(0, 1) |
| | |
| | images = images.cpu().permute(0, 2, 3, 1).float().numpy() |
| | images = pipe.numpy_to_pil(images) |
| | if isinstance(pipe, StableDiffusionXLPipeline): |
| | pipe.vae.to(dtype=torch.float16) |
| |
|
| | return images |
| |
|
| |
|
| | |
| | @torch.no_grad() |
| | def invert( |
| | pipe, |
| | start_latents, |
| | prompt, |
| | guidance_scale=3.5, |
| | num_inference_steps=50, |
| | num_images_per_prompt=1, |
| | do_classifier_free_guidance=True, |
| | negative_prompt="", |
| | device=device, |
| | ): |
| |
|
| | |
| | if isinstance(pipe, StableDiffusionPipeline): |
| | text_embeddings = pipe._encode_prompt( |
| | prompt, |
| | device, |
| | num_images_per_prompt, |
| | do_classifier_free_guidance, |
| | negative_prompt, |
| | ) |
| | added_cond_kwargs = None |
| | latents = start_latents.clone().detach() |
| | elif isinstance(pipe, StableDiffusionXLPipeline): |
| | added_cond_kwargs, text_embeddings = _encode_text_sdxl_with_negative( |
| | pipe, [prompt] |
| | ) |
| | latents = start_latents.clone().detach().half() |
| |
|
| | |
| | intermediate_latents = [] |
| |
|
| | |
| | pipe.scheduler.set_timesteps(num_inference_steps, device=device) |
| |
|
| | |
| | timesteps = reversed(pipe.scheduler.timesteps) |
| |
|
| | for i in tqdm( |
| | range(1, num_inference_steps), |
| | total=num_inference_steps - 1, |
| | desc="DDIM Inversion", |
| | ): |
| |
|
| | |
| | if i >= num_inference_steps - 1: |
| | continue |
| |
|
| | t = timesteps[i] |
| |
|
| | |
| | latent_model_input = ( |
| | torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
| | ) |
| | latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t) |
| |
|
| | |
| | noise_pred = pipe.unet( |
| | latent_model_input, |
| | t, |
| | encoder_hidden_states=text_embeddings, |
| | added_cond_kwargs=added_cond_kwargs, |
| | ).sample |
| |
|
| | |
| | if do_classifier_free_guidance: |
| | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
| | noise_pred = noise_pred_uncond + guidance_scale * ( |
| | noise_pred_text - noise_pred_uncond |
| | ) |
| |
|
| | current_t = max(0, t.item() - (1000 // num_inference_steps)) |
| | next_t = t |
| | alpha_t = pipe.scheduler.alphas_cumprod[current_t] |
| | alpha_t_next = pipe.scheduler.alphas_cumprod[next_t] |
| |
|
| | |
| | latents = (latents - (1 - alpha_t).sqrt() * noise_pred) * ( |
| | alpha_t_next.sqrt() / alpha_t.sqrt() |
| | ) + (1 - alpha_t_next).sqrt() * noise_pred |
| |
|
| | |
| | intermediate_latents.append(latents) |
| |
|
| | return torch.cat(intermediate_latents) |
| |
|
| |
|
| | def style_image_with_inversion( |
| | pipe, |
| | input_image, |
| | input_image_prompt, |
| | style_prompt, |
| | num_steps=100, |
| | start_step=30, |
| | guidance_scale=3.5, |
| | disentangle=False, |
| | share_attn=False, |
| | share_cross_attn=False, |
| | share_resnet_layers=[0, 1], |
| | share_attn_layers=[], |
| | c2s_layers=[0, 1], |
| | share_key=True, |
| | share_query=True, |
| | share_value=False, |
| | use_adain=True, |
| | use_content_anchor=True, |
| | output_dir: str = None, |
| | resnet_mode: str = None, |
| | return_intermediate=False, |
| | intermediate_latents=None, |
| | ): |
| | with torch.no_grad(): |
| | pipe.vae.to(dtype=torch.float32) |
| | latent = pipe.vae.encode(input_image.to(device) * 2 - 1) |
| | |
| | l = pipe.vae.config.scaling_factor * latent.latent_dist.sample() |
| | if isinstance(pipe, StableDiffusionXLPipeline): |
| | pipe.vae.to(dtype=torch.float16) |
| | if intermediate_latents is None: |
| | inverted_latents = invert( |
| | pipe, l, input_image_prompt, num_inference_steps=num_steps |
| | ) |
| | else: |
| | inverted_latents = intermediate_latents |
| |
|
| | attn_injection.register_attention_processors( |
| | pipe, |
| | base_dir=output_dir, |
| | resnet_mode=resnet_mode, |
| | attn_mode="artist" if disentangle else "pnp", |
| | disentangle=disentangle, |
| | share_resblock=True, |
| | share_attn=share_attn, |
| | share_cross_attn=share_cross_attn, |
| | share_resnet_layers=share_resnet_layers, |
| | share_attn_layers=share_attn_layers, |
| | share_key=share_key, |
| | share_query=share_query, |
| | share_value=share_value, |
| | use_adain=use_adain, |
| | c2s_layers=c2s_layers, |
| | ) |
| |
|
| | if disentangle: |
| | final_im = sample_disentangled( |
| | pipe, |
| | style_prompt, |
| | start_latents=inverted_latents[-(start_step + 1)][None], |
| | intermediate_latents=inverted_latents, |
| | start_step=start_step, |
| | num_inference_steps=num_steps, |
| | guidance_scale=guidance_scale, |
| | use_content_anchor=use_content_anchor, |
| | ) |
| | else: |
| | final_im = sample( |
| | pipe, |
| | style_prompt, |
| | start_latents=inverted_latents[-(start_step + 1)][None], |
| | intermediate_latents=inverted_latents, |
| | start_step=start_step, |
| | num_inference_steps=num_steps, |
| | guidance_scale=guidance_scale, |
| | ) |
| |
|
| | |
| | attn_injection.unset_attention_processors( |
| | pipe, |
| | unset_share_attn=True, |
| | unset_share_resblock=True, |
| | ) |
| | if return_intermediate: |
| | return final_im, inverted_latents |
| | return final_im |
| |
|
| |
|
| | if __name__ == "__main__": |
| |
|
| | |
| | pipe = StableDiffusionPipeline.from_pretrained( |
| | "stabilityai/stable-diffusion-2-1-base" |
| | ).to(device) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) |
| |
|
| | parser = argparse.ArgumentParser(description="Stable Diffusion with OmegaConf") |
| | parser.add_argument( |
| | "--config", type=str, default="config.yaml", help="Path to the config file" |
| | ) |
| | parser.add_argument( |
| | "--mode", |
| | type=str, |
| | default="dataset", |
| | choices=["dataset", "cli", "app"], |
| | help="Path to the config file", |
| | ) |
| | parser.add_argument( |
| | "--image_dir", type=str, default="test.png", help="Path to the image" |
| | ) |
| | parser.add_argument( |
| | "--prompt", |
| | type=str, |
| | default="an impressionist painting", |
| | help="Stylization prompt", |
| | ) |
| | |
| | args = parser.parse_args() |
| | config_dir = args.config |
| | mode = args.mode |
| | |
| | out_name = ["content_delegation", "style_delegation", "style_out"] |
| |
|
| | if mode == "dataset": |
| | cfg = OmegaConf.load(config_dir) |
| |
|
| | base_output_path = cfg.out_path |
| | if not os.path.exists(cfg.out_path): |
| | os.makedirs(cfg.out_path) |
| | base_output_path = os.path.join(base_output_path, cfg.exp_name) |
| |
|
| | experiment_output_path = utils.exp_utils.make_unique_experiment_path( |
| | base_output_path |
| | ) |
| |
|
| | |
| | config_file_path = os.path.join(experiment_output_path, "config.yaml") |
| | omegaconf.OmegaConf.save(cfg, config_file_path) |
| |
|
| | |
| |
|
| | annotation = json.load(open(cfg.annotation)) |
| | with open(os.path.join(experiment_output_path, "annotation.json"), "w") as f: |
| | json.dump(annotation, f) |
| | for i, entry in enumerate(annotation): |
| | utils.exp_utils.seed_all(cfg.seed) |
| | image_path = entry["image_path"] |
| | src_prompt = entry["source_prompt"] |
| | tgt_prompt = entry["target_prompt"] |
| | resolution = 512 if isinstance(pipe, StableDiffusionXLPipeline) else 512 |
| | input_image = utils.exp_utils.get_processed_image( |
| | image_path, device, resolution |
| | ) |
| |
|
| | prompt_in = [ |
| | src_prompt, |
| | tgt_prompt, |
| | "", |
| | ] |
| |
|
| | imgs = style_image_with_inversion( |
| | pipe, |
| | input_image, |
| | src_prompt, |
| | style_prompt=prompt_in, |
| | num_steps=cfg.num_steps, |
| | start_step=cfg.start_step, |
| | guidance_scale=cfg.style_cfg_scale, |
| | disentangle=cfg.disentangle, |
| | resnet_mode=cfg.resnet_mode, |
| | share_attn=cfg.share_attn, |
| | share_cross_attn=cfg.share_cross_attn, |
| | share_resnet_layers=cfg.share_resnet_layers, |
| | share_attn_layers=cfg.share_attn_layers, |
| | share_key=cfg.share_key, |
| | share_query=cfg.share_query, |
| | share_value=cfg.share_value, |
| | use_content_anchor=cfg.use_content_anchor, |
| | use_adain=cfg.use_adain, |
| | output_dir=experiment_output_path, |
| | ) |
| |
|
| | for j, img in enumerate(imgs): |
| | img.save(f"{experiment_output_path}/out_{i}_{out_name[j]}.png") |
| | print( |
| | f"Image saved as {experiment_output_path}/out_{i}_{out_name[j]}.png" |
| | ) |
| | elif mode == "cli": |
| | cfg = OmegaConf.load(config_dir) |
| | utils.exp_utils.seed_all(cfg.seed) |
| | image = utils.exp_utils.get_processed_image(args.image_dir, device, 512) |
| | tgt_prompt = args.prompt |
| | src_prompt = "" |
| | prompt_in = [ |
| | "", |
| | tgt_prompt, |
| | "", |
| | ] |
| | out_dir = "./out" |
| | os.makedirs(out_dir, exist_ok=True) |
| | imgs = style_image_with_inversion( |
| | pipe, |
| | image, |
| | src_prompt, |
| | style_prompt=prompt_in, |
| | num_steps=cfg.num_steps, |
| | start_step=cfg.start_step, |
| | guidance_scale=cfg.style_cfg_scale, |
| | disentangle=cfg.disentangle, |
| | resnet_mode=cfg.resnet_mode, |
| | share_attn=cfg.share_attn, |
| | share_cross_attn=cfg.share_cross_attn, |
| | share_resnet_layers=cfg.share_resnet_layers, |
| | share_attn_layers=cfg.share_attn_layers, |
| | share_key=cfg.share_key, |
| | share_query=cfg.share_query, |
| | share_value=cfg.share_value, |
| | use_content_anchor=cfg.use_content_anchor, |
| | use_adain=cfg.use_adain, |
| | output_dir=out_dir, |
| | ) |
| | image_base_name = os.path.basename(args.image_dir).split(".")[0] |
| | for j, img in enumerate(imgs): |
| | img.save(f"{out_dir}/{image_base_name}_out_{out_name[j]}.png") |
| | print(f"Image saved as {out_dir}/{image_base_name}_out_{out_name[j]}.png") |
| | elif mode == "app": |
| | |
| | import gradio as gr |
| |
|
| | def style_transfer_app( |
| | prompt, |
| | image, |
| | cfg_scale=7.5, |
| | num_content_layers=4, |
| | num_style_layers=9, |
| | seed=0, |
| | progress=gr.Progress(track_tqdm=True), |
| | ): |
| | utils.exp_utils.seed_all(seed) |
| | image = utils.exp_utils.process_image(image, device, 512) |
| |
|
| | tgt_prompt = prompt |
| | src_prompt = "" |
| | prompt_in = [ |
| | "", |
| | tgt_prompt, |
| | "", |
| | ] |
| |
|
| | share_resnet_layers = ( |
| | list(range(num_content_layers)) if num_content_layers != 0 else None |
| | ) |
| | share_attn_layers = ( |
| | list(range(num_style_layers)) if num_style_layers != 0 else None |
| | ) |
| | imgs = style_image_with_inversion( |
| | pipe, |
| | image, |
| | src_prompt, |
| | style_prompt=prompt_in, |
| | num_steps=50, |
| | start_step=0, |
| | guidance_scale=cfg_scale, |
| | disentangle=True, |
| | resnet_mode="hidden", |
| | share_attn=True, |
| | share_cross_attn=True, |
| | share_resnet_layers=share_resnet_layers, |
| | share_attn_layers=share_attn_layers, |
| | share_key=True, |
| | share_query=True, |
| | share_value=False, |
| | use_content_anchor=True, |
| | use_adain=True, |
| | output_dir="./", |
| | ) |
| |
|
| | return imgs[2] |
| |
|
| | |
| | examples = [] |
| | annotation = json.load(open("data/example/annotation.json")) |
| | for entry in annotation: |
| | image = utils.exp_utils.get_processed_image( |
| | entry["image_path"], device, 512 |
| | ) |
| | image = transforms.ToPILImage()(image[0]) |
| |
|
| | examples.append([entry["target_prompt"], image, None, None, None]) |
| |
|
| | text_input = gr.Textbox( |
| | value="An impressionist painting", |
| | label="Text Prompt", |
| | info="Describe the style you want to apply to the image, do not include the description of the image content itself", |
| | lines=2, |
| | placeholder="Enter a text prompt", |
| | ) |
| | image_input = gr.Image( |
| | height="80%", |
| | width="80%", |
| | label="Content image (will be resized to 512x512)", |
| | interactive=True, |
| | ) |
| | cfg_slider = gr.Slider( |
| | 0, |
| | 15, |
| | value=7.5, |
| | label="Classifier Free Guidance (CFG) Scale", |
| | info="higher values give more style, 7.5 should be good for most cases", |
| | ) |
| | content_slider = gr.Slider( |
| | 0, |
| | 9, |
| | value=4, |
| | step=1, |
| | label="Number of content control layer", |
| | info="higher values make it more similar to original image. Default to control first 4 layers", |
| | ) |
| | style_slider = gr.Slider( |
| | 0, |
| | 9, |
| | value=9, |
| | step=1, |
| | label="Number of style control layer", |
| | info="higher values make it more similar to target style. Default to control first 9 layers, usually not necessary to change.", |
| | ) |
| | seed_slider = gr.Slider( |
| | 0, |
| | 100, |
| | value=0, |
| | step=1, |
| | label="Seed", |
| | info="Random seed for the model", |
| | ) |
| | app = gr.Interface( |
| | fn=style_transfer_app, |
| | inputs=[ |
| | text_input, |
| | image_input, |
| | cfg_slider, |
| | content_slider, |
| | style_slider, |
| | seed_slider, |
| | ], |
| | outputs=["image"], |
| | title="Artist Interactive Demo", |
| | examples=examples, |
| | ) |
| | app.launch() |