Spaces:
Build error
Build error
| import sys | |
| import copy | |
| from typing import List | |
| import numpy as np | |
| import torch | |
| from einops import rearrange | |
| from omegaconf import OmegaConf | |
| from PIL import Image | |
| from pytorch_lightning import seed_everything | |
| from pytorch3d.renderer.cameras import PerspectiveCameras | |
| from pytorch3d.renderer import look_at_view_transform | |
| from pytorch3d.renderer.camera_utils import join_cameras_as_batch | |
| import json | |
| sys.path.append('./custom-diffusion360/') | |
| from sgm.util import instantiate_from_config, load_safetensors | |
| choices = [] | |
| def load_base_model(config, ckpt=None, verbose=True): | |
| config = OmegaConf.load(config) | |
| # load model | |
| config.model.params.network_config.params.far = 3 | |
| config.model.params.first_stage_config.params.ckpt_path = "pretrained-models/sdxl_vae.safetensors" | |
| guider_config = {'target': 'sgm.modules.diffusionmodules.guiders.ScheduledCFGImgTextRef', | |
| 'params': {'scale': 7.5, 'scale_im': 3.5} | |
| } | |
| config.model.params.sampler_config.params.guider_config = guider_config | |
| model = instantiate_from_config(config.model) | |
| if ckpt is not None: | |
| print(f"Loading model from {ckpt}") | |
| if ckpt.endswith("ckpt"): | |
| pl_sd = torch.load(ckpt, map_location="cpu") | |
| if "global_step" in pl_sd: | |
| print(f"Global Step: {pl_sd['global_step']}") | |
| sd = pl_sd["state_dict"] | |
| elif ckpt.endswith("safetensors"): | |
| sd = load_safetensors(ckpt) | |
| if 'modifier_token' in config.data.params: | |
| del sd['conditioner.embedders.0.transformer.text_model.embeddings.token_embedding.weight'] | |
| del sd['conditioner.embedders.1.model.token_embedding.weight'] | |
| else: | |
| raise NotImplementedError | |
| m, u = model.load_state_dict(sd, strict=False) | |
| model.eval() | |
| return model | |
| def load_delta_model(model, delta_ckpt=None, verbose=True, freeze=True): | |
| """ | |
| model is preloaded base stable diffusion model | |
| """ | |
| msg = None | |
| if delta_ckpt is not None: | |
| pl_sd_delta = torch.load(delta_ckpt, map_location="cpu") | |
| sd_delta = pl_sd_delta["delta_state_dict"] | |
| # TODO: add new delta loading embedding stuff? | |
| for name, module in model.model.diffusion_model.named_modules(): | |
| if len(name.split('.')) > 1 and name.split('.')[-2] == 'transformer_blocks': | |
| if hasattr(module, 'pose_emb_layers'): | |
| module.register_buffer('references', sd_delta[f'model.diffusion_model.{name}.references']) | |
| del sd_delta[f'model.diffusion_model.{name}.references'] | |
| m, u = model.load_state_dict(sd_delta, strict=False) | |
| if len(m) > 0 and verbose: | |
| print("missing keys:") | |
| if len(u) > 0 and verbose: | |
| print("unexpected keys:") | |
| if freeze: | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| model.eval() | |
| return model, msg | |
| def get_unique_embedder_keys_from_conditioner(conditioner): | |
| p = [x.input_keys for x in conditioner.embedders] | |
| return list(set([item for sublist in p for item in sublist])) + ['jpg_ref'] | |
| def customforward(self, x, xr, context=None, contextr=None, pose=None, mask_ref=None, prev_weights=None, timesteps=None, drop_im=None): | |
| # note: if no context is given, cross-attention defaults to self-attention | |
| if not isinstance(context, list): | |
| context = [context] | |
| b, c, h, w = x.shape | |
| x_in = x | |
| fg_masks = [] | |
| alphas = [] | |
| rgbs = [] | |
| x = self.norm(x) | |
| if not self.use_linear: | |
| x = self.proj_in(x) | |
| x = rearrange(x, "b c h w -> b (h w) c").contiguous() | |
| if self.use_linear: | |
| x = self.proj_in(x) | |
| prev_weights = None | |
| counter = 0 | |
| for i, block in enumerate(self.transformer_blocks): | |
| if i > 0 and len(context) == 1: | |
| i = 0 # use same context for each block | |
| if self.image_cross and (counter % self.poscontrol_interval == 0): | |
| x, fg_mask, weights, alpha, rgb = block(x, context=context[i], context_ref=x, pose=pose, mask_ref=mask_ref, prev_weights=prev_weights, drop_im=drop_im) | |
| prev_weights = weights | |
| fg_masks.append(fg_mask) | |
| if alpha is not None: | |
| alphas.append(alpha) | |
| if rgb is not None: | |
| rgbs.append(rgb) | |
| else: | |
| x, _, _, _, _ = block(x, context=context[i], drop_im=drop_im) | |
| counter += 1 | |
| if self.use_linear: | |
| x = self.proj_out(x) | |
| x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous() | |
| if not self.use_linear: | |
| x = self.proj_out(x) | |
| if len(fg_masks) > 0: | |
| if len(rgbs) <= 0: | |
| rgbs = None | |
| if len(alphas) <= 0: | |
| alphas = None | |
| return x + x_in, None, fg_masks, prev_weights, alphas, rgbs | |
| else: | |
| return x + x_in, None, None, prev_weights, None, None | |
| def _customforward( | |
| self, x, context=None, context_ref=None, pose=None, mask_ref=None, prev_weights=None, drop_im=None, additional_tokens=None, n_times_crossframe_attn_in_self=0 | |
| ): | |
| if context_ref is not None: | |
| global choices | |
| batch_size = x.size(0) | |
| # IP2P like sampling or default sampling | |
| if batch_size % 3 == 0: | |
| batch_size = batch_size // 3 | |
| context_ref = torch.stack([self.references[:-1][y] for y in choices]).unsqueeze(0).expand(batch_size, -1, -1, -1) | |
| context_ref = torch.cat([self.references[-1:].unsqueeze(0).expand(batch_size, context_ref.size(1), -1, -1), context_ref, context_ref], dim=0) | |
| else: | |
| batch_size = batch_size // 2 | |
| context_ref = torch.stack([self.references[:-1][y] for y in choices]).unsqueeze(0).expand(batch_size, -1, -1, -1) | |
| context_ref = torch.cat([self.references[-1:].unsqueeze(0).expand(batch_size, context_ref.size(1), -1, -1), context_ref], dim=0) | |
| fg_mask = None | |
| weights = None | |
| alphas = None | |
| predicted_rgb = None | |
| x = ( | |
| self.attn1( | |
| self.norm1(x), | |
| context=context if self.disable_self_attn else None, | |
| additional_tokens=additional_tokens, | |
| n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self | |
| if not self.disable_self_attn | |
| else 0, | |
| ) | |
| + x | |
| ) | |
| x = ( | |
| self.attn2( | |
| self.norm2(x), context=context, additional_tokens=additional_tokens, | |
| ) | |
| + x | |
| ) | |
| if context_ref is not None: | |
| if self.rendered_feat is not None: | |
| x = self.pose_emb_layers(torch.cat([x, self.rendered_feat], dim=-1)) | |
| else: | |
| xref, fg_mask, weights, alphas, predicted_rgb = self.reference_attn(x, | |
| context_ref, | |
| context, | |
| pose, | |
| prev_weights, | |
| mask_ref) | |
| self.rendered_feat = xref | |
| x = self.pose_emb_layers(torch.cat([x, xref], -1)) | |
| x = self.ff(self.norm3(x)) + x | |
| return x, fg_mask, weights, alphas, predicted_rgb | |
| def log_images( | |
| model, | |
| batch, | |
| N: int = 1, | |
| noise=None, | |
| scale_im=3.5, | |
| num_steps: int = 10, | |
| ucg_keys: List[str] = None, | |
| **kwargs, | |
| ): | |
| log = dict() | |
| conditioner_input_keys = [e.input_keys for e in model.conditioner.embedders] | |
| ucg_keys = conditioner_input_keys | |
| pose = batch['pose'] | |
| c, uc = model.conditioner.get_unconditional_conditioning( | |
| batch, | |
| force_uc_zero_embeddings=ucg_keys | |
| if len(model.conditioner.embedders) > 0 | |
| else [], | |
| force_ref_zero_embeddings=True | |
| ) | |
| _, n = 1, len(pose)-1 | |
| sampling_kwargs = {} | |
| if scale_im > 0: | |
| if uc is not None: | |
| if isinstance(pose, list): | |
| pose = pose[:N]*3 | |
| else: | |
| pose = torch.cat([pose[:N]] * 3) | |
| else: | |
| if uc is not None: | |
| if isinstance(pose, list): | |
| pose = pose[:N]*2 | |
| else: | |
| pose = torch.cat([pose[:N]] * 2) | |
| sampling_kwargs['pose'] = pose | |
| sampling_kwargs['drop_im'] = None | |
| sampling_kwargs['mask_ref'] = None | |
| for k in c: | |
| if isinstance(c[k], torch.Tensor): | |
| c[k], uc[k] = map(lambda y: y[k][:(n+1)*N].to('cuda'), (c, uc)) | |
| import time | |
| st = time.time() | |
| with model.ema_scope("Plotting"): | |
| samples = model.sample( | |
| c, shape=noise.shape[1:], uc=uc, batch_size=N, num_steps=num_steps, noise=noise, **sampling_kwargs | |
| ) | |
| model.clear_rendered_feat() | |
| samples = model.decode_first_stage(samples) | |
| print("Time taken for sampling", time.time() - st) | |
| log["samples"] = samples.cpu() | |
| return log | |
| def process_camera_json(camera_json, example_cam): | |
| # replace all single quotes in the camera_json with quotes quotes | |
| camera_json = camera_json.replace("'", "\"") | |
| print("input camera json") | |
| print(camera_json) | |
| camera_dict = json.loads(camera_json)["scene.camera"] | |
| eye = torch.tensor([camera_dict["eye"]["x"], camera_dict["eye"]["y"], camera_dict["eye"]["z"]], dtype=torch.float32).unsqueeze(0) | |
| up = torch.tensor([camera_dict["up"]["x"], camera_dict["up"]["y"], camera_dict["up"]["z"]], dtype=torch.float32).unsqueeze(0) | |
| center = torch.tensor([camera_dict["center"]["x"], camera_dict["center"]["y"], camera_dict["center"]["z"]], dtype=torch.float32).unsqueeze(0) | |
| new_R, new_T = look_at_view_transform(eye=eye, at=center, up=up) | |
| print("focal length", example_cam.focal_length) | |
| print("principal point", example_cam.principal_point) | |
| newcam = PerspectiveCameras(R=new_R, | |
| T=new_T, | |
| focal_length=example_cam.focal_length, | |
| principal_point=example_cam.principal_point, | |
| image_size=512) | |
| print("input pose") | |
| print(newcam.get_world_to_view_transform().get_matrix()) | |
| return newcam | |
| def load_and_return_model_and_data(config, model, | |
| ckpt="pretrained-models/sd_xl_base_1.0.safetensors", | |
| delta_ckpt=None, | |
| train=False, | |
| valid=False, | |
| far=3, | |
| num_images=1, | |
| num_ref=8, | |
| max_images=20, | |
| ): | |
| config = OmegaConf.load(config) | |
| # load data | |
| data = None | |
| # config.data.params.jitter = False | |
| # config.data.params.addreg = False | |
| # config.data.params.bbox = False | |
| # data = instantiate_from_config(config.data) | |
| # data = data.train_dataset | |
| # single_id = data.single_id | |
| # if hasattr(data, 'rotations'): | |
| # total_images = len(data.rotations[data.sequence_list[single_id]]) | |
| # else: | |
| # total_images = len(data.annotations['chair']) | |
| # print(f"Total images in dataset: {total_images}") | |
| model, msg = load_delta_model(model, delta_ckpt,) | |
| model = model.cuda() | |
| # change forward methods to store rendered features and use the pre-calculated reference features | |
| def register_recr(net_): | |
| if net_.__class__.__name__ == 'SpatialTransformer': | |
| print(net_.__class__.__name__, "adding control") | |
| bound_method = customforward.__get__(net_, net_.__class__) | |
| setattr(net_, 'forward', bound_method) | |
| return | |
| elif hasattr(net_, 'children'): | |
| for net__ in net_.children(): | |
| register_recr(net__) | |
| return | |
| def register_recr2(net_): | |
| if net_.__class__.__name__ == 'BasicTransformerBlock': | |
| print(net_.__class__.__name__, "adding control") | |
| bound_method = _customforward.__get__(net_, net_.__class__) | |
| setattr(net_, 'forward', bound_method) | |
| return | |
| elif hasattr(net_, 'children'): | |
| for net__ in net_.children(): | |
| register_recr2(net__) | |
| return | |
| sub_nets = model.model.diffusion_model.named_children() | |
| for net in sub_nets: | |
| register_recr(net[1]) | |
| register_recr2(net[1]) | |
| # start sampling | |
| model.clear_rendered_feat() | |
| return model, data | |
| def sample(model, data, | |
| num_images=1, | |
| prompt="", | |
| appendpath="", | |
| camera_json=None, | |
| train=False, | |
| scale=7.5, | |
| scale_im=3.5, | |
| beta=1.0, | |
| num_ref=8, | |
| skipreflater=False, | |
| num_steps=10, | |
| valid=False, | |
| max_images=20, | |
| seed=42, | |
| camera_path="pretrained-models/car0/camera.bin", | |
| ): | |
| """ | |
| Only works with num_images=1 (because of camera_json processing) | |
| """ | |
| if num_images != 1: | |
| print("forcing num_images to be 1") | |
| num_images = 1 | |
| # set guidance scales | |
| model.sampler.guider.scale_im = scale_im | |
| model.sampler.guider.scale = scale | |
| seed_everything(seed) | |
| # load cameras | |
| cameras_val, cameras_train = torch.load(camera_path) | |
| global choices | |
| num_ref = 8 | |
| max_diff = len(cameras_train)/num_ref | |
| choices = [int(x) for x in torch.linspace(0, len(cameras_train) - max_diff, num_ref)] | |
| cameras_train_final = [cameras_train[i] for i in choices] | |
| # start sampling | |
| model.clear_rendered_feat() | |
| if prompt == "": | |
| prompt = None | |
| noise = torch.randn(1, 4, 64, 64).to('cuda').repeat(num_images, 1, 1, 1) | |
| # random sample camera poses | |
| pose_ids = np.random.choice(len(cameras_val), num_images, replace=False) | |
| print(pose_ids) | |
| pose_ids[0] = 21 | |
| pose = [cameras_val[i] for i in pose_ids] | |
| print("example camera") | |
| print(pose[0].R) | |
| print(pose[0].T) | |
| print(pose[0].focal_length) | |
| print(pose[0].principal_point) | |
| # prepare batches [if translating then call required functions on the target pose] | |
| batches = [] | |
| for i in range(num_images): | |
| batch = {'pose': [pose[i]] + cameras_train_final, | |
| "original_size_as_tuple": torch.tensor([512, 512]).reshape(-1, 2), | |
| "target_size_as_tuple": torch.tensor([512, 512]).reshape(-1, 2), | |
| "crop_coords_top_left": torch.tensor([0, 0]).reshape(-1, 2), | |
| "original_size_as_tuple_ref": torch.tensor([512, 512]).reshape(-1, 2), | |
| "target_size_as_tuple_ref": torch.tensor([512, 512]).reshape(-1, 2), | |
| "crop_coords_top_left_ref": torch.tensor([0, 0]).reshape(-1, 2), | |
| } | |
| batch_ = copy.deepcopy(batch) | |
| batch_["pose"][0] = process_camera_json(camera_json, pose[0]) | |
| batch_["pose"] = [join_cameras_as_batch(batch_["pose"])] | |
| # print('batched') | |
| # print(batch_["pose"][0].get_world_to_view_transform().get_matrix()) | |
| batches.append(batch_) | |
| print(f'len batches: {len(batches)}') | |
| image = None | |
| with torch.no_grad(): | |
| for batch in batches: | |
| for key in batch.keys(): | |
| if isinstance(batch[key], torch.Tensor): | |
| batch[key] = batch[key].to('cuda') | |
| elif 'pose' in key: | |
| batch[key] = [x.to('cuda') for x in batch[key]] | |
| else: | |
| pass | |
| if prompt is not None: | |
| batch["txt"] = [prompt for _ in range(1)] | |
| batch["txt_ref"] = [prompt for _ in range(len(batch["pose"])-1)] | |
| print(batch["txt"]) | |
| N = 1 | |
| log_ = log_images(model, batch, N=N, noise=noise.clone()[:N], num_steps=num_steps, scale_im=scale_im) | |
| image = log_["samples"] | |
| torch.cuda.empty_cache() | |
| model.clear_rendered_feat() | |
| print("generation done") | |
| return image | |