| from diffusers import DDIMScheduler |
| import torchvision.transforms.functional as TF |
|
|
| import numpy as np |
| from PIL import Image |
| import matplotlib.pyplot as plt |
| import torch |
| import torch.nn as nn |
| import torchvision |
| from torchvision.utils import save_image |
| from torchvision import transforms |
| import torch.nn.functional as F |
| from einops import rearrange |
|
|
| import sys |
| sys.path.append('./') |
|
|
| from sparseags.guidance_utils.zero123 import Zero123Pipeline |
|
|
|
|
| name_mapping = { |
| "model.diffusion_model.input_blocks.1.1.": "down_blocks.0.attentions.0.", |
| "model.diffusion_model.input_blocks.2.1.": "down_blocks.0.attentions.1.", |
| "model.diffusion_model.input_blocks.4.1.": "down_blocks.1.attentions.0.", |
| "model.diffusion_model.input_blocks.5.1.": "down_blocks.1.attentions.1.", |
| "model.diffusion_model.input_blocks.7.1.": "down_blocks.2.attentions.0.", |
| "model.diffusion_model.input_blocks.8.1.": "down_blocks.2.attentions.1.", |
| "model.diffusion_model.middle_block.1.": "mid_block.attentions.0.", |
| "model.diffusion_model.output_blocks.3.1.": "up_blocks.1.attentions.0.", |
| "model.diffusion_model.output_blocks.4.1.": "up_blocks.1.attentions.1.", |
| "model.diffusion_model.output_blocks.5.1.": "up_blocks.1.attentions.2.", |
| "model.diffusion_model.output_blocks.6.1.": "up_blocks.2.attentions.0.", |
| "model.diffusion_model.output_blocks.7.1.": "up_blocks.2.attentions.1.", |
| "model.diffusion_model.output_blocks.8.1.": "up_blocks.2.attentions.2.", |
| "model.diffusion_model.output_blocks.9.1.": "up_blocks.3.attentions.0.", |
| "model.diffusion_model.output_blocks.10.1.": "up_blocks.3.attentions.1.", |
| "model.diffusion_model.output_blocks.11.1.": "up_blocks.3.attentions.2.", |
| } |
|
|
| class Zero123(nn.Module): |
| def __init__(self, device, fp16=True, t_range=[0.02, 0.98], model_key="ashawkey/zero123-xl-diffusers"): |
| super().__init__() |
|
|
| self.device = device |
| self.fp16 = fp16 |
| self.dtype = torch.float16 if fp16 else torch.float32 |
|
|
| self.pipe = Zero123Pipeline.from_pretrained( |
| model_key, |
| trust_remote_code=True, |
| torch_dtype=self.dtype, |
| ).to(self.device) |
|
|
| |
| ckpt_path = "checkpoints/zero123_6dof_23k.ckpt" |
| print(f'[INFO] loading checkpoint from {ckpt_path} ...') |
| old_state = torch.load(ckpt_path) |
| pretrained_weights = old_state['state_dict']['cc_projection.weight'] |
| pretrained_biases = old_state['state_dict']['cc_projection.bias'] |
| linear_layer = torch.nn.Linear(768 + 18, 768) |
| linear_layer.weight.data = pretrained_weights |
| linear_layer.bias.data = pretrained_biases |
| self.pipe.clip_camera_projection.proj = linear_layer.to(dtype=self.dtype, device=self.device) |
|
|
| for name in list(old_state['state_dict'].keys()): |
| for k, v in name_mapping.items(): |
| if k in name: |
| old_state['state_dict'][name.replace(k, name_mapping[k])] = old_state['state_dict'][name].to(dtype=self.dtype, device=self.device) |
|
|
| m, u = self.pipe.unet.load_state_dict(old_state['state_dict'], strict=False) |
|
|
| |
| self.use_stable_zero123 = 'stable' in model_key |
|
|
| self.pipe.image_encoder.eval() |
| self.pipe.vae.eval() |
| self.pipe.unet.eval() |
| self.pipe.clip_camera_projection.eval() |
|
|
| self.vae = self.pipe.vae |
| self.unet = self.pipe.unet |
|
|
| self.pipe.set_progress_bar_config(disable=True) |
|
|
| self.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config) |
| self.num_train_timesteps = self.scheduler.config.num_train_timesteps |
|
|
| self.min_step = int(self.num_train_timesteps * t_range[0]) |
| self.max_step = int(self.num_train_timesteps * t_range[1]) |
| self.alphas = self.scheduler.alphas_cumprod.to(self.device) |
|
|
| self.embeddings = None |
|
|
| @torch.no_grad() |
| def get_img_embeds(self, x): |
| |
| x = F.interpolate(x, (256, 256), mode='bilinear', align_corners=False) |
| x_pil = [TF.to_pil_image(image) for image in x] |
| x_clip = self.pipe.feature_extractor(images=x_pil, return_tensors="pt").pixel_values.to(device=self.device, dtype=self.dtype) |
| c = self.pipe.image_encoder(x_clip).image_embeds |
| v = self.encode_imgs(x.to(self.dtype)) / self.vae.config.scaling_factor |
| self.embeddings = [c, v] |
|
|
| def get_cam_embeddings(self, polar, azimuth, radius, default_elevation=0): |
| if self.use_stable_zero123: |
| T = np.stack([np.deg2rad(polar), np.sin(np.deg2rad(azimuth)), np.cos(np.deg2rad(azimuth)), np.deg2rad([90 + default_elevation] * len(polar))], axis=-1) |
| else: |
| |
| T = np.stack([np.deg2rad(polar), np.sin(np.deg2rad(azimuth)), np.cos(np.deg2rad(azimuth)), radius], axis=-1) |
| T = torch.from_numpy(T).unsqueeze(1).to(dtype=self.dtype, device=self.device) |
| return T |
|
|
| def get_cam_embeddings_6D(self, target_RT, cond_RT): |
| T_target = torch.from_numpy(target_RT["c2w"]) |
| focal_len_target = torch.from_numpy(target_RT["focal_length"]) |
|
|
| T_cond = torch.from_numpy(cond_RT["c2w"]) |
| focal_len_cond = torch.from_numpy(cond_RT["focal_length"]) |
| |
| focal_len = focal_len_target / focal_len_cond |
|
|
| d_T = torch.linalg.inv(T_target) @ T_cond |
| d_T = torch.cat([d_T.flatten(), torch.log(focal_len)]) |
| return d_T.unsqueeze(0).unsqueeze(0).to(dtype=self.dtype, device=self.device) |
|
|
| @torch.no_grad() |
| def refine(self, pred_rgb, cam_embed, |
| guidance_scale=5, steps=50, strength=0.8, idx=None |
| ): |
|
|
| |
| if pred_rgb is not None: |
| batch_size = pred_rgb.shape[0] |
| else: |
| batch_size = 1 |
|
|
| self.scheduler.set_timesteps(steps) |
|
|
| if strength == 0: |
| init_step = 0 |
| latents = torch.randn((1, 4, 32, 32), device=self.device, dtype=self.dtype) |
| else: |
| init_step = int(steps * strength) |
| pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode='bilinear', align_corners=False) |
| latents = self.encode_imgs(pred_rgb_256.to(self.dtype)) |
| latents = self.scheduler.add_noise(latents, torch.randn_like(latents), self.scheduler.timesteps[init_step]) |
|
|
| T = cam_embed |
| if idx is not None: |
| cc_emb = torch.cat([self.embeddings[0][idx].repeat(batch_size, 1, 1), T], dim=-1) |
| else: |
| cc_emb = torch.cat([self.embeddings[0].repeat(batch_size, 1, 1), T], dim=-1) |
| cc_emb = self.pipe.clip_camera_projection(cc_emb) |
| cc_emb = torch.cat([cc_emb, torch.zeros_like(cc_emb)], dim=0) |
|
|
| if idx is not None: |
| vae_emb = self.embeddings[1][idx].repeat(batch_size, 1, 1, 1) |
| else: |
| vae_emb = self.embeddings[1].repeat(batch_size, 1, 1, 1) |
| vae_emb = torch.cat([vae_emb, torch.zeros_like(vae_emb)], dim=0) |
|
|
| for i, t in enumerate(self.scheduler.timesteps[init_step:]): |
| |
| x_in = torch.cat([latents] * 2) |
| t_in = torch.cat([t.view(1)]).to(self.device) |
|
|
| noise_pred = self.unet( |
| torch.cat([x_in, vae_emb], dim=1), |
| t_in.to(self.unet.dtype), |
| encoder_hidden_states=cc_emb, |
| ).sample |
|
|
| noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2) |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) |
| |
| latents = self.scheduler.step(noise_pred, t, latents).prev_sample |
|
|
| imgs = self.decode_latents(latents) |
| return imgs |
| |
| def train_step(self, pred_rgb, polar, azimuth, radius, step_ratio=None, guidance_scale=5, as_latent=False): |
| |
|
|
| batch_size = pred_rgb.shape[0] |
|
|
| if as_latent: |
| latents = F.interpolate(pred_rgb, (32, 32), mode='bilinear', align_corners=False) * 2 - 1 |
| else: |
| pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode='bilinear', align_corners=False) |
| latents = self.encode_imgs(pred_rgb_256.to(self.dtype)) |
|
|
| if step_ratio is not None: |
| |
| |
| t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step) |
| t = torch.full((batch_size,), t, dtype=torch.long, device=self.device) |
| else: |
| t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device) |
|
|
| w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1) |
|
|
| with torch.no_grad(): |
| noise = torch.randn_like(latents) |
| latents_noisy = self.scheduler.add_noise(latents, noise, t) |
|
|
| x_in = torch.cat([latents_noisy] * 2) |
| t_in = torch.cat([t] * 2) |
|
|
| T = self.get_cam_embeddings(polar, azimuth, radius) |
| cc_emb = torch.cat([self.embeddings[0].repeat(batch_size, 1, 1), T], dim=-1) |
| cc_emb = self.pipe.clip_camera_projection(cc_emb) |
| cc_emb = torch.cat([cc_emb, torch.zeros_like(cc_emb)], dim=0) |
|
|
| vae_emb = self.embeddings[1].repeat(batch_size, 1, 1, 1) |
| vae_emb = torch.cat([vae_emb, torch.zeros_like(vae_emb)], dim=0) |
|
|
| noise_pred = self.unet( |
| torch.cat([x_in, vae_emb], dim=1), |
| t_in.to(self.unet.dtype), |
| encoder_hidden_states=cc_emb, |
| ).sample |
|
|
| noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2) |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) |
|
|
| grad = w * (noise_pred - noise) |
| grad = torch.nan_to_num(grad) |
|
|
| target = (latents - grad).detach() |
| loss = 0.5 * F.mse_loss(latents.float(), target, reduction='sum') |
|
|
| return loss |
|
|
| def angle_between(self, sph_v1, sph_v2): |
| def sph2cart(sv): |
| r, theta, phi = sv[0], sv[1], sv[2] |
| |
| return torch.tensor([r * torch.cos(theta) * torch.cos(phi), r * torch.cos(theta) * torch.sin(phi), r * torch.sin(theta)]) |
| def unit_vector(v): |
| return v / torch.linalg.norm(v) |
| def angle_between_2_sph(sv1, sv2): |
| v1, v2 = sph2cart(sv1), sph2cart(sv2) |
| v1_u, v2_u = unit_vector(v1), unit_vector(v2) |
| return torch.arccos(torch.clip(torch.dot(v1_u, v2_u), -1.0, 1.0)) |
| angles = torch.empty(len(sph_v1), len(sph_v2)) |
| for i, sv1 in enumerate(sph_v1): |
| for j, sv2 in enumerate(sph_v2): |
| angles[i][j] = angle_between_2_sph(sv1, sv2) |
| return angles |
|
|
| def batch_train_step(self, pred_rgb, target_RT, cond_cams, step_ratio=None, guidance_scale=5, as_latent=False, step=None): |
| |
|
|
| batch_size = pred_rgb.shape[0] |
|
|
| if as_latent: |
| latents = F.interpolate(pred_rgb, (32, 32), mode='bilinear', align_corners=False) * 2 - 1 |
| else: |
| pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode='bilinear', align_corners=False) |
| latents = self.encode_imgs(pred_rgb_256.to(self.dtype)) |
|
|
| if step_ratio is not None: |
| |
| |
| t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step) |
| t = torch.full((batch_size,), t, dtype=torch.long, device=self.device) |
| else: |
| t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device) |
|
|
| w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1) |
|
|
| with torch.no_grad(): |
| noise = torch.randn_like(latents) |
| latents_noisy = self.scheduler.add_noise(latents, noise, t) |
|
|
| x_in = torch.cat([latents_noisy] * 2 * self.num_views) |
| t_in = torch.cat([t] * 2 * self.num_views) |
|
|
| cc_embs = [] |
| vae_embs = [] |
| noise_preds = [] |
| for idx in range(self.num_views): |
| cond_RT = { |
| "c2w": cond_cams[idx].c2w, |
| "focal_length": cond_cams[idx].focal_length, |
| } |
| T = self.get_cam_embeddings_6D(target_RT, cond_RT) |
| cc_emb = torch.cat([self.embeddings[0][idx].repeat(batch_size, 1, 1), T], dim=-1) |
| cc_emb = self.pipe.clip_camera_projection(cc_emb) |
| cc_emb = torch.cat([cc_emb, torch.zeros_like(cc_emb)], dim=0) |
|
|
| vae_emb = self.embeddings[1][idx].repeat(batch_size, 1, 1, 1) |
| vae_emb = torch.cat([vae_emb, torch.zeros_like(vae_emb)], dim=0) |
|
|
| cc_embs.append(cc_emb) |
| vae_embs.append(vae_emb) |
|
|
| cc_emb = torch.cat(cc_embs, dim=0) |
| vae_emb = torch.cat(vae_embs, dim=0) |
| noise_pred = self.unet( |
| torch.cat([x_in, vae_emb], dim=1), |
| t_in.to(self.unet.dtype), |
| encoder_hidden_states=cc_emb, |
| ).sample |
|
|
| noise_pred_chunks = noise_pred.chunk(self.num_views) |
| for idx in range(self.num_views): |
| noise_pred_cond, noise_pred_uncond = noise_pred_chunks[idx][0], noise_pred_chunks[idx][1] |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) |
| noise_preds.append(noise_pred) |
|
|
| noise_pred = torch.stack(noise_preds).sum(dim=0) / len(noise_preds) |
|
|
| grad = w * (noise_pred - noise) |
| grad = torch.nan_to_num(grad) |
|
|
| target = (latents - grad).detach() |
| loss = 0.5 * F.mse_loss(latents.float(), target, reduction='sum') |
|
|
| return loss |
|
|
| def decode_latents(self, latents): |
| latents = 1 / self.vae.config.scaling_factor * latents |
|
|
| imgs = self.vae.decode(latents).sample |
| imgs = (imgs / 2 + 0.5).clamp(0, 1) |
|
|
| return imgs |
|
|
| def encode_imgs(self, imgs, mode=False): |
| |
|
|
| imgs = 2 * imgs - 1 |
|
|
| posterior = self.vae.encode(imgs).latent_dist |
| if mode: |
| latents = posterior.mode() |
| else: |
| latents = posterior.sample() |
| latents = latents * self.vae.config.scaling_factor |
|
|
| return latents |
|
|
|
|
| def process_im(im): |
| if im.shape[-1] == 3: |
| if self.bg_remover is None: |
| self.bg_remover = rembg.new_session() |
| im = rembg.remove(im, session=self.bg_remover) |
|
|
| im = im.astype(np.float32) / 255.0 |
|
|
| input_mask = im[..., 3:] |
| input_img = im[..., :3] * input_mask + (1 - input_mask) |
| input_img = input_img[..., ::-1].copy() |
| image = torch.from_numpy(input_img).permute(2, 0, 1).unsqueeze(0).contiguous().to(device) |
| image = F.interpolate(image, (256, 256), mode='bilinear', align_corners=False) |
|
|
| return image |
|
|
|
|
| def get_T_6d(target_RT, cond_RT, use_objaverse): |
| if use_objaverse: |
| new_row = torch.tensor([[0., 0., 0., 1.]]) |
|
|
| T_target = torch.from_numpy(target_RT) |
| T_target = torch.cat((T_target, new_row), dim=0) |
| T_target = torch.linalg.inv(T_target) |
| T_target[:3, :] = T_target[[1, 2, 0]] |
|
|
| T_cond = torch.from_numpy(cond_RT) |
| T_cond = torch.cat((T_cond, new_row), dim=0) |
| T_cond = torch.linalg.inv(T_cond) |
| T_cond[:3, :] = T_cond[[1, 2, 0]] |
|
|
| focal_len = torch.tensor([1., 1.]) |
|
|
| else: |
| T_target = torch.from_numpy(target_RT["c2w"]) |
| focal_len_target = torch.from_numpy(target_RT["focal_length"]) |
|
|
| T_cond = torch.from_numpy(cond_RT["c2w"]) |
| focal_len_cond = torch.from_numpy(cond_RT["focal_length"]) |
| |
| focal_len = focal_len_target / focal_len_cond |
|
|
| d_T = torch.linalg.inv(T_target) @ T_cond |
| d_T = torch.cat([d_T.flatten(), torch.log(focal_len)]) |
| return d_T |