Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| import os | |
| from tqdm import tqdm | |
| import torch | |
| from loguru import logger | |
| from functools import wraps | |
| from torch.utils._pytree import tree_map_only | |
| def set_attention_backend(): | |
| if torch.cuda.is_available(): | |
| gpu_name = torch.cuda.get_device_name(0) | |
| logger.info(f"GPU name is {gpu_name}") | |
| if "A100" in gpu_name or "H100" in gpu_name or "H200" in gpu_name: | |
| # logger.info("Use flash_attn") | |
| os.environ["ATTN_BACKEND"] = "flash_attn" | |
| os.environ["SPARSE_ATTN_BACKEND"] = "flash_attn" | |
| set_attention_backend() | |
| from typing import List, Union | |
| from hydra.utils import instantiate | |
| from omegaconf import OmegaConf | |
| import numpy as np | |
| from PIL import Image | |
| from sam3d_objects.pipeline import preprocess_utils | |
| from sam3d_objects.data.dataset.tdfy.img_and_mask_transforms import ( | |
| get_mask, | |
| ) | |
| from sam3d_objects.pipeline.inference_utils import ( | |
| get_pose_decoder, | |
| SLAT_MEAN, | |
| SLAT_STD, | |
| downsample_sparse_structure, | |
| prune_sparse_structure, | |
| ) | |
| from sam3d_objects.model.io import ( | |
| load_model_from_checkpoint, | |
| filter_and_remove_prefix_state_dict_fn, | |
| ) | |
| from sam3d_objects.model.backbone.tdfy_dit.modules import sparse as sp | |
| from sam3d_objects.model.backbone.tdfy_dit.utils import postprocessing_utils | |
| from safetensors.torch import load_file | |
| class InferencePipeline: | |
| def __init__( | |
| self, | |
| ss_generator_config_path, | |
| ss_generator_ckpt_path, | |
| slat_generator_config_path, | |
| slat_generator_ckpt_path, | |
| ss_decoder_config_path, | |
| ss_decoder_ckpt_path, | |
| slat_decoder_gs_config_path, | |
| slat_decoder_gs_ckpt_path, | |
| slat_decoder_mesh_config_path, | |
| slat_decoder_mesh_ckpt_path, | |
| slat_decoder_gs_4_config_path=None, | |
| slat_decoder_gs_4_ckpt_path=None, | |
| ss_encoder_config_path=None, | |
| ss_encoder_ckpt_path=None, | |
| decode_formats=["gaussian", "mesh"], | |
| dtype="bfloat16", | |
| pad_size=1.0, | |
| version="v0", | |
| device="cuda", | |
| ss_preprocessor=preprocess_utils.get_default_preprocessor(), | |
| slat_preprocessor=preprocess_utils.get_default_preprocessor(), | |
| ss_condition_input_mapping=["image"], | |
| slat_condition_input_mapping=["image"], | |
| pose_decoder_name="default", | |
| workspace_dir="", | |
| downsample_ss_dist=0, # the distance we use to downsample | |
| ss_inference_steps=25, | |
| ss_rescale_t=3, | |
| ss_cfg_strength=7, | |
| ss_cfg_interval=[0, 500], | |
| ss_cfg_strength_pm=0.0, | |
| slat_inference_steps=25, | |
| slat_rescale_t=3, | |
| slat_cfg_strength=5, | |
| slat_cfg_interval=[0, 500], | |
| rendering_engine: str = "nvdiffrast", # nvdiffrast OR pytorch3d, | |
| shape_model_dtype=None, | |
| compile_model=False, | |
| slat_mean=SLAT_MEAN, | |
| slat_std=SLAT_STD, | |
| ): | |
| self.rendering_engine = rendering_engine | |
| self.device = torch.device(device) | |
| self.compile_model = compile_model | |
| logger.info(f"self.device: {self.device}") | |
| logger.info(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', None)}") | |
| logger.info(f"Actually using GPU: {torch.cuda.current_device()}") | |
| with self.device: | |
| self.decode_formats = decode_formats | |
| self.pad_size = pad_size | |
| self.version = version | |
| self.ss_condition_input_mapping = ss_condition_input_mapping | |
| self.slat_condition_input_mapping = slat_condition_input_mapping | |
| self.workspace_dir = workspace_dir | |
| self.downsample_ss_dist = downsample_ss_dist | |
| self.ss_inference_steps = ss_inference_steps | |
| self.ss_rescale_t = ss_rescale_t | |
| self.ss_cfg_strength = ss_cfg_strength | |
| self.ss_cfg_interval = ss_cfg_interval | |
| self.ss_cfg_strength_pm = ss_cfg_strength_pm | |
| self.slat_inference_steps = slat_inference_steps | |
| self.slat_rescale_t = slat_rescale_t | |
| self.slat_cfg_strength = slat_cfg_strength | |
| self.slat_cfg_interval = slat_cfg_interval | |
| self.dtype = self._get_dtype(dtype) | |
| if shape_model_dtype is None: | |
| self.shape_model_dtype = self.dtype | |
| else: | |
| self.shape_model_dtype = self._get_dtype(shape_model_dtype) | |
| # Setup preprocessors | |
| self.pose_decoder = self.init_pose_decoder(ss_generator_config_path, pose_decoder_name) | |
| self.ss_preprocessor = self.init_ss_preprocessor(ss_preprocessor, ss_generator_config_path) | |
| self.slat_preprocessor = slat_preprocessor | |
| logger.info("Loading model weights...") | |
| ss_generator = self.init_ss_generator( | |
| ss_generator_config_path, ss_generator_ckpt_path | |
| ) | |
| slat_generator = self.init_slat_generator( | |
| slat_generator_config_path, slat_generator_ckpt_path | |
| ) | |
| ss_decoder = self.init_ss_decoder( | |
| ss_decoder_config_path, ss_decoder_ckpt_path | |
| ) | |
| ss_encoder = self.init_ss_encoder( | |
| ss_encoder_config_path, ss_encoder_ckpt_path | |
| ) | |
| slat_decoder_gs = self.init_slat_decoder_gs( | |
| slat_decoder_gs_config_path, slat_decoder_gs_ckpt_path | |
| ) | |
| slat_decoder_gs_4 = self.init_slat_decoder_gs( | |
| slat_decoder_gs_4_config_path, slat_decoder_gs_4_ckpt_path | |
| ) | |
| slat_decoder_mesh = self.init_slat_decoder_mesh( | |
| slat_decoder_mesh_config_path, slat_decoder_mesh_ckpt_path | |
| ) | |
| # Load conditioner embedder so that we only load it once | |
| ss_condition_embedder = self.init_ss_condition_embedder( | |
| ss_generator_config_path, ss_generator_ckpt_path | |
| ) | |
| slat_condition_embedder = self.init_slat_condition_embedder( | |
| slat_generator_config_path, slat_generator_ckpt_path | |
| ) | |
| self.condition_embedders = { | |
| "ss_condition_embedder": ss_condition_embedder, | |
| "slat_condition_embedder": slat_condition_embedder, | |
| } | |
| # override generator and condition embedder setting | |
| self.override_ss_generator_cfg_config( | |
| ss_generator, | |
| cfg_strength=ss_cfg_strength, | |
| inference_steps=ss_inference_steps, | |
| rescale_t=ss_rescale_t, | |
| cfg_interval=ss_cfg_interval, | |
| cfg_strength_pm=ss_cfg_strength_pm, | |
| ) | |
| self.override_slat_generator_cfg_config( | |
| slat_generator, | |
| cfg_strength=slat_cfg_strength, | |
| inference_steps=slat_inference_steps, | |
| rescale_t=slat_rescale_t, | |
| cfg_interval=slat_cfg_interval, | |
| ) | |
| self.models = torch.nn.ModuleDict( | |
| { | |
| "ss_generator": ss_generator, | |
| "slat_generator": slat_generator, | |
| "ss_encoder": ss_encoder, | |
| "ss_decoder": ss_decoder, | |
| "slat_decoder_gs": slat_decoder_gs, | |
| "slat_decoder_gs_4": slat_decoder_gs_4, | |
| "slat_decoder_mesh": slat_decoder_mesh, | |
| } | |
| ) | |
| logger.info("Loading model weights completed!") | |
| if self.compile_model: | |
| logger.info("Compiling model...") | |
| self._compile() | |
| logger.info("Model compilation completed!") | |
| self.slat_mean = torch.tensor(slat_mean) | |
| self.slat_std = torch.tensor(slat_std) | |
| def _compile(self): | |
| torch._dynamo.config.cache_size_limit = 64 | |
| torch._dynamo.config.accumulated_cache_size_limit = 2048 | |
| torch._dynamo.config.capture_scalar_outputs = True | |
| compile_mode = "max-autotune" | |
| logger.info(f"Compile mode {compile_mode}") | |
| def clone_output_wrapper(f): | |
| def wrapped(*args, **kwargs): | |
| outputs = f(*args, **kwargs) | |
| return tree_map_only( | |
| torch.Tensor, lambda t: t.clone() if t.is_cuda else t, outputs | |
| ) | |
| return wrapped | |
| self.embed_condition = clone_output_wrapper( | |
| torch.compile( | |
| self.embed_condition, | |
| mode=compile_mode, | |
| fullgraph=True, # _preprocess_input in dino is not compatible with fullgraph | |
| ) | |
| ) | |
| self.models["ss_generator"].reverse_fn.inner_forward = clone_output_wrapper( | |
| torch.compile( | |
| self.models["ss_generator"].reverse_fn.inner_forward, | |
| mode=compile_mode, | |
| fullgraph=True, | |
| ) | |
| ) | |
| self.models["ss_decoder"].forward = clone_output_wrapper( | |
| torch.compile( | |
| self.models["ss_decoder"].forward, | |
| mode=compile_mode, | |
| fullgraph=True, | |
| ) | |
| ) | |
| self._warmup() | |
| def _warmup(self, num_warmup_iters=3): | |
| test_image = np.ones((512, 512, 4), dtype=np.uint8) * 255 | |
| test_image[:, :, :3] = np.random.randint(0, 255, (512, 512, 3), dtype=np.uint8) | |
| image = Image.fromarray(test_image) | |
| mask = None | |
| image = self.merge_image_and_mask(image, mask) | |
| for _ in tqdm(range(num_warmup_iters)): | |
| ss_input_dict = self.preprocess_image(image, self.ss_preprocessor) | |
| slat_input_dict = self.preprocess_image(image, self.slat_preprocessor) | |
| ss_return_dict = self.sample_sparse_structure(ss_input_dict) | |
| coords = ss_return_dict["coords"] | |
| slat = self.sample_slat(slat_input_dict, coords) | |
| def instantiate_and_load_from_pretrained( | |
| self, | |
| config, | |
| ckpt_path, | |
| state_dict_fn=None, | |
| state_dict_key="state_dict", | |
| device="cuda", | |
| ): | |
| model = instantiate(config) | |
| if ckpt_path.endswith(".safetensors"): | |
| state_dict = load_file(ckpt_path, device="cuda") | |
| if state_dict_fn is not None: | |
| state_dict = state_dict_fn(state_dict) | |
| model.load_state_dict(state_dict, strict=False) | |
| model.eval() | |
| else: | |
| model = load_model_from_checkpoint( | |
| model, | |
| ckpt_path, | |
| strict=True, | |
| device="cpu", | |
| freeze=True, | |
| eval=True, | |
| state_dict_key=state_dict_key, | |
| state_dict_fn=state_dict_fn, | |
| ) | |
| model = model.to(device) | |
| return model | |
| def init_pose_decoder(self, ss_generator_config_path, pose_decoder_name): | |
| if pose_decoder_name is None: | |
| pose_decoder_name = OmegaConf.load(os.path.join(self.workspace_dir, ss_generator_config_path))["module"]["pose_target_convention"] | |
| logger.info(f"Using pose decoder: {pose_decoder_name}") | |
| return get_pose_decoder(pose_decoder_name) | |
| def init_ss_preprocessor(self, ss_preprocessor, ss_generator_config_path): | |
| if ss_preprocessor is not None: | |
| return ss_preprocessor | |
| config = OmegaConf.load(os.path.join(self.workspace_dir, ss_generator_config_path))["tdfy"]["val_preprocessor"] | |
| return instantiate(config) | |
| def init_ss_generator(self, ss_generator_config_path, ss_generator_ckpt_path): | |
| config = OmegaConf.load( | |
| os.path.join(self.workspace_dir, ss_generator_config_path) | |
| )["module"]["generator"]["backbone"] | |
| state_dict_prefix_func = filter_and_remove_prefix_state_dict_fn( | |
| "_base_models.generator." | |
| ) | |
| return self.instantiate_and_load_from_pretrained( | |
| config, | |
| os.path.join(self.workspace_dir, ss_generator_ckpt_path), | |
| state_dict_fn=state_dict_prefix_func, | |
| device=self.device, | |
| ) | |
| def init_slat_generator(self, slat_generator_config_path, slat_generator_ckpt_path): | |
| config = OmegaConf.load( | |
| os.path.join(self.workspace_dir, slat_generator_config_path) | |
| )["module"]["generator"]["backbone"] | |
| state_dict_prefix_func = filter_and_remove_prefix_state_dict_fn( | |
| "_base_models.generator." | |
| ) | |
| return self.instantiate_and_load_from_pretrained( | |
| config, | |
| os.path.join(self.workspace_dir, slat_generator_ckpt_path), | |
| state_dict_fn=state_dict_prefix_func, | |
| device=self.device, | |
| ) | |
| def init_ss_encoder(self, ss_encoder_config_path, ss_encoder_ckpt_path): | |
| if ss_encoder_ckpt_path is not None: | |
| # override to avoid problem loading | |
| config = OmegaConf.load( | |
| os.path.join(self.workspace_dir, ss_encoder_config_path) | |
| ) | |
| if "pretrained_ckpt_path" in config: | |
| del config["pretrained_ckpt_path"] | |
| return self.instantiate_and_load_from_pretrained( | |
| config, | |
| os.path.join(self.workspace_dir, ss_encoder_ckpt_path), | |
| device=self.device, | |
| state_dict_key=None, | |
| ) | |
| else: | |
| return None | |
| def init_ss_decoder(self, ss_decoder_config_path, ss_decoder_ckpt_path): | |
| # override to avoid problem loading | |
| config = OmegaConf.load( | |
| os.path.join(self.workspace_dir, ss_decoder_config_path) | |
| ) | |
| if "pretrained_ckpt_path" in config: | |
| del config["pretrained_ckpt_path"] | |
| return self.instantiate_and_load_from_pretrained( | |
| config, | |
| os.path.join(self.workspace_dir, ss_decoder_ckpt_path), | |
| device=self.device, | |
| state_dict_key=None, | |
| ) | |
| def init_slat_decoder_gs( | |
| self, slat_decoder_gs_config_path, slat_decoder_gs_ckpt_path | |
| ): | |
| if slat_decoder_gs_config_path is None: | |
| return None | |
| else: | |
| return self.instantiate_and_load_from_pretrained( | |
| OmegaConf.load( | |
| os.path.join(self.workspace_dir, slat_decoder_gs_config_path) | |
| ), | |
| os.path.join(self.workspace_dir, slat_decoder_gs_ckpt_path), | |
| device=self.device, | |
| state_dict_key=None, | |
| ) | |
| def init_slat_decoder_mesh( | |
| self, slat_decoder_mesh_config_path, slat_decoder_mesh_ckpt_path | |
| ): | |
| return self.instantiate_and_load_from_pretrained( | |
| OmegaConf.load( | |
| os.path.join(self.workspace_dir, slat_decoder_mesh_config_path) | |
| ), | |
| os.path.join(self.workspace_dir, slat_decoder_mesh_ckpt_path), | |
| device=self.device, | |
| state_dict_key=None, | |
| ) | |
| def init_ss_condition_embedder( | |
| self, ss_generator_config_path, ss_generator_ckpt_path | |
| ): | |
| conf = OmegaConf.load( | |
| os.path.join(self.workspace_dir, ss_generator_config_path) | |
| ) | |
| if "condition_embedder" in conf["module"]: | |
| return self.instantiate_and_load_from_pretrained( | |
| conf["module"]["condition_embedder"]["backbone"], | |
| os.path.join(self.workspace_dir, ss_generator_ckpt_path), | |
| state_dict_fn=filter_and_remove_prefix_state_dict_fn( | |
| "_base_models.condition_embedder." | |
| ), | |
| device=self.device, | |
| ) | |
| else: | |
| return None | |
| def init_slat_condition_embedder( | |
| self, slat_generator_config_path, slat_generator_ckpt_path | |
| ): | |
| return self.init_ss_condition_embedder( | |
| slat_generator_config_path, slat_generator_ckpt_path | |
| ) | |
| def override_ss_generator_cfg_config( | |
| self, | |
| ss_generator, | |
| cfg_strength=7, | |
| inference_steps=25, | |
| rescale_t=3, | |
| cfg_interval=[0, 500], | |
| cfg_strength_pm=0.0, | |
| ): | |
| # override generator setting | |
| ss_generator.inference_steps = inference_steps | |
| ss_generator.reverse_fn.strength = cfg_strength | |
| ss_generator.reverse_fn.interval = cfg_interval | |
| ss_generator.rescale_t = rescale_t | |
| ss_generator.reverse_fn.backbone.condition_embedder.normalize_images = True | |
| ss_generator.reverse_fn.unconditional_handling = "add_flag" | |
| ss_generator.reverse_fn.strength_pm = cfg_strength_pm | |
| logger.info( | |
| "ss_generator parameters: inference_steps={}, cfg_strength={}, cfg_interval={}, rescale_t={}, cfg_strength_pm={}", | |
| inference_steps, | |
| cfg_strength, | |
| cfg_interval, | |
| rescale_t, | |
| cfg_strength_pm, | |
| ) | |
| def override_slat_generator_cfg_config( | |
| self, | |
| slat_generator, | |
| cfg_strength=5, | |
| inference_steps=25, | |
| rescale_t=3, | |
| cfg_interval=[0, 500], | |
| ): | |
| slat_generator.inference_steps = inference_steps | |
| slat_generator.reverse_fn.strength = cfg_strength | |
| slat_generator.reverse_fn.interval = cfg_interval | |
| slat_generator.rescale_t = rescale_t | |
| logger.info( | |
| "slat_generator parameters: inference_steps={}, cfg_strength={}, cfg_interval={}, rescale_t={}", | |
| inference_steps, | |
| cfg_strength, | |
| cfg_interval, | |
| rescale_t, | |
| ) | |
| def run( | |
| self, | |
| image: Union[None, Image.Image, np.ndarray], | |
| mask: Union[None, Image.Image, np.ndarray] = None, | |
| seed=42, | |
| stage1_only=False, | |
| with_mesh_postprocess=True, | |
| with_texture_baking=True, | |
| use_vertex_color=False, | |
| stage1_inference_steps=None, | |
| stage2_inference_steps=None, | |
| use_stage1_distillation=False, | |
| use_stage2_distillation=False, | |
| decode_formats=None, | |
| ) -> dict: | |
| """ | |
| Parameters: | |
| - image (Image): The input image to be processed. | |
| - seed (int, optional): The random seed for reproducibility. Default is 42. | |
| - stage1_only (bool, optional): If True, only the sparse structure is sampled and returned. Default is False. | |
| - with_mesh_postprocess (bool, optional): If True, performs mesh post-processing. Default is True. | |
| - with_texture_baking (bool, optional): If True, applies texture baking to the 3D model. Default is True. | |
| Returns: | |
| - dict: A dictionary containing the GLB file and additional data from the sparse structure sampling. | |
| """ | |
| # This should only happen if called from demo | |
| image = self.merge_image_and_mask(image, mask) | |
| with self.device: | |
| ss_input_dict = self.preprocess_image(image, self.ss_preprocessor) | |
| slat_input_dict = self.preprocess_image(image, self.slat_preprocessor) | |
| torch.manual_seed(seed) | |
| ss_return_dict = self.sample_sparse_structure( | |
| ss_input_dict, | |
| inference_steps=stage1_inference_steps, | |
| use_distillation=use_stage1_distillation, | |
| ) | |
| ss_return_dict.update(self.pose_decoder(ss_return_dict)) | |
| if "scale" in ss_return_dict: | |
| logger.info(f"Rescaling scale by {ss_return_dict['downsample_factor']}") | |
| ss_return_dict["scale"] = ss_return_dict["scale"] * ss_return_dict["downsample_factor"] | |
| if stage1_only: | |
| logger.info("Finished!") | |
| ss_return_dict["voxel"] = ss_return_dict["coords"][:, 1:] / 64 - 0.5 | |
| return ss_return_dict | |
| coords = ss_return_dict["coords"] | |
| slat = self.sample_slat( | |
| slat_input_dict, | |
| coords, | |
| inference_steps=stage2_inference_steps, | |
| use_distillation=use_stage2_distillation, | |
| ) | |
| outputs = self.decode_slat( | |
| slat, self.decode_formats if decode_formats is None else decode_formats | |
| ) | |
| outputs = self.postprocess_slat_output( | |
| outputs, with_mesh_postprocess, with_texture_baking, use_vertex_color | |
| ) | |
| logger.info("Finished!") | |
| return { | |
| **ss_return_dict, | |
| **outputs, | |
| } | |
| def postprocess_slat_output( | |
| self, outputs, with_mesh_postprocess, with_texture_baking, use_vertex_color | |
| ): | |
| # GLB files can be extracted from the outputs | |
| logger.info( | |
| f"Postprocessing mesh with option with_mesh_postprocess {with_mesh_postprocess}, with_texture_baking {with_texture_baking}..." | |
| ) | |
| if "mesh" in outputs: | |
| glb = postprocessing_utils.to_glb( | |
| outputs["gaussian"][0], | |
| outputs["mesh"][0], | |
| # Optional parameters | |
| simplify=0.95, # Ratio of triangles to remove in the simplification process | |
| texture_size=1024, # Size of the texture used for the GLB | |
| verbose=False, | |
| with_mesh_postprocess=with_mesh_postprocess, | |
| with_texture_baking=with_texture_baking, | |
| use_vertex_color=use_vertex_color, | |
| rendering_engine=self.rendering_engine, | |
| ) | |
| # glb.export("sample.glb") | |
| else: | |
| glb = None | |
| outputs["glb"] = glb | |
| if "gaussian" in outputs: | |
| outputs["gs"] = outputs["gaussian"][0] | |
| if "gaussian_4" in outputs: | |
| outputs["gs_4"] = outputs["gaussian_4"][0] | |
| return outputs | |
| def merge_image_and_mask( | |
| self, | |
| image: Union[np.ndarray, Image.Image], | |
| mask: Union[None, np.ndarray, Image.Image], | |
| ): | |
| if mask is not None: | |
| if isinstance(image, Image.Image): | |
| image = np.array(image) | |
| mask = np.array(mask) | |
| if mask.ndim == 2: | |
| mask = mask[..., None] | |
| logger.info(f"Replacing alpha channel with the provided mask") | |
| assert mask.shape[:2] == image.shape[:2] | |
| image = np.concatenate([image[..., :3], mask], axis=-1) | |
| image = np.array(image) | |
| return image | |
| def decode_slat( | |
| self, | |
| slat: sp.SparseTensor, | |
| formats: List[str] = ["mesh", "gaussian"], | |
| ) -> dict: | |
| """ | |
| Decode the structured latent. | |
| Args: | |
| slat (sp.SparseTensor): The structured latent. | |
| formats (List[str]): The formats to decode the structured latent to. | |
| Returns: | |
| dict: The decoded structured latent. | |
| """ | |
| logger.info("Decoding sparse latent...") | |
| ret = {} | |
| with torch.no_grad(): | |
| if "mesh" in formats: | |
| ret["mesh"] = self.models["slat_decoder_mesh"](slat) | |
| if "gaussian" in formats: | |
| ret["gaussian"] = self.models["slat_decoder_gs"](slat) | |
| if "gaussian_4" in formats: | |
| ret["gaussian_4"] = self.models["slat_decoder_gs_4"](slat) | |
| # if "radiance_field" in formats: | |
| # ret["radiance_field"] = self.models["slat_decoder_rf"](slat) | |
| return ret | |
| def is_mm_dit(self, model_name="ss_generator"): | |
| return hasattr(self.models[model_name].reverse_fn.backbone, "latent_mapping") | |
| def embed_condition(self, condition_embedder, *args, **kwargs): | |
| if condition_embedder is not None: | |
| tokens = condition_embedder(*args, **kwargs) | |
| return tokens, None, None | |
| return None, args, kwargs | |
| def get_condition_input(self, condition_embedder, input_dict, input_mapping): | |
| condition_args = self.map_input_keys(input_dict, input_mapping) | |
| condition_kwargs = { | |
| k: v for k, v in input_dict.items() if k not in input_mapping | |
| } | |
| logger.info("Running condition embedder ...") | |
| embedded_cond, condition_args, condition_kwargs = self.embed_condition( | |
| condition_embedder, *condition_args, **condition_kwargs | |
| ) | |
| logger.info("Condition embedder finishes!") | |
| if embedded_cond is not None: | |
| condition_args = (embedded_cond,) | |
| condition_kwargs = {} | |
| return condition_args, condition_kwargs | |
| def sample_sparse_structure( | |
| self, ss_input_dict: dict, inference_steps=None, use_distillation=False | |
| ): | |
| ss_generator = self.models["ss_generator"] | |
| ss_decoder = self.models["ss_decoder"] | |
| if use_distillation: | |
| ss_generator.no_shortcut = False | |
| ss_generator.reverse_fn.strength = 0 | |
| ss_generator.reverse_fn.strength_pm = 0 | |
| else: | |
| ss_generator.no_shortcut = True | |
| ss_generator.reverse_fn.strength = self.ss_cfg_strength | |
| ss_generator.reverse_fn.strength_pm = self.ss_cfg_strength_pm | |
| prev_inference_steps = ss_generator.inference_steps | |
| if inference_steps: | |
| ss_generator.inference_steps = inference_steps | |
| image = ss_input_dict["image"] | |
| bs = image.shape[0] | |
| logger.info( | |
| "Sampling sparse structure: inference_steps={}, strength={}, interval={}, rescale_t={}, cfg_strength_pm={}", | |
| ss_generator.inference_steps, | |
| ss_generator.reverse_fn.strength, | |
| ss_generator.reverse_fn.interval, | |
| ss_generator.rescale_t, | |
| ss_generator.reverse_fn.strength_pm, | |
| ) | |
| with torch.no_grad(): | |
| with torch.autocast(device_type="cuda", dtype=self.shape_model_dtype): | |
| if self.is_mm_dit(): | |
| latent_shape_dict = { | |
| k: (bs,) + (v.pos_emb.shape[0], v.input_layer.in_features) | |
| for k, v in ss_generator.reverse_fn.backbone.latent_mapping.items() | |
| } | |
| else: | |
| latent_shape_dict = (bs,) + (4096, 8) | |
| condition_args, condition_kwargs = self.get_condition_input( | |
| self.condition_embedders["ss_condition_embedder"], | |
| ss_input_dict, | |
| self.ss_condition_input_mapping, | |
| ) | |
| return_dict = ss_generator( | |
| latent_shape_dict, | |
| image.device, | |
| *condition_args, | |
| **condition_kwargs, | |
| ) | |
| if not self.is_mm_dit(): | |
| return_dict = {"shape": return_dict} | |
| shape_latent = return_dict["shape"] | |
| ss = ss_decoder( | |
| shape_latent.permute(0, 2, 1) | |
| .contiguous() | |
| .view(shape_latent.shape[0], 8, 16, 16, 16) | |
| ) | |
| coords = torch.argwhere(ss > 0)[:, [0, 2, 3, 4]].int() | |
| # downsample output | |
| return_dict["coords_original"] = coords | |
| original_shape = coords.shape | |
| if self.downsample_ss_dist > 0: | |
| coords = prune_sparse_structure( | |
| coords, | |
| max_neighbor_axes_dist=self.downsample_ss_dist, | |
| ) | |
| coords, downsample_factor = downsample_sparse_structure(coords) | |
| logger.info( | |
| f"Downsampled coords from {original_shape[0]} to {coords.shape[0]}" | |
| ) | |
| return_dict["coords"] = coords | |
| return_dict["downsample_factor"] = downsample_factor | |
| ss_generator.inference_steps = prev_inference_steps | |
| return return_dict | |
| def sample_slat( | |
| self, | |
| slat_input: dict, | |
| coords: torch.Tensor, | |
| inference_steps=25, | |
| use_distillation=False, | |
| ) -> sp.SparseTensor: | |
| image = slat_input["image"] | |
| DEVICE = image.device | |
| slat_generator = self.models["slat_generator"] | |
| latent_shape = (image.shape[0],) + (coords.shape[0], 8) | |
| prev_inference_steps = slat_generator.inference_steps | |
| if inference_steps: | |
| slat_generator.inference_steps = inference_steps | |
| if use_distillation: | |
| slat_generator.no_shortcut = False | |
| slat_generator.reverse_fn.strength = 0 | |
| else: | |
| slat_generator.no_shortcut = True | |
| slat_generator.reverse_fn.strength = self.slat_cfg_strength | |
| logger.info( | |
| "Sampling sparse latent: inference_steps={}, strength={}, interval={}, rescale_t={}", | |
| slat_generator.inference_steps, | |
| slat_generator.reverse_fn.strength, | |
| slat_generator.reverse_fn.interval, | |
| slat_generator.rescale_t, | |
| ) | |
| with torch.autocast(device_type="cuda", dtype=self.dtype): | |
| with torch.no_grad(): | |
| condition_args, condition_kwargs = self.get_condition_input( | |
| self.condition_embedders["slat_condition_embedder"], | |
| slat_input, | |
| self.slat_condition_input_mapping, | |
| ) | |
| condition_args += (coords.cpu().numpy(),) | |
| slat = slat_generator( | |
| latent_shape, DEVICE, *condition_args, **condition_kwargs | |
| ) | |
| slat = sp.SparseTensor( | |
| coords=coords, | |
| feats=slat[0], | |
| ).to(DEVICE) | |
| slat = slat * self.slat_std.to(DEVICE) + self.slat_mean.to(DEVICE) | |
| slat_generator.inference_steps = prev_inference_steps | |
| return slat | |
| def _apply_transform(self, input: torch.Tensor, transform): | |
| if input is not None: | |
| input = transform(input) | |
| return input | |
| def _preprocess_image_and_mask( | |
| self, rgb_image, mask_image, img_mask_joint_transform | |
| ): | |
| for trans in img_mask_joint_transform: | |
| rgb_image, mask_image = trans(rgb_image, mask_image) | |
| return rgb_image, mask_image | |
| def map_input_keys(self, item, condition_input_mapping): | |
| output = [item[k] for k in condition_input_mapping] | |
| return output | |
| def image_to_float(self, image): | |
| image = np.array(image) | |
| image = image / 255 | |
| image = image.astype(np.float32) | |
| return image | |
| def preprocess_image( | |
| self, image: Union[Image.Image, np.ndarray], preprocessor | |
| ) -> torch.Tensor: | |
| # canonical type is numpy | |
| if not isinstance(input, np.ndarray): | |
| image = np.array(image) | |
| assert image.ndim == 3 # no batch dimension as of now | |
| assert image.shape[-1] == 4 # rgba format | |
| assert image.dtype == np.uint8 # [0,255] range | |
| rgba_image = torch.from_numpy(self.image_to_float(image)) | |
| rgba_image = rgba_image.permute(2, 0, 1).contiguous() | |
| rgb_image = rgba_image[:3] | |
| rgb_image_mask = (get_mask(rgba_image, None, "ALPHA_CHANNEL") > 0).float() | |
| processed_rgb_image, processed_mask = self._preprocess_image_and_mask( | |
| rgb_image, rgb_image_mask, preprocessor.img_mask_joint_transform | |
| ) | |
| # transform tensor to model input | |
| processed_rgb_image = self._apply_transform( | |
| processed_rgb_image, preprocessor.img_transform | |
| ) | |
| processed_mask = self._apply_transform( | |
| processed_mask, preprocessor.mask_transform | |
| ) | |
| # full image, with only processing from the image | |
| rgb_image = self._apply_transform(rgb_image, preprocessor.img_transform) | |
| rgb_image_mask = self._apply_transform( | |
| rgb_image_mask, preprocessor.mask_transform | |
| ) | |
| item = { | |
| "mask": processed_mask[None].to(self.device), | |
| "image": processed_rgb_image[None].to(self.device), | |
| "rgb_image": rgb_image[None].to(self.device), | |
| "rgb_image_mask": rgb_image_mask[None].to(self.device), | |
| } | |
| return item | |
| def _get_dtype(dtype): | |
| if dtype == "bfloat16": | |
| return torch.bfloat16 | |
| elif dtype == "float16": | |
| return torch.float16 | |
| elif dtype == "float32": | |
| return torch.float32 | |
| else: | |
| raise NotImplementedError | |