Spaces:
Paused
Paused
| import glob | |
| import os | |
| import torch | |
| from torch import nn, optim | |
| import torch.nn.functional as F | |
| import torchvision.transforms.functional as FF | |
| from PIL import Image | |
| import numpy as np | |
| from diffusers import UniPCMultistepScheduler | |
| from src.models.stage2_inpaint_unet_2d_condition import Stage2_InapintUNet2DConditionModel | |
| from accelerate import Accelerator | |
| from torchvision import transforms | |
| from diffusers.models.controlnet import ControlNetConditioningEmbedding | |
| from transformers import CLIPImageProcessor | |
| from transformers import Dinov2Model | |
| from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel,ControlNetModel,DDIMScheduler | |
| from src.pipelines.PCDMs_pipeline import PCDMsPipeline | |
| from single_extract_pose import inference_pose | |
| device = "cuda" | |
| pretrained_model_name_or_path ="stabilityai/stable-diffusion-2-1-base" | |
| image_encoder_path = "facebook/dinov2-giant" | |
| model_ckpt_path = "./pcdms_ckpt.pt" # ckpt path | |
| num_samples = 1 | |
| image_size = (512, 512) | |
| s_img_path = 'imgs/sm.png' # input image 1 | |
| target_pose_img = 'imgs/pose.png' # input image 2 | |
| def image_grid(imgs, rows, cols): | |
| assert len(imgs) == rows * cols | |
| w, h = imgs[0].size | |
| print(w, h) | |
| grid = Image.new("RGB", size=(cols * w, rows * h)) | |
| grid_w, grid_h = grid.size | |
| for i, img in enumerate(imgs): | |
| grid.paste(img, box=(i % cols * w, i // cols * h)) | |
| return grid | |
| def load_mydict(model_ckpt_path): | |
| model_sd = torch.load(model_ckpt_path, map_location="cpu")["module"] | |
| image_proj_model_dict = {} | |
| pose_proj_dict = {} | |
| unet_dict = {} | |
| for k in model_sd.keys(): | |
| if k.startswith("pose_proj"): | |
| pose_proj_dict[k.replace("pose_proj.", "")] = model_sd[k] | |
| elif k.startswith("image_proj_model"): | |
| image_proj_model_dict[k.replace("image_proj_model.", "")] = model_sd[k] | |
| elif k.startswith("unet"): | |
| unet_dict[k.replace("unet.", "")] = model_sd[k] | |
| else: | |
| print(k) | |
| return image_proj_model_dict, pose_proj_dict, unet_dict | |
| class ImageProjModel(torch.nn.Module): | |
| """SD model with image prompt""" | |
| def __init__(self, in_dim, hidden_dim, out_dim, dropout = 0.): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(in_dim, hidden_dim), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.LayerNorm(hidden_dim), | |
| nn.Linear(hidden_dim, out_dim), | |
| nn.Dropout(dropout) | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| clip_image_processor = CLIPImageProcessor() | |
| img_transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5], [0.5]), | |
| ]) | |
| generator = torch.Generator(device=device).manual_seed(42) | |
| unet = Stage2_InapintUNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16,subfolder="unet",in_channels=9, low_cpu_mem_usage=False, ignore_mismatched_sizes=True).to(device) | |
| vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path,subfolder="vae").to(device, dtype=torch.float16) | |
| image_encoder = Dinov2Model.from_pretrained(image_encoder_path).to(device, dtype=torch.float16) | |
| noise_scheduler = DDIMScheduler( | |
| num_train_timesteps=1000, | |
| beta_start=0.00085, | |
| beta_end=0.012, | |
| beta_schedule="scaled_linear", | |
| clip_sample=False, | |
| set_alpha_to_one=False, | |
| steps_offset=1, | |
| ) | |
| #noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler") | |
| print('====================== model load finish ===================') | |
| class SDModel(torch.nn.Module): | |
| """SD model with image prompt""" | |
| def __init__(self, unet) -> None: | |
| super().__init__() | |
| self.unet = unet | |
| self.image_proj_model = ImageProjModel(in_dim=1536, hidden_dim=768, out_dim=1024).to(device).to(dtype=torch.float16) | |
| self.pose_proj = ControlNetConditioningEmbedding( | |
| conditioning_embedding_channels=320, | |
| block_out_channels=(16, 32, 96, 256), | |
| conditioning_channels=3).to(device).to(dtype=torch.float16) | |
| # load weight | |
| image_proj_model_dict, pose_proj_dict, unet_dict = load_mydict(model_ckpt_path) | |
| self.image_proj_model.load_state_dict(image_proj_model_dict) | |
| self.pose_proj.load_state_dict(pose_proj_dict) | |
| unet.load_state_dict(unet_dict) | |
| def forward(self, s_img_path, t_pose_path, t_img_path, epoch): | |
| pipe = PCDMsPipeline.from_pretrained(pretrained_model_name_or_path, unet=self.unet, torch_dtype=torch.float16, scheduler=noise_scheduler,feature_extractor=None,safety_checker=None).to(device) | |
| t_pose = inference_pose(t_img_path, image_size=(image_size[1], image_size[0])).convert("RGB").resize(image_size, Image.BICUBIC) | |
| target_img = Image.open(t_img_path).convert("RGB").resize(image_size, Image.BICUBIC) | |
| s_img = Image.open(s_img_path).convert("RGB").resize(image_size, Image.BICUBIC) | |
| black_image = Image.new("RGB", s_img.size, (0, 0, 0)).resize(image_size, Image.BICUBIC) | |
| s_img_t_mask = Image.new("RGB", (s_img.width * 2, s_img.height)) | |
| s_img_t_mask.paste(s_img, (0, 0)) | |
| s_img_t_mask.paste(black_image, (s_img.width, 0)) | |
| s_pose = inference_pose(s_img_path, image_size=(image_size[1], image_size[0])).resize(image_size, Image.BICUBIC) | |
| print('source image width: {}, height: {}'.format(s_pose.width, s_pose.height)) | |
| #t_pose = Image.open(t_pose_path).convert("RGB").resize((image_size), Image.BICUBIC) | |
| st_pose = Image.new("RGB", (s_pose.width * 2, s_pose.height)) | |
| st_pose.paste(s_pose, (0, 0)) | |
| st_pose.paste(t_pose, (s_pose.width, 0)) | |
| clip_s_img = clip_image_processor(images=s_img, return_tensors="pt").pixel_values | |
| vae_image = torch.unsqueeze(img_transform(s_img_t_mask), 0) | |
| cond_st_pose = torch.unsqueeze(img_transform(st_pose), 0) | |
| mask1 = torch.ones((1, 1, int(image_size[0] / 8), int(image_size[1] / 8))).to(device, dtype=torch.float16) | |
| mask0 = torch.zeros((1, 1, int(image_size[0] / 8), int(image_size[1] / 8))).to(device, dtype=torch.float16) | |
| mask = torch.cat([mask1, mask0], dim=3) | |
| st_img = (Image.new("RGB", (image_size[0] * 2, image_size[1]))) | |
| st_img.paste(s_img, (0, 0)) | |
| st_img.paste(target_img, (image_size[0], 0)) | |
| st_img.save('tar.png') | |
| st_img = torch.unsqueeze(img_transform(st_img), 0) | |
| with torch.inference_mode(): | |
| cond_pose = self.pose_proj(cond_st_pose.to(dtype=torch.float16, device=device)) | |
| simg_mask_latents = pipe.vae.encode(vae_image.to(device, dtype=torch.float16)).latent_dist.sample() | |
| simg_mask_latents = simg_mask_latents * 0.18215 | |
| images_embeds = image_encoder(clip_s_img.to(device, dtype=torch.float16)).last_hidden_state | |
| image_prompt_embeds = self.image_proj_model(images_embeds) | |
| uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(images_embeds)) | |
| latents = pipe.vae.encode(st_img.to(device, dtype=torch.float16)).latent_dist.sample() | |
| latents = latents * pipe.vae.config.scaling_factor | |
| noise = torch.randn_like(latents) | |
| timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (1,),device=latents.device, ) | |
| timesteps = timesteps.long() | |
| target = noise_scheduler.get_velocity(latents, noise, timesteps) | |
| bs_embed, seq_len, _ = image_prompt_embeds.shape | |
| image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) | |
| image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) | |
| uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) | |
| uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) | |
| output, model_pred = pipe( | |
| simg_mask_latents= simg_mask_latents, | |
| mask = mask, | |
| cond_pose = cond_pose, | |
| prompt_embeds=image_prompt_embeds, | |
| negative_prompt_embeds=uncond_image_prompt_embeds, | |
| height=image_size[1], | |
| width=image_size[0]*2, | |
| num_images_per_prompt=num_samples, | |
| guidance_scale=2.0, | |
| generator=generator, | |
| num_inference_steps=50, | |
| ) | |
| output = output.images[-1] | |
| output.save('out'+str(epoch)+'.png') | |
| """ | |
| with torch.inference_mode(): | |
| output = torch.unsqueeze(img_transform(output), 0) | |
| latents = pipe.vae.encode(output.to(device, dtype=torch.float16)).latent_dist.sample() | |
| latents = latents * pipe.vae.config.scaling_factor | |
| noise = torch.randn_like(latents) | |
| timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (1,),device=latents.device, ) | |
| timesteps = timesteps.long() | |
| model_pred = noise_scheduler.get_velocity(latents, noise, timesteps) | |
| """ | |
| return model_pred, target | |
| # Training setup | |
| sd_model = SDModel(unet) | |
| sd_model.train() | |
| optimizer = optim.AdamW(sd_model.parameters(), lr=1e-5) | |
| loss_fn = nn.MSELoss() | |
| accelerator = Accelerator() | |
| sd_model, optimizer = accelerator.prepare(sd_model, optimizer) | |
| prev = sd_model.unet.state_dict() | |
| # Fine-tuning loop | |
| num_epochs = 5 | |
| for epoch in range(num_epochs): | |
| for s_img_path, t_pose_path, t_img_path in zip(['imgs/sm.png'], ['imgs/p1.png'], ['imgs/target.png']): | |
| with accelerator.accumulate(sd_model): | |
| optimizer.zero_grad() | |
| model_pred, target = sd_model(s_img_path, t_pose_path, t_img_path, epoch) | |
| #loss = loss_fn(torch.unsqueeze(img_transform(output), 0), torch.unsqueeze(img_transform(target_img),0)) | |
| loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | |
| loss.requires_grad = True | |
| accelerator.backward(loss) | |
| optimizer.step() | |
| set1 = set(prev.items()) | |
| set2 = set(sd_model.unet.state_dict().items()) | |
| dif = set1 ^ set2 | |
| print(len(dif)) | |
| prev = sd_model.unet.state_dict() | |
| print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}") | |
| # Save fine-tuned model | |
| torch.save(sd_model, "fine_tuned_pcdms.pt") | |
| #sd_model.save_checkpoint("outputs", "0", {}) | |
| print("Fine-tuning completed. Model saved.") | |