Spaces:
Running
on
Zero
Running
on
Zero
| from typing import * | |
| from contextlib import contextmanager | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from torchvision import transforms | |
| from PIL import Image | |
| import rembg | |
| from transformers import AutoModel | |
| from .base import Pipeline | |
| from . import samplers | |
| from ..modules import sparse as sp | |
| from ..modules.sparse.basic import SparseTensor, sparse_cat | |
| class OmniPartImageTo3DPipeline(Pipeline): | |
| """ | |
| Pipeline for inferring OmniPart image-to-3D models. | |
| Args: | |
| models (dict[str, nn.Module]): The models to use in the pipeline. | |
| sparse_structure_sampler (samplers.Sampler): The sampler for the sparse structure. | |
| slat_sampler (samplers.Sampler): The sampler for the structured latent. | |
| slat_normalization (dict): The normalization parameters for the structured latent. | |
| image_cond_model (str): The name of the image conditioning model. | |
| """ | |
| def __init__( | |
| self, | |
| models: Dict[str, nn.Module] = None, | |
| sparse_structure_sampler: samplers.Sampler = None, | |
| slat_sampler: samplers.Sampler = None, | |
| slat_normalization: dict = None, | |
| image_cond_model: str = None, | |
| ): | |
| # Skip initialization if models is None (used in from_pretrained) | |
| if models is None: | |
| return | |
| super().__init__(models) | |
| self.sparse_structure_sampler = sparse_structure_sampler | |
| self.slat_sampler = slat_sampler | |
| self.sparse_structure_sampler_params = {} | |
| self.slat_sampler_params = {} | |
| self.slat_normalization = slat_normalization | |
| self.rembg_session = None | |
| self._init_image_cond_model(image_cond_model) | |
| def from_pretrained(path: str) -> "OmniPartImageTo3DPipeline": | |
| """ | |
| Load a pretrained model. | |
| Args: | |
| path (str): The path to the model. Can be either local path or a Hugging Face repository. | |
| Returns: | |
| OmniPartImageTo3DPipeline: Loaded pipeline instance | |
| """ | |
| pipeline = super(OmniPartImageTo3DPipeline, OmniPartImageTo3DPipeline).from_pretrained(path) | |
| new_pipeline = OmniPartImageTo3DPipeline() | |
| new_pipeline.__dict__ = pipeline.__dict__ | |
| args = pipeline._pretrained_args | |
| # Initialize samplers from saved arguments | |
| new_pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])( | |
| **args['sparse_structure_sampler']['args']) | |
| new_pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params'] | |
| new_pipeline.slat_sampler = getattr(samplers, args['slat_sampler']['name'])( | |
| **args['slat_sampler']['args']) | |
| new_pipeline.slat_sampler_params = args['slat_sampler']['params'] | |
| new_pipeline.slat_normalization = args['slat_normalization'] | |
| new_pipeline._init_image_cond_model(args['image_cond_model']) | |
| return new_pipeline | |
| def _init_image_cond_model(self, name: str): | |
| """ | |
| Initialize the image conditioning model. | |
| Args: | |
| name (str): Name of the DINOv2 model to load | |
| """ | |
| dinov2_model = torch.hub.load('facebookresearch/dinov2', name) | |
| dinov2_model.eval() | |
| self.models['image_cond_model'] = dinov2_model | |
| transform = transforms.Compose([ | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| self.image_cond_model_transform = transform | |
| def preprocess_image(self, input: Image.Image, size=(518, 518)) -> Image.Image: | |
| """ | |
| Preprocess the input image for the model. | |
| Args: | |
| input (Image.Image): Input image | |
| size (tuple): Target size for resizing | |
| Returns: | |
| Image.Image: Preprocessed image | |
| """ | |
| img = np.array(input) | |
| if img.shape[-1] == 4: | |
| # Handle alpha channel by replacing transparent pixels with black | |
| mask_img = img[..., 3] == 0 | |
| img[mask_img] = [0, 0, 0, 255] | |
| img = img[..., :3] | |
| img_rgb = Image.fromarray(img.astype('uint8')) | |
| # Resize to target size | |
| img_rgb = img_rgb.resize(size, resample=Image.Resampling.BILINEAR) | |
| return img_rgb | |
| def encode_image(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor: | |
| """ | |
| Encode the image using the conditioning model. | |
| Args: | |
| image (Union[torch.Tensor, list[Image.Image]]): The image(s) to encode | |
| Returns: | |
| torch.Tensor: The encoded features | |
| """ | |
| if isinstance(image, torch.Tensor): | |
| assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)" | |
| elif isinstance(image, list): | |
| assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images" | |
| # Convert PIL images to tensors | |
| image = [i.resize((518, 518), Image.LANCZOS) for i in image] | |
| image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image] | |
| image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image] | |
| image = torch.stack(image).to(self.device) | |
| else: | |
| raise ValueError(f"Unsupported type of image: {type(image)}") | |
| # Apply normalization and run through DINOv2 model | |
| image = self.image_cond_model_transform(image).to(self.device) | |
| features = self.models['image_cond_model'](image, is_training=True)['x_prenorm'] | |
| patchtokens = F.layer_norm(features, features.shape[-1:]) | |
| return patchtokens | |
| def get_cond(self, image: Union[torch.Tensor, List[Image.Image]]) -> dict: | |
| """ | |
| Get the conditioning information for the model. | |
| Args: | |
| image (Union[torch.Tensor, list[Image.Image]]): The image prompts. | |
| Returns: | |
| dict: Dictionary with conditioning information | |
| """ | |
| cond = self.encode_image(image) | |
| neg_cond = torch.zeros_like(cond) # Negative conditioning (zero) | |
| return { | |
| 'cond': cond, | |
| 'neg_cond': neg_cond, | |
| } | |
| def sample_sparse_structure( | |
| self, | |
| cond: dict, | |
| num_samples: int = 1, | |
| sampler_params: dict = {}, | |
| save_coords: bool = False, | |
| ) -> torch.Tensor: | |
| """ | |
| Sample sparse structures with the given conditioning. | |
| Args: | |
| cond (dict): The conditioning information. | |
| num_samples (int): The number of samples to generate. | |
| sampler_params (dict): Additional parameters for the sampler. | |
| save_coords (bool): Whether to save coordinates internally. | |
| Returns: | |
| torch.Tensor: Coordinates of the sparse structure | |
| """ | |
| # Sample occupancy latent | |
| flow_model = self.models['sparse_structure_flow_model'] | |
| reso = flow_model.resolution | |
| noise = torch.randn(num_samples, flow_model.in_channels, reso, reso, reso).to(self.device) | |
| # Merge default and custom sampler parameters | |
| sampler_params = {**self.sparse_structure_sampler_params, **sampler_params} | |
| # Generate samples using the sampler | |
| z_s = self.sparse_structure_sampler.sample( | |
| flow_model, | |
| noise, | |
| **cond, | |
| **sampler_params, | |
| verbose=True | |
| ).samples | |
| # Decode occupancy latent to get coordinates | |
| decoder = self.models['sparse_structure_decoder'] | |
| coords = torch.argwhere(decoder(z_s)>0)[:, [0, 2, 3, 4]].int() | |
| if save_coords: | |
| self.save_coordinates = coords | |
| return coords | |
| def get_coords( | |
| self, | |
| image: Union[Image.Image, List[Image.Image]], | |
| num_samples: int = 1, | |
| seed: int = 42, | |
| sparse_structure_sampler_params: dict = {}, | |
| preprocess_image: bool = True, | |
| save_coords: bool = False, | |
| ) -> dict: | |
| """ | |
| Get coordinates of the sparse structure from an input image. | |
| Args: | |
| image: Input image or list of images | |
| num_samples: Number of samples to generate | |
| seed: Random seed | |
| sparse_structure_sampler_params: Additional parameters for the sparse structure sampler | |
| preprocess_image: Whether to preprocess the image | |
| save_coords: Whether to save coordinates internally | |
| Returns: | |
| torch.Tensor: Coordinates of the sparse structure | |
| """ | |
| if isinstance(image, Image.Image): | |
| if preprocess_image: | |
| image = self.preprocess_image(image) | |
| cond = self.get_cond([image]) | |
| torch.manual_seed(seed) | |
| coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params, save_coords) | |
| return coords | |
| elif isinstance(image, torch.Tensor): | |
| cond = self.get_cond(image.unsqueeze(0)) | |
| torch.manual_seed(seed) | |
| coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params, save_coords) | |
| return coords | |
| elif isinstance(image, list): | |
| if preprocess_image: | |
| image = [self.preprocess_image(i) for i in image] | |
| cond = self.get_cond(image) | |
| torch.manual_seed(seed) | |
| coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params, save_coords) | |
| return coords | |
| else: | |
| raise ValueError(f"Unsupported type of image: {type(image)}") | |
| def sample_slat( | |
| self, | |
| cond: dict, | |
| coords: torch.Tensor, | |
| part_layouts: List[slice] = None, | |
| masks: torch.Tensor = None, | |
| sampler_params: dict = {}, | |
| **kwargs | |
| ) -> sp.SparseTensor: | |
| # Sample structured latent | |
| flow_model = self.models['slat_flow_model'] | |
| # Create noise tensor with same coordinates as the sparse structure | |
| noise = sp.SparseTensor( | |
| feats=torch.randn(coords.shape[0], flow_model.in_channels).to(self.device), | |
| coords=coords, | |
| ) | |
| # Merge default and custom sampler parameters | |
| sampler_params = {**self.slat_sampler_params, **sampler_params} | |
| # Add part information if provided | |
| if part_layouts is not None: | |
| kwargs['part_layouts'] = part_layouts | |
| if masks is not None: | |
| kwargs['masks'] = masks | |
| # Generate samples | |
| slat = self.slat_sampler.sample( | |
| flow_model, | |
| noise, | |
| **cond, | |
| **sampler_params, | |
| verbose=True, | |
| **kwargs | |
| ).samples | |
| # Normalize the features | |
| feat_dim = slat.feats.shape[1] | |
| base_std = torch.tensor(self.slat_normalization['std']).to(slat.device) | |
| base_mean = torch.tensor(self.slat_normalization['mean']).to(slat.device) | |
| # Handle different dimensionality cases | |
| if feat_dim == len(base_std): | |
| # Dimensions match, apply directly | |
| std = base_std[None, :] | |
| mean = base_mean[None, :] | |
| elif feat_dim == 8 and len(base_std) == 9: | |
| # Use first 8 dimensions when latent is 8-dimensional but normalization is 9-dimensional | |
| std = base_std[:8][None, :] | |
| mean = base_mean[:8][None, :] | |
| print(f"Warning: Normalizing {feat_dim}-dimensional features with first 8 dimensions of 9-dimensional parameters") | |
| else: | |
| # Handle general case of dimension mismatch | |
| std = torch.ones((1, feat_dim), device=slat.device) | |
| mean = torch.zeros((1, feat_dim), device=slat.device) | |
| copy_dim = min(feat_dim, len(base_std)) | |
| std[0, :copy_dim] = base_std[:copy_dim] | |
| mean[0, :copy_dim] = base_mean[:copy_dim] | |
| print(f"Warning: Feature dimensions mismatch. Using {copy_dim} dimensions for normalization") | |
| # Apply normalization | |
| slat = slat * std + mean | |
| return slat | |
| def get_slat( | |
| self, | |
| image: Union[Image.Image, List[Image.Image], torch.Tensor], | |
| coords: torch.Tensor, | |
| part_layouts: List[slice], | |
| masks: torch.Tensor, | |
| seed: int = 42, | |
| slat_sampler_params: dict = {}, | |
| formats: List[str] = ['mesh', 'gaussian', 'radiance_field'], | |
| preprocess_image: bool = True, | |
| ) -> dict: | |
| if isinstance(image, Image.Image): | |
| if preprocess_image: | |
| image = self.preprocess_image(image) | |
| cond = self.get_cond([image]) | |
| torch.manual_seed(seed) | |
| slat = self.sample_slat(cond, coords, part_layouts, masks, slat_sampler_params) | |
| return self.decode_slat(self.divide_slat(slat, part_layouts), formats) | |
| elif isinstance(image, list): | |
| if preprocess_image: | |
| image = [self.preprocess_image(i) for i in image] | |
| cond = self.get_cond(image) | |
| torch.manual_seed(seed) | |
| slat = self.sample_slat(cond, coords, part_layouts, masks, slat_sampler_params) | |
| return self.decode_slat(self.divide_slat(slat, part_layouts), formats) | |
| elif isinstance(image, torch.Tensor): | |
| cond = self.get_cond(image.unsqueeze(0)) | |
| torch.manual_seed(seed) | |
| slat = self.sample_slat(cond, coords, part_layouts, masks, slat_sampler_params) | |
| return self.decode_slat(self.divide_slat(slat, part_layouts), formats) | |
| else: | |
| raise ValueError(f"Unsupported type of image: {type(image)}") | |
| def decode_slat( | |
| self, | |
| slat: sp.SparseTensor, | |
| formats: List[str] = ['mesh', 'gaussian', 'radiance_field'], | |
| ) -> dict: | |
| """ | |
| Decode the structured latent. | |
| Args: | |
| slat (sp.SparseTensor): The structured latent | |
| formats (List[str]): The formats to decode to | |
| Returns: | |
| dict: Decoded outputs in requested formats | |
| """ | |
| ret = {} | |
| if 'mesh' in formats: | |
| ret['mesh'] = self.models['slat_decoder_mesh'](slat) | |
| if 'gaussian' in formats: | |
| ret['gaussian'] = self.models['slat_decoder_gs'](slat) | |
| if 'radiance_field' in formats: | |
| ret['radiance_field'] = self.models['slat_decoder_rf'](slat) | |
| return ret | |
| def divide_slat( | |
| self, | |
| slat: sp.SparseTensor, | |
| part_layouts: List[slice], | |
| ) -> List[sp.SparseTensor]: | |
| """ | |
| Divide the structured latent into parts. | |
| Args: | |
| slat (sp.SparseTensor): The structured latent | |
| part_layouts (List[slice]): Layout information for parts | |
| Returns: | |
| sp.SparseTensor: Processed and divided latent | |
| """ | |
| sparse_part = [] | |
| for part_id, part_layout in enumerate(part_layouts): | |
| for part_obj_id, part_slice in enumerate(part_layout): | |
| part_x_sparse_tensor = SparseTensor( | |
| coords=slat[part_id].coords[part_slice], | |
| feats=slat[part_id].feats[part_slice], | |
| ) | |
| sparse_part.append(part_x_sparse_tensor) | |
| slat = sparse_cat(sparse_part) | |
| return self.remove_noise(slat) | |
| def remove_noise(self, z_batch): | |
| """ | |
| Remove noise from latent vectors by filtering out points with low confidence. | |
| Args: | |
| z_batch: Latent vectors to process | |
| Returns: | |
| sp.SparseTensor: Processed latent with noise removed | |
| """ | |
| # Create a new list for processed tensors | |
| processed_batch = [] | |
| for i, z in enumerate(z_batch): | |
| coords = z.coords | |
| feats = z.feats | |
| # Only filter if features have a confidence dimension (9th dimension) | |
| if feats.shape[1] == 9: | |
| # Get the confidence values (last dimension) | |
| last_dim = feats[:, -1] | |
| sigmoid_val = torch.sigmoid(last_dim) | |
| # Calculate filtering statistics | |
| total_points = coords.shape[0] | |
| to_keep = sigmoid_val >= 0.5 | |
| kept_points = to_keep.sum().item() | |
| discarded_points = total_points - kept_points | |
| discard_percentage = (discarded_points / total_points) * 100 if total_points > 0 else 0 | |
| if kept_points == 0: | |
| print(f"No points kept for part {i}") | |
| continue | |
| print(f"Discarded {discarded_points}/{total_points} points ({discard_percentage:.2f}%)") | |
| # Filter coordinates and features | |
| coords = coords[to_keep] | |
| feats = feats[to_keep] | |
| feats = feats[:, :-1] # Remove the confidence dimension | |
| # Create a filtered SparseTensor | |
| processed_z = z.replace(coords=coords, feats=feats) | |
| else: | |
| processed_z = z | |
| processed_batch.append(processed_z) | |
| return sparse_cat(processed_batch) | |
| def inject_sampler_multi_image( | |
| self, | |
| sampler_name: str, | |
| num_images: int, | |
| num_steps: int, | |
| mode: Literal['stochastic', 'multidiffusion'] = 'stochastic', | |
| ): | |
| """ | |
| Inject a sampler with multiple images as condition. | |
| Args: | |
| sampler_name (str): The name of the sampler to inject | |
| num_images (int): The number of images to condition on | |
| num_steps (int): The number of steps to run the sampler for | |
| mode (str): Sampling strategy ('stochastic' or 'multidiffusion') | |
| """ | |
| sampler = getattr(self, sampler_name) | |
| setattr(sampler, f'_old_inference_model', sampler._inference_model) | |
| if mode == 'stochastic': | |
| if num_images > num_steps: | |
| print(f"\033[93mWarning: number of conditioning images is greater than number of steps for {sampler_name}. " | |
| "This may lead to performance degradation.\033[0m") | |
| # Create schedule for which image to use at each step | |
| cond_indices = (np.arange(num_steps) % num_images).tolist() | |
| def _new_inference_model(self, model, x_t, t, cond, **kwargs): | |
| cond_idx = cond_indices.pop(0) | |
| cond_i = cond[cond_idx:cond_idx+1] | |
| return self._old_inference_model(model, x_t, t, cond=cond_i, **kwargs) | |
| elif mode == 'multidiffusion': | |
| from .samplers import FlowEulerSampler | |
| def _new_inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, cfg_interval, **kwargs): | |
| if cfg_interval[0] <= t <= cfg_interval[1]: | |
| # Average predictions from all conditions when within CFG interval | |
| preds = [] | |
| for i in range(len(cond)): | |
| preds.append(FlowEulerSampler._inference_model(self, model, x_t, t, cond[i:i+1], **kwargs)) | |
| pred = sum(preds) / len(preds) | |
| neg_pred = FlowEulerSampler._inference_model(self, model, x_t, t, neg_cond, **kwargs) | |
| return (1 + cfg_strength) * pred - cfg_strength * neg_pred | |
| else: | |
| # Average predictions from all conditions when outside CFG interval | |
| preds = [] | |
| for i in range(len(cond)): | |
| preds.append(FlowEulerSampler._inference_model(self, model, x_t, t, cond[i:i+1], **kwargs)) | |
| pred = sum(preds) / len(preds) | |
| return pred | |
| else: | |
| raise ValueError(f"Unsupported mode: {mode}") | |
| sampler._inference_model = _new_inference_model.__get__(sampler, type(sampler)) | |
| try: | |
| yield | |
| finally: | |
| # Restore original inference model | |
| sampler._inference_model = sampler._old_inference_model | |
| delattr(sampler, f'_old_inference_model') |