Spaces:
Build error
Build error
| # 手元で推論を行うための最低限のコード。HuggingFace/DiffusersのCLIP、schedulerとVAEを使う | |
| # Minimal code for performing inference at local. Use HuggingFace/Diffusers CLIP, scheduler and VAE | |
| import argparse | |
| import datetime | |
| import math | |
| import os | |
| import random | |
| from einops import repeat | |
| import numpy as np | |
| import torch | |
| from tqdm import tqdm | |
| from transformers import CLIPTokenizer | |
| from diffusers import EulerDiscreteScheduler | |
| from PIL import Image | |
| import open_clip | |
| from safetensors.torch import load_file | |
| from library import model_util, sdxl_model_util | |
| import networks.lora as lora | |
| # scheduler: このあたりの設定はSD1/2と同じでいいらしい | |
| # scheduler: The settings around here seem to be the same as SD1/2 | |
| SCHEDULER_LINEAR_START = 0.00085 | |
| SCHEDULER_LINEAR_END = 0.0120 | |
| SCHEDULER_TIMESTEPS = 1000 | |
| SCHEDLER_SCHEDULE = "scaled_linear" | |
| # Time EmbeddingはDiffusersからのコピー | |
| # Time Embedding is copied from Diffusers | |
| def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): | |
| """ | |
| Create sinusoidal timestep embeddings. | |
| :param timesteps: a 1-D Tensor of N indices, one per batch element. | |
| These may be fractional. | |
| :param dim: the dimension of the output. | |
| :param max_period: controls the minimum frequency of the embeddings. | |
| :return: an [N x dim] Tensor of positional embeddings. | |
| """ | |
| if not repeat_only: | |
| half = dim // 2 | |
| freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( | |
| device=timesteps.device | |
| ) | |
| args = timesteps[:, None].float() * freqs[None] | |
| embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) | |
| if dim % 2: | |
| embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) | |
| else: | |
| embedding = repeat(timesteps, "b -> b d", d=dim) | |
| return embedding | |
| def get_timestep_embedding(x, outdim): | |
| assert len(x.shape) == 2 | |
| b, dims = x.shape[0], x.shape[1] | |
| # x = rearrange(x, "b d -> (b d)") | |
| x = torch.flatten(x) | |
| emb = timestep_embedding(x, outdim) | |
| # emb = rearrange(emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=outdim) | |
| emb = torch.reshape(emb, (b, dims * outdim)) | |
| return emb | |
| if __name__ == "__main__": | |
| # 画像生成条件を変更する場合はここを変更 / change here to change image generation conditions | |
| # SDXLの追加のvector embeddingへ渡す値 / Values to pass to additional vector embedding of SDXL | |
| target_height = 1024 | |
| target_width = 1024 | |
| original_height = target_height | |
| original_width = target_width | |
| crop_top = 0 | |
| crop_left = 0 | |
| steps = 50 | |
| guidance_scale = 7 | |
| seed = None # 1 | |
| DEVICE = "cuda" | |
| DTYPE = torch.float16 # bfloat16 may work | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--ckpt_path", type=str, required=True) | |
| parser.add_argument("--prompt", type=str, default="A photo of a cat") | |
| parser.add_argument("--negative_prompt", type=str, default="") | |
| parser.add_argument("--output_dir", type=str, default=".") | |
| parser.add_argument( | |
| "--lora_weights", | |
| type=str, | |
| nargs="*", | |
| default=[], | |
| help="LoRA weights, only supports networks.lora, each arguement is a `path;multiplier` (semi-colon separated)", | |
| ) | |
| parser.add_argument("--interactive", action="store_true") | |
| args = parser.parse_args() | |
| # HuggingFaceのmodel id | |
| text_encoder_1_name = "openai/clip-vit-large-patch14" | |
| text_encoder_2_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" | |
| # checkpointを読み込む。モデル変換についてはそちらの関数を参照 | |
| # Load checkpoint. For model conversion, see this function | |
| # 本体RAMが少ない場合はGPUにロードするといいかも | |
| # If the main RAM is small, it may be better to load it on the GPU | |
| text_model1, text_model2, vae, unet, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( | |
| sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, args.ckpt_path, "cpu" | |
| ) | |
| # Text Encoder 1はSDXL本体でもHuggingFaceのものを使っている | |
| # In SDXL, Text Encoder 1 is also using HuggingFace's | |
| # Text Encoder 2はSDXL本体ではopen_clipを使っている | |
| # それを使ってもいいが、SD2のDiffusers版に合わせる形で、HuggingFaceのものを使う | |
| # 重みの変換コードはSD2とほぼ同じ | |
| # In SDXL, Text Encoder 2 is using open_clip | |
| # It's okay to use it, but to match the Diffusers version of SD2, use HuggingFace's | |
| # The weight conversion code is almost the same as SD2 | |
| # VAEの構造はSDXLもSD1/2と同じだが、重みは異なるようだ。何より謎のscale値が違う | |
| # fp16でNaNが出やすいようだ | |
| # The structure of VAE is the same as SD1/2, but the weights seem to be different. Above all, the mysterious scale value is different. | |
| # NaN seems to be more likely to occur in fp16 | |
| unet.to(DEVICE, dtype=DTYPE) | |
| unet.eval() | |
| vae_dtype = DTYPE | |
| if DTYPE == torch.float16: | |
| print("use float32 for vae") | |
| vae_dtype = torch.float32 | |
| vae.to(DEVICE, dtype=vae_dtype) | |
| vae.eval() | |
| text_model1.to(DEVICE, dtype=DTYPE) | |
| text_model1.eval() | |
| text_model2.to(DEVICE, dtype=DTYPE) | |
| text_model2.eval() | |
| unet.set_use_memory_efficient_attention(True, False) | |
| vae.set_use_memory_efficient_attention_xformers(True) | |
| # Tokenizers | |
| tokenizer1 = CLIPTokenizer.from_pretrained(text_encoder_1_name) | |
| tokenizer2 = lambda x: open_clip.tokenize(x, context_length=77) | |
| # LoRA | |
| for weights_file in args.lora_weights: | |
| if ";" in weights_file: | |
| weights_file, multiplier = weights_file.split(";") | |
| multiplier = float(multiplier) | |
| else: | |
| multiplier = 1.0 | |
| lora_model, weights_sd = lora.create_network_from_weights( | |
| multiplier, weights_file, vae, [text_model1, text_model2], unet, None, True | |
| ) | |
| lora_model.merge_to([text_model1, text_model2], unet, weights_sd, DTYPE, DEVICE) | |
| # scheduler | |
| scheduler = EulerDiscreteScheduler( | |
| num_train_timesteps=SCHEDULER_TIMESTEPS, | |
| beta_start=SCHEDULER_LINEAR_START, | |
| beta_end=SCHEDULER_LINEAR_END, | |
| beta_schedule=SCHEDLER_SCHEDULE, | |
| ) | |
| def generate_image(prompt, negative_prompt, seed=None): | |
| # 将来的にサイズ情報も変えられるようにする / Make it possible to change the size information in the future | |
| # prepare embedding | |
| with torch.no_grad(): | |
| # vector | |
| emb1 = get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256) | |
| emb2 = get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256) | |
| emb3 = get_timestep_embedding(torch.FloatTensor([target_height, target_width]).unsqueeze(0), 256) | |
| # print("emb1", emb1.shape) | |
| c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(DEVICE, dtype=DTYPE) | |
| uc_vector = c_vector.clone().to(DEVICE, dtype=DTYPE) # ちょっとここ正しいかどうかわからない I'm not sure if this is right | |
| # crossattn | |
| # Text Encoderを二つ呼ぶ関数 Function to call two Text Encoders | |
| def call_text_encoder(text): | |
| # text encoder 1 | |
| batch_encoding = tokenizer1( | |
| text, | |
| truncation=True, | |
| return_length=True, | |
| return_overflowing_tokens=False, | |
| padding="max_length", | |
| return_tensors="pt", | |
| ) | |
| tokens = batch_encoding["input_ids"].to(DEVICE) | |
| with torch.no_grad(): | |
| enc_out = text_model1(tokens, output_hidden_states=True, return_dict=True) | |
| text_embedding1 = enc_out["hidden_states"][11] | |
| # text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding) # layer normは通さないらしい | |
| # text encoder 2 | |
| with torch.no_grad(): | |
| tokens = tokenizer2(text).to(DEVICE) | |
| enc_out = text_model2(tokens, output_hidden_states=True, return_dict=True) | |
| text_embedding2_penu = enc_out["hidden_states"][-2] | |
| # print("hidden_states2", text_embedding2_penu.shape) | |
| text_embedding2_pool = enc_out["text_embeds"] | |
| # 連結して終了 concat and finish | |
| text_embedding = torch.cat([text_embedding1, text_embedding2_penu], dim=2) | |
| return text_embedding, text_embedding2_pool | |
| # cond | |
| c_ctx, c_ctx_pool = call_text_encoder(prompt) | |
| # print(c_ctx.shape, c_ctx_p.shape, c_vector.shape) | |
| c_vector = torch.cat([c_ctx_pool, c_vector], dim=1) | |
| # uncond | |
| uc_ctx, uc_ctx_pool = call_text_encoder(negative_prompt) | |
| uc_vector = torch.cat([uc_ctx_pool, uc_vector], dim=1) | |
| text_embeddings = torch.cat([uc_ctx, c_ctx]) | |
| vector_embeddings = torch.cat([uc_vector, c_vector]) | |
| # メモリ使用量を減らすにはここでText Encoderを削除するかCPUへ移動する | |
| if seed is not None: | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| # # random generator for initial noise | |
| # generator = torch.Generator(device="cuda").manual_seed(seed) | |
| generator = None | |
| else: | |
| generator = None | |
| # get the initial random noise unless the user supplied it | |
| # SDXLはCPUでlatentsを作成しているので一応合わせておく、Diffusersはtarget deviceでlatentsを作成している | |
| # SDXL creates latents in CPU, Diffusers creates latents in target device | |
| latents_shape = (1, 4, target_height // 8, target_width // 8) | |
| latents = torch.randn( | |
| latents_shape, | |
| generator=generator, | |
| device="cpu", | |
| dtype=torch.float32, | |
| ).to(DEVICE, dtype=DTYPE) | |
| # scale the initial noise by the standard deviation required by the scheduler | |
| latents = latents * scheduler.init_noise_sigma | |
| # set timesteps | |
| scheduler.set_timesteps(steps, DEVICE) | |
| # このへんはDiffusersからのコピペ | |
| # Copy from Diffusers | |
| timesteps = scheduler.timesteps.to(DEVICE) # .to(DTYPE) | |
| num_latent_input = 2 | |
| with torch.no_grad(): | |
| for i, t in enumerate(tqdm(timesteps)): | |
| # expand the latents if we are doing classifier free guidance | |
| latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) | |
| latent_model_input = scheduler.scale_model_input(latent_model_input, t) | |
| noise_pred = unet(latent_model_input, t, text_embeddings, vector_embeddings) | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_latent_input) # uncond by negative prompt | |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
| # compute the previous noisy sample x_t -> x_t-1 | |
| # latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample | |
| latents = scheduler.step(noise_pred, t, latents).prev_sample | |
| # latents = 1 / 0.18215 * latents | |
| latents = 1 / sdxl_model_util.VAE_SCALE_FACTOR * latents | |
| latents = latents.to(vae_dtype) | |
| image = vae.decode(latents).sample | |
| image = (image / 2 + 0.5).clamp(0, 1) | |
| # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 | |
| image = image.cpu().permute(0, 2, 3, 1).float().numpy() | |
| # image = self.numpy_to_pil(image) | |
| image = (image * 255).round().astype("uint8") | |
| image = [Image.fromarray(im) for im in image] | |
| # 保存して終了 save and finish | |
| timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") | |
| for i, img in enumerate(image): | |
| img.save(os.path.join(args.output_dir, f"image_{timestamp}_{i:03d}.png")) | |
| if not args.interactive: | |
| generate_image(args.prompt, args.negative_prompt, seed) | |
| else: | |
| # loop for interactive | |
| while True: | |
| prompt = input("prompt: ") | |
| if prompt == "": | |
| break | |
| negative_prompt = input("negative prompt: ") | |
| seed = input("seed: ") | |
| if seed == "": | |
| seed = None | |
| else: | |
| seed = int(seed) | |
| generate_image(prompt, negative_prompt, seed) | |
| print("Done!") | |