| import sys |
| from pathlib import Path |
| from typing import List, Optional |
|
|
| import gradio as gr |
| import torch |
| from PIL import Image |
| from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler |
| from huggingface_hub import snapshot_download |
| from transformers import CLIPTokenizer |
|
|
| import constants |
| from checkpoint_handler import CheckpointHandler |
| from models.neti_clip_text_encoder import NeTICLIPTextModel |
| from models.xti_attention_processor import XTIAttenProc |
| from prompt_manager import PromptManager |
| from scripts.inference import run_inference |
|
|
| sys.path.append(".") |
| sys.path.append("..") |
|
|
| DESCRIPTION = ''' |
| # A Neural Space-Time Representation for Text-to-Image Personalization |
| <p style="text-align: center;"> |
| This is a demo for our <a href="https://arxiv.org/abs/2305.15391">paper</a>: ''A Neural Space-Time Representation |
| for Text-to-Image Personalization''. |
| <br> |
| Project page and code is available <a href="https://neuraltextualinversion.github.io/NeTI/">here</a>. |
| <br> |
| We introduce a new text-conditioning latent space P* that is dependent on both the denoising process timestep and |
| the U-Net layers. |
| This space-time representation is learned implicitly via a small mapping network. |
| <br> |
| Here, you can generate images using one of the concepts trained in our paper. Simply select your concept and |
| random seed. |
| <br> |
| You can also choose different truncation values to play with the reconstruction vs. editability of the concept. |
| </p> |
| ''' |
|
|
| CONCEPT_TO_PLACEHOLDER = { |
| 'barn': '<barn>', |
| 'cat': '<cat>', |
| 'clock': '<clock>', |
| 'colorful_teapot': '<colorful-teapot>', |
| 'dangling_child': '<dangling-child>', |
| 'dog': '<dog>', |
| 'elephant': '<elephant>', |
| 'fat_stone_bird': '<stone-bird>', |
| 'headless_statue': '<headless-statue>', |
| 'lecun': '<lecun>', |
| 'maeve': '<maeve-dog>', |
| 'metal_bird': '<metal-bird>', |
| 'mugs_skulls': '<mug-skulls>', |
| 'rainbow_cat': '<rainbow-cat>', |
| 'red_bowl': '<red-bowl>', |
| 'teddybear': '<teddybear>', |
| 'tortoise_plushy': '<tortoise-plushy>', |
| 'wooden_pot': '<wooden-pot>' |
| } |
|
|
| MODELS_PATH = Path('./trained_models') |
| MODELS_PATH.mkdir(parents=True, exist_ok=True) |
|
|
|
|
| def load_stable_diffusion_model(pretrained_model_name_or_path: str, |
| num_denoising_steps: int = 50, |
| torch_dtype: torch.dtype = torch.float16) -> StableDiffusionPipeline: |
| tokenizer = CLIPTokenizer.from_pretrained( |
| pretrained_model_name_or_path, subfolder="tokenizer") |
| text_encoder = NeTICLIPTextModel.from_pretrained( |
| pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=torch_dtype, |
| ) |
| pipeline = StableDiffusionPipeline.from_pretrained( |
| pretrained_model_name_or_path, |
| torch_dtype=torch_dtype, |
| text_encoder=text_encoder, |
| tokenizer=tokenizer |
| ).to("cuda") |
| pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) |
| pipeline.scheduler.set_timesteps(num_denoising_steps, device=pipeline.device) |
| pipeline.unet.set_attn_processor(XTIAttenProc()) |
| return pipeline |
|
|
|
|
| def get_possible_concepts() -> List[str]: |
| objects = [x for x in MODELS_PATH.iterdir() if x.is_dir()] |
| return [x.name for x in objects] |
|
|
|
|
| def load_sd_and_all_tokens(): |
| mappers = {} |
| pipeline = load_stable_diffusion_model(pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4") |
| print("Downloading all models from HF Hub...") |
| snapshot_download(repo_id="neural-ti/NeTI", local_dir='./trained_models') |
| print("Done.") |
| concepts = get_possible_concepts() |
| for concept in concepts: |
| print(f"Loading model for concept: {concept}") |
| learned_embeds_path = MODELS_PATH / concept / f"{concept}-learned_embeds.bin" |
| mapper_path = MODELS_PATH / concept / f"{concept}-mapper.pt" |
| train_cfg, mapper = CheckpointHandler.load_mapper(mapper_path=mapper_path) |
| placeholder_token, placeholder_token_id = CheckpointHandler.load_learned_embed_in_clip( |
| learned_embeds_path=learned_embeds_path, |
| text_encoder=pipeline.text_encoder, |
| tokenizer=pipeline.tokenizer |
| ) |
| mappers[concept] = { |
| "mapper": mapper, |
| "placeholder_token": placeholder_token, |
| "placeholder_token_id": placeholder_token_id |
| } |
| return mappers, pipeline |
|
|
|
|
| mappers, pipeline = load_sd_and_all_tokens() |
|
|
|
|
| def main_pipeline(concept_name: str, |
| prompt_input: str, |
| seed: int, |
| use_truncation: bool = False, |
| truncation_idx: Optional[int] = None) -> Image.Image: |
| pipeline.text_encoder.text_model.embeddings.set_mapper(mappers[concept_name]["mapper"]) |
| placeholder_token = mappers[concept_name]["placeholder_token"] |
| placeholder_token_id = mappers[concept_name]["placeholder_token_id"] |
| prompt_manager = PromptManager(tokenizer=pipeline.tokenizer, |
| text_encoder=pipeline.text_encoder, |
| timesteps=pipeline.scheduler.timesteps, |
| unet_layers=constants.UNET_LAYERS, |
| placeholder_token=placeholder_token, |
| placeholder_token_id=placeholder_token_id, |
| torch_dtype=torch.float16) |
| image = run_inference(prompt=prompt_input.replace("*", CONCEPT_TO_PLACEHOLDER[concept_name]), |
| pipeline=pipeline, |
| prompt_manager=prompt_manager, |
| seeds=[int(seed)], |
| num_images_per_prompt=1, |
| truncation_idx=truncation_idx if use_truncation else None) |
| return [image] |
|
|
|
|
| with gr.Blocks(css='style.css') as demo: |
| gr.Markdown(DESCRIPTION) |
|
|
| gr.HTML('''<a href="https://huggingface.co/spaces/neural-ti/NeTI?duplicate=true"><img src="https://bit.ly/3gLdBN6" |
| alt="Duplicate Space"></a>''') |
|
|
| with gr.Row(): |
| with gr.Column(): |
| concept = gr.Dropdown(get_possible_concepts(), multiselect=False, label="Concept", |
| info="Choose your concept") |
| prompt = gr.Textbox(label="Input prompt", info="Input prompt with placeholder for concept. " |
| "Please use * to specify the concept.") |
| random_seed = gr.Number(value=42, label="Random seed", precision=0) |
| use_truncation = gr.Checkbox(label="Use inference-time dropout", |
| info="Whether to use our dropout technique when computing the concept " |
| "embeddings.") |
| truncation_idx = gr.Slider(8, 128, label="Truncation index", |
| info="If using truncation, which index to truncate from. Lower numbers tend to " |
| "result in more editable images, but at the cost of reconstruction.") |
| run_button = gr.Button('Generate') |
|
|
| with gr.Column(): |
| result = gr.Gallery(label='Result') |
| inputs = [concept, prompt, random_seed, use_truncation, truncation_idx] |
| outputs = [result] |
| run_button.click(fn=main_pipeline, inputs=inputs, outputs=outputs) |
|
|
| with gr.Row(): |
| examples = [ |
| ["maeve", "A photo of * swimming in the ocean", 5196, True, 16], |
| ["dangling_child", "A photo of * in Times Square", 3552126062741487430, False, 8], |
| ["teddybear", "A photo of * at his graduation ceremony after finishing his PhD", 263, True, 32], |
| ["red_bowl", "A * vase filled with flowers", 13491504810502930872, False, 8], |
| ["metal_bird", "* in a comic book", 1028, True, 24], |
| ["fat_stone_bird", "A movie poster of The Rock, featuring * about on Godzilla", 7393181316156044422, True, |
| 64], |
| ] |
| gr.Examples(examples=examples, |
| inputs=[concept, prompt, random_seed, use_truncation, truncation_idx], |
| outputs=[result], |
| fn=main_pipeline, |
| cache_examples=True) |
|
|
| demo.queue(max_size=50).launch(share=False) |
|
|