Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import sys | |
| import os | |
| import tqdm | |
| sys.path.append(os.path.abspath(os.path.join("", ".."))) | |
| import torch | |
| import gc | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| from PIL import Image | |
| from utils import load_models, save_model_w2w, save_model_for_diffusers | |
| from sampling import sample_weights | |
| from editing import get_direction, debias | |
| from huggingface_hub import snapshot_download | |
| global device | |
| global generator | |
| global unet | |
| global vae | |
| global text_encoder | |
| global tokenizer | |
| global noise_scheduler | |
| global young_val | |
| global pointy_val | |
| global bags_val | |
| device = "cuda:0" | |
| generator = torch.Generator(device=device) | |
| models_path = snapshot_download(repo_id="Snapchat/w2w") | |
| mean = torch.load(f"{models_path}/mean.pt").bfloat16().to(device) | |
| std = torch.load(f"{models_path}/std.pt").bfloat16().to(device) | |
| v = torch.load(f"{models_path}/V.pt").bfloat16().to(device) | |
| proj = torch.load(f"{models_path}/proj_1000pc.pt").bfloat16().to(device) | |
| df = torch.load(f"{models_path}/identity_df.pt") | |
| weight_dimensions = torch.load(f"{models_path}/weight_dimensions.pt") | |
| pinverse = torch.load(f"{models_path}/pinverse_1000pc.pt").bfloat16().to(device) | |
| unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device) | |
| global network | |
| def sample_model(): | |
| global unet | |
| del unet | |
| global network | |
| unet, _, _, _, _ = load_models(device) | |
| network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00) | |
| def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed): | |
| global device | |
| global generator | |
| global unet | |
| global vae | |
| global text_encoder | |
| global tokenizer | |
| global noise_scheduler | |
| generator = generator.manual_seed(seed) | |
| latents = torch.randn( | |
| (1, unet.in_channels, 512 // 8, 512 // 8), | |
| generator = generator, | |
| device = device | |
| ).bfloat16() | |
| text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") | |
| text_embeddings = text_encoder(text_input.input_ids.to(device))[0] | |
| max_length = text_input.input_ids.shape[-1] | |
| uncond_input = tokenizer( | |
| [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt" | |
| ) | |
| uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0] | |
| text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | |
| noise_scheduler.set_timesteps(ddim_steps) | |
| latents = latents * noise_scheduler.init_noise_sigma | |
| for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)): | |
| latent_model_input = torch.cat([latents] * 2) | |
| latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t) | |
| with network: | |
| noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample | |
| #guidance | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
| latents = noise_scheduler.step(noise_pred, t, latents).prev_sample | |
| latents = 1 / 0.18215 * latents | |
| image = vae.decode(latents).sample | |
| image = (image / 2 + 0.5).clamp(0, 1) | |
| image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0] | |
| image = Image.fromarray((image * 255).round().astype("uint8")) | |
| return [image] | |
| def edit_inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed, start_noise, a1, a2, a3, a4): | |
| global device | |
| global generator | |
| global unet | |
| global vae | |
| global text_encoder | |
| global tokenizer | |
| global noise_scheduler | |
| global young | |
| global pointy | |
| global wavy | |
| global large | |
| original_weights = network.proj.clone() | |
| edited_weights = original_weights+a1*1e6*young+a2*1e6*pointy+a3*1e6*wavy+a4*2e6*large | |
| generator = generator.manual_seed(seed) | |
| latents = torch.randn( | |
| (1, unet.in_channels, 512 // 8, 512 // 8), | |
| generator = generator, | |
| device = device | |
| ).bfloat16() | |
| text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") | |
| text_embeddings = text_encoder(text_input.input_ids.to(device))[0] | |
| max_length = text_input.input_ids.shape[-1] | |
| uncond_input = tokenizer( | |
| [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt" | |
| ) | |
| uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0] | |
| text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | |
| noise_scheduler.set_timesteps(ddim_steps) | |
| latents = latents * noise_scheduler.init_noise_sigma | |
| for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)): | |
| latent_model_input = torch.cat([latents] * 2) | |
| latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t) | |
| if t>start_noise: | |
| pass | |
| elif t<=start_noise: | |
| network.proj = torch.nn.Parameter(edited_weights) | |
| network.reset() | |
| with network: | |
| noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample | |
| #guidance | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
| latents = noise_scheduler.step(noise_pred, t, latents).prev_sample | |
| latents = 1 / 0.18215 * latents | |
| image = vae.decode(latents).sample | |
| image = (image / 2 + 0.5).clamp(0, 1) | |
| image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0] | |
| image = Image.fromarray((image * 255).round().astype("uint8")) | |
| #reset weights back to original | |
| network.proj = torch.nn.Parameter(original_weights) | |
| network.reset() | |
| return [image] | |
| def sample_then_run(): | |
| global young_val | |
| global pointy_val | |
| global bags_val | |
| global young | |
| global pointy | |
| global bags | |
| sample_model() | |
| young_val = network.proj@young[0]/(torch.norm(young)**2).item() | |
| pointy_val = network.proj@pointy[0]/(torch.norm(pointy)**2).item() | |
| bags_val = network.proj@bags[0]/(torch.norm(bags)**2).item() | |
| prompt = "sks person" | |
| negative_prompt = "low quality, blurry, unfinished, cartoon" | |
| seed = 5 | |
| cfg = 3.0 | |
| steps = 50 | |
| image = inference( prompt, negative_prompt, cfg, steps, seed) | |
| return image | |
| #directions | |
| global young | |
| global pointy | |
| global wavy | |
| global large | |
| young = get_direction(df, "Young", pinverse, 1000, device) | |
| young = debias(young, "Male", df, pinverse, device) | |
| young = debias(young, "Pointy_Nose", df, pinverse, device) | |
| young = debias(young, "Wavy_Hair", df, pinverse, device) | |
| young = debias(young, "Chubby", df, pinverse, device) | |
| young_max = torch.max(proj@young[0]/(torch.norm(young))**2).item() | |
| young_min = torch.min(proj@young[0]/(torch.norm(young))**2).item() | |
| pointy = get_direction(df, "Pointy_Nose", pinverse, 1000, device) | |
| pointy = debias(pointy, "Young", df, pinverse, device) | |
| pointy = debias(pointy, "Male", df, pinverse, device) | |
| pointy = debias(pointy, "Wavy_Hair", df, pinverse, device) | |
| pointy = debias(pointy, "Chubby", df, pinverse, device) | |
| pointy = debias(pointy, "Heavy_Makeup", df, pinverse, device) | |
| pointy_max = torch.max(proj@pointy[0]/(torch.norm(pointy))**2).item() | |
| pointy_min = torch.min(proj@pointy[0]/(torch.norm(pointy))**2).item() | |
| wavy = get_direction(df, "Wavy_Hair", pinverse, 1000, device) | |
| wavy = debias(wavy, "Young", df, pinverse, device) | |
| wavy = debias(wavy, "Male", df, pinverse, device) | |
| wavy = debias(wavy, "Pointy_Nose", df, pinverse, device) | |
| wavy = debias(wavy, "Chubby", df, pinverse, device) | |
| wavy = debias(wavy, "Heavy_Makeup", df, pinverse, device) | |
| wavy_max = torch.max(proj@wavy[0]/(torch.norm(wavy))**2).item() | |
| wavy_min = torch.min(proj@wavy[0]/(torch.norm(wavy))**2).item() | |
| large = get_direction(df, "Chubby", pinverse, 1000, device) | |
| large = debias(large, "Male", df, pinverse, device) | |
| large = debias(large, "Young", df, pinverse, device) | |
| large = debias(large, "Pointy_Nose", df, pinverse, device) | |
| large = debias(large, "Wavy_Hair", df, pinverse, device) | |
| large_max = torch.max(proj@large[0]/(torch.norm(large))**2).item() | |
| large_min = torch.min(proj@large[0]/(torch.norm(large))**2).item() | |
| intro = """ | |
| <div style="display: flex;align-items: center;justify-content: center"> | |
| <h1 style="margin-left: 12px;text-align: center;margin-bottom: 7px;display: inline-block">weights2weights</h1> | |
| <h3 style="display: inline-block;margin-left: 10px;margin-top: 6px;font-weight: 500">Interpreting the Weight Space of Customized Diffusion Models</h3> | |
| </div> | |
| <p style="font-size: 0.95rem;margin: 0rem;line-height: 1.2em;margin-top:1em;display: inline-block"> | |
| <a href="https://snap-research.github.io/weights2weights/" target="_blank">project page</a> | <a href="https://arxiv.org/abs/2406.09413" target="_blank">paper</a> | |
| | | |
| <a href="https://huggingface.co/spaces/Snapchat/w2w-demo?duplicate=true" target="_blank" style=" | |
| display: inline-block; | |
| "> | |
| <img style="margin-top: -1em;margin-bottom: 0em;position: absolute;" src="https://bit.ly/3CWLGkA" alt="Duplicate Space"></a> | |
| </p> | |
| """ | |
| with gr.Blocks(css="style.css") as demo: | |
| gr.HTML(intro) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gallery1 = gr.Gallery(label="Identity from Sampled Model") | |
| sample = gr.Button("Sample New Model") | |
| gallery2 = gr.Gallery(label="Identity from Edited Model") | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt = gr.Textbox(label="Prompt", | |
| info="Make sure to include 'sks person'" , | |
| placeholder="sks person", | |
| value="sks person") | |
| negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, cartoon", value="low quality, blurry, unfinished, cartoon") | |
| with gr.Row(): | |
| a1 = gr.Slider(label="+Young", value=0, step=0.001, minimum=-1, maximum=1, interactive=True) | |
| a2 = gr.Slider(label="+Pointy Nose", value=0, step=0.001, minimum=-1, maximum=1, interactive=True) | |
| with gr.Row(): | |
| a3 = gr.Slider(label="+Curly Hair", value=0, step=0.001, minimum=-1, maximum=1, interactive=True) | |
| a4 = gr.Slider(label="+Large", value=0, step=0.001, minimum=-1, maximum=1, interactive=True) | |
| with gr.Accordion("Advanced Options", open=False): | |
| with gr.Column(): | |
| seed = gr.Number(value=5, label="Seed", interactive=True) | |
| cfg = gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True) | |
| steps = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True) | |
| injection_step = gr.Slider(label="Injection Step", value=800, step=1, minimum=0, maximum=1000, interactive=True) | |
| submit = gr.Button("Submit") | |
| sample.click(fn=sample_then_run, outputs=gallery1) | |
| submit.click(fn=edit_inference, | |
| inputs=[prompt, negative_prompt, cfg, steps, seed, injection_step, a1, a2, a3, a4], | |
| outputs=gallery2) | |
| demo.launch(share=True) | |