Spaces:
Runtime error
Runtime error
| # Copyright 2024 Bytedance Ltd. and/or its affiliates | |
| # SPDX-License-Identifier: Apache-2.0 | |
| import torch | |
| from pathlib import Path | |
| import numpy as np | |
| from diffusers import ControlNetModel, EulerDiscreteScheduler | |
| from diffusers.loaders.unet import UNet2DConditionLoadersMixin | |
| from .pipeline_idpatch_sd_xl import StableDiffusionXLIDPatchPipeline | |
| class IDPatchInferencer: | |
| def __init__(self, base_model_path, idp_model_path, patch_size=64, torch_device='cuda:0', torch_dtype=torch.bfloat16): | |
| super().__init__() | |
| self.patch_size = patch_size | |
| self.torch_device = torch_device | |
| self.torch_dtype = torch_dtype | |
| idp_state_dict = torch.load(Path(idp_model_path) / 'id-patch.bin', map_location="cpu") | |
| loader = UNet2DConditionLoadersMixin() | |
| self.id_patch_projection = loader._convert_ip_adapter_image_proj_to_diffusers(idp_state_dict['patch_proj']).to(self.torch_device, dtype=self.torch_dtype).eval() | |
| self.id_prompt_projection = loader._convert_ip_adapter_image_proj_to_diffusers(idp_state_dict['prompt_proj']).to(self.torch_device, dtype=self.torch_dtype).eval() | |
| controlnet = ControlNetModel.from_pretrained(Path(idp_model_path) / 'ControlNetModel').to(self.torch_device, dtype=self.torch_dtype).eval() | |
| scheduler = EulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler") | |
| self.pipe = StableDiffusionXLIDPatchPipeline.from_pretrained( | |
| base_model_path, | |
| controlnet=controlnet, | |
| scheduler=scheduler, | |
| torch_dtype=self.torch_dtype, | |
| ).to(self.torch_device) | |
| def get_text_embeds_from_strings(self, text_strings): | |
| pipe = self.pipe | |
| device = pipe.device | |
| tokenizer_1 = pipe.tokenizer | |
| tokenizer_2 = pipe.tokenizer_2 | |
| text_encoder_1 = pipe.text_encoder | |
| text_encoder_2 = pipe.text_encoder_2 | |
| text_embeds = [] | |
| for tokenizer, text_encoder in [(tokenizer_1, text_encoder_1), (tokenizer_2, text_encoder_2)]: | |
| input_ids = tokenizer( | |
| text_strings, | |
| max_length=tokenizer.model_max_length, | |
| padding="max_length", | |
| truncation=True, | |
| return_tensors="pt" | |
| ).input_ids.to(device) | |
| text_embeds.append(text_encoder(input_ids, output_hidden_states=True)) | |
| pooled_embeds = text_embeds[1]['text_embeds'] | |
| text_embeds = torch.concat([text_embeds[0]['hidden_states'][-2], text_embeds[1]['hidden_states'][-2]], dim=2) | |
| return text_embeds, pooled_embeds | |
| def generate(self, face_embeds, face_locations, control_image, prompt, negative_prompt="", guidance_scale=5.0, num_inference_steps=50, controlnet_conditioning_scale=0.8, id_injection_ratio=0.8, seed=-1): | |
| """ | |
| face_embeds: n_faces x 512 | |
| face_locations: n_faces x 2[xy] | |
| control_image: PIL image | |
| """ | |
| face_locations = face_locations.to(self.torch_device, self.torch_dtype) | |
| control_image = torch.from_numpy(np.array(control_image)).to(self.torch_device, dtype=self.torch_dtype).permute(2,0,1)[None] / 255.0 | |
| height, width = control_image.shape[2:4] | |
| text_embeds, pooled_embeds = self.get_text_embeds_from_strings([negative_prompt, prompt]) # text_embeds: 2 x 77 x 2048, pooled_embeds: 2 x 1280 | |
| negative_pooled_embeds, pooled_embeds = pooled_embeds[:1], pooled_embeds[1:] | |
| negative_text_embeds, text_embeds = text_embeds[:1], text_embeds[1:] | |
| n_faces = len(face_embeds) | |
| negative_id_embeds = self.id_prompt_projection(torch.zeros(n_faces, 1, 512, device=self.torch_device, dtype=self.torch_dtype)) # (BxF) x 16 x 2048 | |
| negative_id_embeds = negative_id_embeds.reshape(1, -1, negative_id_embeds.shape[2]) # B x (Fx16) x 2048 | |
| negative_text_id_embeds = torch.concat([negative_text_embeds, negative_id_embeds], dim=1) | |
| face_embeds = face_embeds[None].to(self.torch_device, self.torch_dtype) # 1 x faces x 512 | |
| id_embeds = self.id_prompt_projection(face_embeds.reshape(-1, 1, 512)) # (BxF) x 16 x 2048 | |
| id_embeds = id_embeds.reshape(face_embeds.shape[0], -1, id_embeds.shape[2]) # B x (Fx16) x 2048 | |
| text_id_embeds = torch.concat([text_embeds, id_embeds], dim=1) # B x (77+Fx16) x 2048 | |
| patch_prompt_embeds = self.id_patch_projection(face_embeds.reshape(-1, 1, 512)) # (Bxn_faces) x 3 x (64*64) | |
| patch_prompt_embeds = patch_prompt_embeds.reshape(1, n_faces, 3, self.patch_size, self.patch_size) | |
| pad = self.patch_size // 2 | |
| canvas = torch.zeros((1, 3, height + pad * 2, width + pad * 2), device=self.torch_device) | |
| xymin = torch.round(face_locations - self.patch_size // 2).int() | |
| xymax =torch.round(face_locations + self.patch_size // 2).int() | |
| for f in range(n_faces): | |
| xmin, ymin = xymin[f,0], xymin[f,1] | |
| xmax, ymax = xymax[f,0], xymax[f,1] | |
| if xmin+pad < 0 or xmax-pad >= width or ymin+pad < 0 or ymax-pad >= height: | |
| continue | |
| canvas[0,:,ymin+pad:ymax+pad,xmin+pad:xmax+pad] += patch_prompt_embeds[0,f] | |
| condition_image = control_image + canvas[:,:,pad:-pad,pad:-pad] | |
| if seed >= 0: | |
| generator = torch.Generator(self.torch_device).manual_seed(seed) | |
| else: | |
| generator = None | |
| output_image = self.pipe( | |
| prompt_embeds=text_id_embeds, | |
| pooled_prompt_embeds=pooled_embeds, | |
| negative_prompt_embeds=negative_text_id_embeds, | |
| negative_pooled_prompt_embeds=negative_pooled_embeds, | |
| image=condition_image, | |
| guidance_scale=guidance_scale, | |
| controlnet_conditioning_scale=controlnet_conditioning_scale, | |
| num_inference_steps=num_inference_steps, | |
| id_injection_ratio=id_injection_ratio, | |
| output_type='pil', | |
| generator=generator, | |
| ).images[0] | |
| return output_image |