# Copyright (c) Meta Platforms, Inc. and affiliates. import os from tqdm import tqdm import torch from loguru import logger from functools import wraps from torch.utils._pytree import tree_map_only def set_attention_backend(): if torch.cuda.is_available(): gpu_name = torch.cuda.get_device_name(0) logger.info(f"GPU name is {gpu_name}") if "A100" in gpu_name or "H100" in gpu_name or "H200" in gpu_name: # logger.info("Use flash_attn") os.environ["ATTN_BACKEND"] = "flash_attn" os.environ["SPARSE_ATTN_BACKEND"] = "flash_attn" set_attention_backend() from typing import List, Union from hydra.utils import instantiate from omegaconf import OmegaConf import numpy as np from PIL import Image from sam3d_objects.pipeline import preprocess_utils from sam3d_objects.data.dataset.tdfy.img_and_mask_transforms import ( get_mask, ) from sam3d_objects.pipeline.inference_utils import ( get_pose_decoder, SLAT_MEAN, SLAT_STD, downsample_sparse_structure, prune_sparse_structure, ) from sam3d_objects.model.io import ( load_model_from_checkpoint, filter_and_remove_prefix_state_dict_fn, ) from sam3d_objects.model.backbone.tdfy_dit.modules import sparse as sp from sam3d_objects.model.backbone.tdfy_dit.utils import postprocessing_utils from safetensors.torch import load_file class InferencePipeline: def __init__( self, ss_generator_config_path, ss_generator_ckpt_path, slat_generator_config_path, slat_generator_ckpt_path, ss_decoder_config_path, ss_decoder_ckpt_path, slat_decoder_gs_config_path, slat_decoder_gs_ckpt_path, slat_decoder_mesh_config_path, slat_decoder_mesh_ckpt_path, slat_decoder_gs_4_config_path=None, slat_decoder_gs_4_ckpt_path=None, ss_encoder_config_path=None, ss_encoder_ckpt_path=None, decode_formats=["gaussian", "mesh"], dtype="bfloat16", pad_size=1.0, version="v0", device="cuda", ss_preprocessor=preprocess_utils.get_default_preprocessor(), slat_preprocessor=preprocess_utils.get_default_preprocessor(), ss_condition_input_mapping=["image"], slat_condition_input_mapping=["image"], pose_decoder_name="default", workspace_dir="", downsample_ss_dist=0, # the distance we use to downsample ss_inference_steps=25, ss_rescale_t=3, ss_cfg_strength=7, ss_cfg_interval=[0, 500], ss_cfg_strength_pm=0.0, slat_inference_steps=25, slat_rescale_t=3, slat_cfg_strength=5, slat_cfg_interval=[0, 500], rendering_engine: str = "nvdiffrast", # nvdiffrast OR pytorch3d, shape_model_dtype=None, compile_model=False, slat_mean=SLAT_MEAN, slat_std=SLAT_STD, ): self.rendering_engine = rendering_engine self.device = torch.device(device) self.compile_model = compile_model logger.info(f"self.device: {self.device}") logger.info(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', None)}") logger.info(f"Actually using GPU: {torch.cuda.current_device()}") with self.device: self.decode_formats = decode_formats self.pad_size = pad_size self.version = version self.ss_condition_input_mapping = ss_condition_input_mapping self.slat_condition_input_mapping = slat_condition_input_mapping self.workspace_dir = workspace_dir self.downsample_ss_dist = downsample_ss_dist self.ss_inference_steps = ss_inference_steps self.ss_rescale_t = ss_rescale_t self.ss_cfg_strength = ss_cfg_strength self.ss_cfg_interval = ss_cfg_interval self.ss_cfg_strength_pm = ss_cfg_strength_pm self.slat_inference_steps = slat_inference_steps self.slat_rescale_t = slat_rescale_t self.slat_cfg_strength = slat_cfg_strength self.slat_cfg_interval = slat_cfg_interval self.dtype = self._get_dtype(dtype) if shape_model_dtype is None: self.shape_model_dtype = self.dtype else: self.shape_model_dtype = self._get_dtype(shape_model_dtype) # Setup preprocessors self.pose_decoder = self.init_pose_decoder(ss_generator_config_path, pose_decoder_name) self.ss_preprocessor = self.init_ss_preprocessor(ss_preprocessor, ss_generator_config_path) self.slat_preprocessor = slat_preprocessor logger.info("Loading model weights...") ss_generator = self.init_ss_generator( ss_generator_config_path, ss_generator_ckpt_path ) slat_generator = self.init_slat_generator( slat_generator_config_path, slat_generator_ckpt_path ) ss_decoder = self.init_ss_decoder( ss_decoder_config_path, ss_decoder_ckpt_path ) ss_encoder = self.init_ss_encoder( ss_encoder_config_path, ss_encoder_ckpt_path ) slat_decoder_gs = self.init_slat_decoder_gs( slat_decoder_gs_config_path, slat_decoder_gs_ckpt_path ) slat_decoder_gs_4 = self.init_slat_decoder_gs( slat_decoder_gs_4_config_path, slat_decoder_gs_4_ckpt_path ) slat_decoder_mesh = self.init_slat_decoder_mesh( slat_decoder_mesh_config_path, slat_decoder_mesh_ckpt_path ) # Load conditioner embedder so that we only load it once ss_condition_embedder = self.init_ss_condition_embedder( ss_generator_config_path, ss_generator_ckpt_path ) slat_condition_embedder = self.init_slat_condition_embedder( slat_generator_config_path, slat_generator_ckpt_path ) self.condition_embedders = { "ss_condition_embedder": ss_condition_embedder, "slat_condition_embedder": slat_condition_embedder, } # override generator and condition embedder setting self.override_ss_generator_cfg_config( ss_generator, cfg_strength=ss_cfg_strength, inference_steps=ss_inference_steps, rescale_t=ss_rescale_t, cfg_interval=ss_cfg_interval, cfg_strength_pm=ss_cfg_strength_pm, ) self.override_slat_generator_cfg_config( slat_generator, cfg_strength=slat_cfg_strength, inference_steps=slat_inference_steps, rescale_t=slat_rescale_t, cfg_interval=slat_cfg_interval, ) self.models = torch.nn.ModuleDict( { "ss_generator": ss_generator, "slat_generator": slat_generator, "ss_encoder": ss_encoder, "ss_decoder": ss_decoder, "slat_decoder_gs": slat_decoder_gs, "slat_decoder_gs_4": slat_decoder_gs_4, "slat_decoder_mesh": slat_decoder_mesh, } ) logger.info("Loading model weights completed!") if self.compile_model: logger.info("Compiling model...") self._compile() logger.info("Model compilation completed!") self.slat_mean = torch.tensor(slat_mean) self.slat_std = torch.tensor(slat_std) def _compile(self): torch._dynamo.config.cache_size_limit = 64 torch._dynamo.config.accumulated_cache_size_limit = 2048 torch._dynamo.config.capture_scalar_outputs = True compile_mode = "max-autotune" logger.info(f"Compile mode {compile_mode}") def clone_output_wrapper(f): @wraps(f) def wrapped(*args, **kwargs): outputs = f(*args, **kwargs) return tree_map_only( torch.Tensor, lambda t: t.clone() if t.is_cuda else t, outputs ) return wrapped self.embed_condition = clone_output_wrapper( torch.compile( self.embed_condition, mode=compile_mode, fullgraph=True, # _preprocess_input in dino is not compatible with fullgraph ) ) self.models["ss_generator"].reverse_fn.inner_forward = clone_output_wrapper( torch.compile( self.models["ss_generator"].reverse_fn.inner_forward, mode=compile_mode, fullgraph=True, ) ) self.models["ss_decoder"].forward = clone_output_wrapper( torch.compile( self.models["ss_decoder"].forward, mode=compile_mode, fullgraph=True, ) ) self._warmup() def _warmup(self, num_warmup_iters=3): test_image = np.ones((512, 512, 4), dtype=np.uint8) * 255 test_image[:, :, :3] = np.random.randint(0, 255, (512, 512, 3), dtype=np.uint8) image = Image.fromarray(test_image) mask = None image = self.merge_image_and_mask(image, mask) for _ in tqdm(range(num_warmup_iters)): ss_input_dict = self.preprocess_image(image, self.ss_preprocessor) slat_input_dict = self.preprocess_image(image, self.slat_preprocessor) ss_return_dict = self.sample_sparse_structure(ss_input_dict) coords = ss_return_dict["coords"] slat = self.sample_slat(slat_input_dict, coords) def instantiate_and_load_from_pretrained( self, config, ckpt_path, state_dict_fn=None, state_dict_key="state_dict", device="cuda", ): model = instantiate(config) if ckpt_path.endswith(".safetensors"): state_dict = load_file(ckpt_path, device="cuda") if state_dict_fn is not None: state_dict = state_dict_fn(state_dict) model.load_state_dict(state_dict, strict=False) model.eval() else: model = load_model_from_checkpoint( model, ckpt_path, strict=True, device="cpu", freeze=True, eval=True, state_dict_key=state_dict_key, state_dict_fn=state_dict_fn, ) model = model.to(device) return model def init_pose_decoder(self, ss_generator_config_path, pose_decoder_name): if pose_decoder_name is None: pose_decoder_name = OmegaConf.load(os.path.join(self.workspace_dir, ss_generator_config_path))["module"]["pose_target_convention"] logger.info(f"Using pose decoder: {pose_decoder_name}") return get_pose_decoder(pose_decoder_name) def init_ss_preprocessor(self, ss_preprocessor, ss_generator_config_path): if ss_preprocessor is not None: return ss_preprocessor config = OmegaConf.load(os.path.join(self.workspace_dir, ss_generator_config_path))["tdfy"]["val_preprocessor"] return instantiate(config) def init_ss_generator(self, ss_generator_config_path, ss_generator_ckpt_path): config = OmegaConf.load( os.path.join(self.workspace_dir, ss_generator_config_path) )["module"]["generator"]["backbone"] state_dict_prefix_func = filter_and_remove_prefix_state_dict_fn( "_base_models.generator." ) return self.instantiate_and_load_from_pretrained( config, os.path.join(self.workspace_dir, ss_generator_ckpt_path), state_dict_fn=state_dict_prefix_func, device=self.device, ) def init_slat_generator(self, slat_generator_config_path, slat_generator_ckpt_path): config = OmegaConf.load( os.path.join(self.workspace_dir, slat_generator_config_path) )["module"]["generator"]["backbone"] state_dict_prefix_func = filter_and_remove_prefix_state_dict_fn( "_base_models.generator." ) return self.instantiate_and_load_from_pretrained( config, os.path.join(self.workspace_dir, slat_generator_ckpt_path), state_dict_fn=state_dict_prefix_func, device=self.device, ) def init_ss_encoder(self, ss_encoder_config_path, ss_encoder_ckpt_path): if ss_encoder_ckpt_path is not None: # override to avoid problem loading config = OmegaConf.load( os.path.join(self.workspace_dir, ss_encoder_config_path) ) if "pretrained_ckpt_path" in config: del config["pretrained_ckpt_path"] return self.instantiate_and_load_from_pretrained( config, os.path.join(self.workspace_dir, ss_encoder_ckpt_path), device=self.device, state_dict_key=None, ) else: return None def init_ss_decoder(self, ss_decoder_config_path, ss_decoder_ckpt_path): # override to avoid problem loading config = OmegaConf.load( os.path.join(self.workspace_dir, ss_decoder_config_path) ) if "pretrained_ckpt_path" in config: del config["pretrained_ckpt_path"] return self.instantiate_and_load_from_pretrained( config, os.path.join(self.workspace_dir, ss_decoder_ckpt_path), device=self.device, state_dict_key=None, ) def init_slat_decoder_gs( self, slat_decoder_gs_config_path, slat_decoder_gs_ckpt_path ): if slat_decoder_gs_config_path is None: return None else: return self.instantiate_and_load_from_pretrained( OmegaConf.load( os.path.join(self.workspace_dir, slat_decoder_gs_config_path) ), os.path.join(self.workspace_dir, slat_decoder_gs_ckpt_path), device=self.device, state_dict_key=None, ) def init_slat_decoder_mesh( self, slat_decoder_mesh_config_path, slat_decoder_mesh_ckpt_path ): return self.instantiate_and_load_from_pretrained( OmegaConf.load( os.path.join(self.workspace_dir, slat_decoder_mesh_config_path) ), os.path.join(self.workspace_dir, slat_decoder_mesh_ckpt_path), device=self.device, state_dict_key=None, ) def init_ss_condition_embedder( self, ss_generator_config_path, ss_generator_ckpt_path ): conf = OmegaConf.load( os.path.join(self.workspace_dir, ss_generator_config_path) ) if "condition_embedder" in conf["module"]: return self.instantiate_and_load_from_pretrained( conf["module"]["condition_embedder"]["backbone"], os.path.join(self.workspace_dir, ss_generator_ckpt_path), state_dict_fn=filter_and_remove_prefix_state_dict_fn( "_base_models.condition_embedder." ), device=self.device, ) else: return None def init_slat_condition_embedder( self, slat_generator_config_path, slat_generator_ckpt_path ): return self.init_ss_condition_embedder( slat_generator_config_path, slat_generator_ckpt_path ) def override_ss_generator_cfg_config( self, ss_generator, cfg_strength=7, inference_steps=25, rescale_t=3, cfg_interval=[0, 500], cfg_strength_pm=0.0, ): # override generator setting ss_generator.inference_steps = inference_steps ss_generator.reverse_fn.strength = cfg_strength ss_generator.reverse_fn.interval = cfg_interval ss_generator.rescale_t = rescale_t ss_generator.reverse_fn.backbone.condition_embedder.normalize_images = True ss_generator.reverse_fn.unconditional_handling = "add_flag" ss_generator.reverse_fn.strength_pm = cfg_strength_pm logger.info( "ss_generator parameters: inference_steps={}, cfg_strength={}, cfg_interval={}, rescale_t={}, cfg_strength_pm={}", inference_steps, cfg_strength, cfg_interval, rescale_t, cfg_strength_pm, ) def override_slat_generator_cfg_config( self, slat_generator, cfg_strength=5, inference_steps=25, rescale_t=3, cfg_interval=[0, 500], ): slat_generator.inference_steps = inference_steps slat_generator.reverse_fn.strength = cfg_strength slat_generator.reverse_fn.interval = cfg_interval slat_generator.rescale_t = rescale_t logger.info( "slat_generator parameters: inference_steps={}, cfg_strength={}, cfg_interval={}, rescale_t={}", inference_steps, cfg_strength, cfg_interval, rescale_t, ) def run( self, image: Union[None, Image.Image, np.ndarray], mask: Union[None, Image.Image, np.ndarray] = None, seed=42, stage1_only=False, with_mesh_postprocess=True, with_texture_baking=True, use_vertex_color=False, stage1_inference_steps=None, stage2_inference_steps=None, use_stage1_distillation=False, use_stage2_distillation=False, decode_formats=None, ) -> dict: """ Parameters: - image (Image): The input image to be processed. - seed (int, optional): The random seed for reproducibility. Default is 42. - stage1_only (bool, optional): If True, only the sparse structure is sampled and returned. Default is False. - with_mesh_postprocess (bool, optional): If True, performs mesh post-processing. Default is True. - with_texture_baking (bool, optional): If True, applies texture baking to the 3D model. Default is True. Returns: - dict: A dictionary containing the GLB file and additional data from the sparse structure sampling. """ # This should only happen if called from demo image = self.merge_image_and_mask(image, mask) with self.device: ss_input_dict = self.preprocess_image(image, self.ss_preprocessor) slat_input_dict = self.preprocess_image(image, self.slat_preprocessor) torch.manual_seed(seed) ss_return_dict = self.sample_sparse_structure( ss_input_dict, inference_steps=stage1_inference_steps, use_distillation=use_stage1_distillation, ) ss_return_dict.update(self.pose_decoder(ss_return_dict)) if "scale" in ss_return_dict: logger.info(f"Rescaling scale by {ss_return_dict['downsample_factor']}") ss_return_dict["scale"] = ss_return_dict["scale"] * ss_return_dict["downsample_factor"] if stage1_only: logger.info("Finished!") ss_return_dict["voxel"] = ss_return_dict["coords"][:, 1:] / 64 - 0.5 return ss_return_dict coords = ss_return_dict["coords"] slat = self.sample_slat( slat_input_dict, coords, inference_steps=stage2_inference_steps, use_distillation=use_stage2_distillation, ) outputs = self.decode_slat( slat, self.decode_formats if decode_formats is None else decode_formats ) outputs = self.postprocess_slat_output( outputs, with_mesh_postprocess, with_texture_baking, use_vertex_color ) logger.info("Finished!") return { **ss_return_dict, **outputs, } def postprocess_slat_output( self, outputs, with_mesh_postprocess, with_texture_baking, use_vertex_color ): # GLB files can be extracted from the outputs logger.info( f"Postprocessing mesh with option with_mesh_postprocess {with_mesh_postprocess}, with_texture_baking {with_texture_baking}..." ) if "mesh" in outputs: glb = postprocessing_utils.to_glb( outputs["gaussian"][0], outputs["mesh"][0], # Optional parameters simplify=0.95, # Ratio of triangles to remove in the simplification process texture_size=1024, # Size of the texture used for the GLB verbose=False, with_mesh_postprocess=with_mesh_postprocess, with_texture_baking=with_texture_baking, use_vertex_color=use_vertex_color, rendering_engine=self.rendering_engine, ) # glb.export("sample.glb") else: glb = None outputs["glb"] = glb if "gaussian" in outputs: outputs["gs"] = outputs["gaussian"][0] if "gaussian_4" in outputs: outputs["gs_4"] = outputs["gaussian_4"][0] return outputs def merge_image_and_mask( self, image: Union[np.ndarray, Image.Image], mask: Union[None, np.ndarray, Image.Image], ): if mask is not None: if isinstance(image, Image.Image): image = np.array(image) mask = np.array(mask) if mask.ndim == 2: mask = mask[..., None] logger.info(f"Replacing alpha channel with the provided mask") assert mask.shape[:2] == image.shape[:2] image = np.concatenate([image[..., :3], mask], axis=-1) image = np.array(image) return image def decode_slat( self, slat: sp.SparseTensor, formats: List[str] = ["mesh", "gaussian"], ) -> dict: """ Decode the structured latent. Args: slat (sp.SparseTensor): The structured latent. formats (List[str]): The formats to decode the structured latent to. Returns: dict: The decoded structured latent. """ logger.info("Decoding sparse latent...") ret = {} with torch.no_grad(): if "mesh" in formats: ret["mesh"] = self.models["slat_decoder_mesh"](slat) if "gaussian" in formats: ret["gaussian"] = self.models["slat_decoder_gs"](slat) if "gaussian_4" in formats: ret["gaussian_4"] = self.models["slat_decoder_gs_4"](slat) # if "radiance_field" in formats: # ret["radiance_field"] = self.models["slat_decoder_rf"](slat) return ret def is_mm_dit(self, model_name="ss_generator"): return hasattr(self.models[model_name].reverse_fn.backbone, "latent_mapping") def embed_condition(self, condition_embedder, *args, **kwargs): if condition_embedder is not None: tokens = condition_embedder(*args, **kwargs) return tokens, None, None return None, args, kwargs def get_condition_input(self, condition_embedder, input_dict, input_mapping): condition_args = self.map_input_keys(input_dict, input_mapping) condition_kwargs = { k: v for k, v in input_dict.items() if k not in input_mapping } logger.info("Running condition embedder ...") embedded_cond, condition_args, condition_kwargs = self.embed_condition( condition_embedder, *condition_args, **condition_kwargs ) logger.info("Condition embedder finishes!") if embedded_cond is not None: condition_args = (embedded_cond,) condition_kwargs = {} return condition_args, condition_kwargs def sample_sparse_structure( self, ss_input_dict: dict, inference_steps=None, use_distillation=False ): ss_generator = self.models["ss_generator"] ss_decoder = self.models["ss_decoder"] if use_distillation: ss_generator.no_shortcut = False ss_generator.reverse_fn.strength = 0 ss_generator.reverse_fn.strength_pm = 0 else: ss_generator.no_shortcut = True ss_generator.reverse_fn.strength = self.ss_cfg_strength ss_generator.reverse_fn.strength_pm = self.ss_cfg_strength_pm prev_inference_steps = ss_generator.inference_steps if inference_steps: ss_generator.inference_steps = inference_steps image = ss_input_dict["image"] bs = image.shape[0] logger.info( "Sampling sparse structure: inference_steps={}, strength={}, interval={}, rescale_t={}, cfg_strength_pm={}", ss_generator.inference_steps, ss_generator.reverse_fn.strength, ss_generator.reverse_fn.interval, ss_generator.rescale_t, ss_generator.reverse_fn.strength_pm, ) with torch.no_grad(): with torch.autocast(device_type="cuda", dtype=self.shape_model_dtype): if self.is_mm_dit(): latent_shape_dict = { k: (bs,) + (v.pos_emb.shape[0], v.input_layer.in_features) for k, v in ss_generator.reverse_fn.backbone.latent_mapping.items() } else: latent_shape_dict = (bs,) + (4096, 8) condition_args, condition_kwargs = self.get_condition_input( self.condition_embedders["ss_condition_embedder"], ss_input_dict, self.ss_condition_input_mapping, ) return_dict = ss_generator( latent_shape_dict, image.device, *condition_args, **condition_kwargs, ) if not self.is_mm_dit(): return_dict = {"shape": return_dict} shape_latent = return_dict["shape"] ss = ss_decoder( shape_latent.permute(0, 2, 1) .contiguous() .view(shape_latent.shape[0], 8, 16, 16, 16) ) coords = torch.argwhere(ss > 0)[:, [0, 2, 3, 4]].int() # downsample output return_dict["coords_original"] = coords original_shape = coords.shape if self.downsample_ss_dist > 0: coords = prune_sparse_structure( coords, max_neighbor_axes_dist=self.downsample_ss_dist, ) coords, downsample_factor = downsample_sparse_structure(coords) logger.info( f"Downsampled coords from {original_shape[0]} to {coords.shape[0]}" ) return_dict["coords"] = coords return_dict["downsample_factor"] = downsample_factor ss_generator.inference_steps = prev_inference_steps return return_dict def sample_slat( self, slat_input: dict, coords: torch.Tensor, inference_steps=25, use_distillation=False, ) -> sp.SparseTensor: image = slat_input["image"] DEVICE = image.device slat_generator = self.models["slat_generator"] latent_shape = (image.shape[0],) + (coords.shape[0], 8) prev_inference_steps = slat_generator.inference_steps if inference_steps: slat_generator.inference_steps = inference_steps if use_distillation: slat_generator.no_shortcut = False slat_generator.reverse_fn.strength = 0 else: slat_generator.no_shortcut = True slat_generator.reverse_fn.strength = self.slat_cfg_strength logger.info( "Sampling sparse latent: inference_steps={}, strength={}, interval={}, rescale_t={}", slat_generator.inference_steps, slat_generator.reverse_fn.strength, slat_generator.reverse_fn.interval, slat_generator.rescale_t, ) with torch.autocast(device_type="cuda", dtype=self.dtype): with torch.no_grad(): condition_args, condition_kwargs = self.get_condition_input( self.condition_embedders["slat_condition_embedder"], slat_input, self.slat_condition_input_mapping, ) condition_args += (coords.cpu().numpy(),) slat = slat_generator( latent_shape, DEVICE, *condition_args, **condition_kwargs ) slat = sp.SparseTensor( coords=coords, feats=slat[0], ).to(DEVICE) slat = slat * self.slat_std.to(DEVICE) + self.slat_mean.to(DEVICE) slat_generator.inference_steps = prev_inference_steps return slat def _apply_transform(self, input: torch.Tensor, transform): if input is not None: input = transform(input) return input def _preprocess_image_and_mask( self, rgb_image, mask_image, img_mask_joint_transform ): for trans in img_mask_joint_transform: rgb_image, mask_image = trans(rgb_image, mask_image) return rgb_image, mask_image def map_input_keys(self, item, condition_input_mapping): output = [item[k] for k in condition_input_mapping] return output def image_to_float(self, image): image = np.array(image) image = image / 255 image = image.astype(np.float32) return image def preprocess_image( self, image: Union[Image.Image, np.ndarray], preprocessor ) -> torch.Tensor: # canonical type is numpy if not isinstance(input, np.ndarray): image = np.array(image) assert image.ndim == 3 # no batch dimension as of now assert image.shape[-1] == 4 # rgba format assert image.dtype == np.uint8 # [0,255] range rgba_image = torch.from_numpy(self.image_to_float(image)) rgba_image = rgba_image.permute(2, 0, 1).contiguous() rgb_image = rgba_image[:3] rgb_image_mask = (get_mask(rgba_image, None, "ALPHA_CHANNEL") > 0).float() processed_rgb_image, processed_mask = self._preprocess_image_and_mask( rgb_image, rgb_image_mask, preprocessor.img_mask_joint_transform ) # transform tensor to model input processed_rgb_image = self._apply_transform( processed_rgb_image, preprocessor.img_transform ) processed_mask = self._apply_transform( processed_mask, preprocessor.mask_transform ) # full image, with only processing from the image rgb_image = self._apply_transform(rgb_image, preprocessor.img_transform) rgb_image_mask = self._apply_transform( rgb_image_mask, preprocessor.mask_transform ) item = { "mask": processed_mask[None].to(self.device), "image": processed_rgb_image[None].to(self.device), "rgb_image": rgb_image[None].to(self.device), "rgb_image_mask": rgb_image_mask[None].to(self.device), } return item @staticmethod def _get_dtype(dtype): if dtype == "bfloat16": return torch.bfloat16 elif dtype == "float16": return torch.float16 elif dtype == "float32": return torch.float32 else: raise NotImplementedError