| from typing import * |
| import torch |
| import torch.nn as nn |
| import numpy as np |
| from transformers import CLIPTextModel, AutoTokenizer |
| import open3d as o3d |
| from .base import Pipeline |
| from . import samplers |
| from ..modules import sparse as sp |
|
|
|
|
| class TrellisTextTo3DPipeline(Pipeline): |
| """ |
| Pipeline for inferring Trellis text-to-3D models. |
| |
| Args: |
| models (dict[str, nn.Module]): The models to use in the pipeline. |
| sparse_structure_sampler (samplers.Sampler): The sampler for the sparse structure. |
| slat_sampler (samplers.Sampler): The sampler for the structured latent. |
| slat_normalization (dict): The normalization parameters for the structured latent. |
| text_cond_model (str): The name of the text conditioning model. |
| """ |
| def __init__( |
| self, |
| models: dict[str, nn.Module] = None, |
| sparse_structure_sampler: samplers.Sampler = None, |
| slat_sampler: samplers.Sampler = None, |
| slat_normalization: dict = None, |
| text_cond_model: str = None, |
| ): |
| if models is None: |
| return |
| super().__init__(models) |
| self.sparse_structure_sampler = sparse_structure_sampler |
| self.slat_sampler = slat_sampler |
| self.sparse_structure_sampler_params = {} |
| self.slat_sampler_params = {} |
| self.slat_normalization = slat_normalization |
| self._init_text_cond_model(text_cond_model) |
|
|
| @staticmethod |
| def from_pretrained(path: str) -> "TrellisTextTo3DPipeline": |
| """ |
| Load a pretrained model. |
| |
| Args: |
| path (str): The path to the model. Can be either local path or a Hugging Face repository. |
| """ |
| pipeline = super(TrellisTextTo3DPipeline, TrellisTextTo3DPipeline).from_pretrained(path) |
| new_pipeline = TrellisTextTo3DPipeline() |
| new_pipeline.__dict__ = pipeline.__dict__ |
| args = pipeline._pretrained_args |
|
|
| new_pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])(**args['sparse_structure_sampler']['args']) |
| new_pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params'] |
|
|
| new_pipeline.slat_sampler = getattr(samplers, args['slat_sampler']['name'])(**args['slat_sampler']['args']) |
| new_pipeline.slat_sampler_params = args['slat_sampler']['params'] |
|
|
| new_pipeline.slat_normalization = args['slat_normalization'] |
|
|
| new_pipeline._init_text_cond_model(args['text_cond_model']) |
|
|
| return new_pipeline |
| |
| def _init_text_cond_model(self, name: str): |
| """ |
| Initialize the text conditioning model. |
| """ |
| |
| model = CLIPTextModel.from_pretrained(name) |
| tokenizer = AutoTokenizer.from_pretrained(name) |
| model.eval() |
| model = model.cuda() |
| self.text_cond_model = { |
| 'model': model, |
| 'tokenizer': tokenizer, |
| } |
| self.text_cond_model['null_cond'] = self.encode_text(['']) |
|
|
| @torch.no_grad() |
| def encode_text(self, text: List[str]) -> torch.Tensor: |
| """ |
| Encode the text. |
| """ |
| assert isinstance(text, list) and all(isinstance(t, str) for t in text), "text must be a list of strings" |
| encoding = self.text_cond_model['tokenizer'](text, max_length=77, padding='max_length', truncation=True, return_tensors='pt') |
| tokens = encoding['input_ids'].cuda() |
| embeddings = self.text_cond_model['model'](input_ids=tokens).last_hidden_state |
| |
| return embeddings |
| |
| def get_cond(self, prompt: List[str]) -> dict: |
| """ |
| Get the conditioning information for the model. |
| |
| Args: |
| prompt (List[str]): The text prompt. |
| |
| Returns: |
| dict: The conditioning information |
| """ |
| cond = self.encode_text(prompt) |
| neg_cond = self.text_cond_model['null_cond'] |
| return { |
| 'cond': cond, |
| 'neg_cond': neg_cond, |
| } |
|
|
| def sample_sparse_structure( |
| self, |
| cond: dict, |
| num_samples: int = 1, |
| sampler_params: dict = {}, |
| ) -> torch.Tensor: |
| """ |
| Sample sparse structures with the given conditioning. |
| |
| Args: |
| cond (dict): The conditioning information. |
| num_samples (int): The number of samples to generate. |
| sampler_params (dict): Additional parameters for the sampler. |
| """ |
| |
| flow_model = self.models['sparse_structure_flow_model'] |
| reso = flow_model.resolution |
| noise = torch.randn(num_samples, flow_model.in_channels, reso, reso, reso).to(self.device) |
| sampler_params = {**self.sparse_structure_sampler_params, **sampler_params} |
| z_s = self.sparse_structure_sampler.sample( |
| flow_model, |
| noise, |
| **cond, |
| **sampler_params, |
| verbose=True |
| ).samples |
| |
| |
| decoder = self.models['sparse_structure_decoder'] |
| coords = torch.argwhere(decoder(z_s)>0)[:, [0, 2, 3, 4]].int() |
|
|
| return coords |
|
|
| def decode_slat( |
| self, |
| slat: sp.SparseTensor, |
| formats: List[str] = ['mesh', 'gaussian', 'radiance_field'], |
| ) -> dict: |
| """ |
| Decode the structured latent. |
| |
| Args: |
| slat (sp.SparseTensor): The structured latent. |
| formats (List[str]): The formats to decode the structured latent to. |
| |
| Returns: |
| dict: The decoded structured latent. |
| """ |
| ret = {} |
| if 'mesh' in formats: |
| ret['mesh'] = self.models['slat_decoder_mesh'](slat) |
| if 'gaussian' in formats: |
| ret['gaussian'] = self.models['slat_decoder_gs'](slat) |
| if 'radiance_field' in formats: |
| ret['radiance_field'] = self.models['slat_decoder_rf'](slat) |
| return ret |
| |
| def sample_slat( |
| self, |
| cond: dict, |
| coords: torch.Tensor, |
| sampler_params: dict = {}, |
| ) -> sp.SparseTensor: |
| """ |
| Sample structured latent with the given conditioning. |
| |
| Args: |
| cond (dict): The conditioning information. |
| coords (torch.Tensor): The coordinates of the sparse structure. |
| sampler_params (dict): Additional parameters for the sampler. |
| """ |
| |
| flow_model = self.models['slat_flow_model'] |
| noise = sp.SparseTensor( |
| feats=torch.randn(coords.shape[0], flow_model.in_channels).to(self.device), |
| coords=coords, |
| ) |
| sampler_params = {**self.slat_sampler_params, **sampler_params} |
| slat = self.slat_sampler.sample( |
| flow_model, |
| noise, |
| **cond, |
| **sampler_params, |
| verbose=True |
| ).samples |
|
|
| std = torch.tensor(self.slat_normalization['std'])[None].to(slat.device) |
| mean = torch.tensor(self.slat_normalization['mean'])[None].to(slat.device) |
| slat = slat * std + mean |
| |
| return slat |
|
|
| @torch.no_grad() |
| def run( |
| self, |
| prompt: str, |
| num_samples: int = 1, |
| seed: int = 42, |
| sparse_structure_sampler_params: dict = {}, |
| slat_sampler_params: dict = {}, |
| formats: List[str] = ['mesh', 'gaussian', 'radiance_field'], |
| ) -> dict: |
| """ |
| Run the pipeline. |
| |
| Args: |
| prompt (str): The text prompt. |
| num_samples (int): The number of samples to generate. |
| seed (int): The random seed. |
| sparse_structure_sampler_params (dict): Additional parameters for the sparse structure sampler. |
| slat_sampler_params (dict): Additional parameters for the structured latent sampler. |
| formats (List[str]): The formats to decode the structured latent to. |
| """ |
| cond = self.get_cond([prompt]) |
| torch.manual_seed(seed) |
| coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params) |
| slat = self.sample_slat(cond, coords, slat_sampler_params) |
| return self.decode_slat(slat, formats) |
| |
| def voxelize(self, mesh: o3d.geometry.TriangleMesh) -> torch.Tensor: |
| """ |
| Voxelize a mesh. |
| |
| Args: |
| mesh (o3d.geometry.TriangleMesh): The mesh to voxelize. |
| sha256 (str): The SHA256 hash of the mesh. |
| output_dir (str): The output directory. |
| """ |
| vertices = np.asarray(mesh.vertices) |
| aabb = np.stack([vertices.min(0), vertices.max(0)]) |
| center = (aabb[0] + aabb[1]) / 2 |
| scale = (aabb[1] - aabb[0]).max() |
| vertices = (vertices - center) / scale |
| vertices = np.clip(vertices, -0.5 + 1e-6, 0.5 - 1e-6) |
| mesh.vertices = o3d.utility.Vector3dVector(vertices) |
| voxel_grid = o3d.geometry.VoxelGrid.create_from_triangle_mesh_within_bounds(mesh, voxel_size=1/64, min_bound=(-0.5, -0.5, -0.5), max_bound=(0.5, 0.5, 0.5)) |
| vertices = np.array([voxel.grid_index for voxel in voxel_grid.get_voxels()]) |
| return torch.tensor(vertices).int().cuda() |
|
|
| @torch.no_grad() |
| def run_variant( |
| self, |
| mesh: o3d.geometry.TriangleMesh, |
| prompt: str, |
| num_samples: int = 1, |
| seed: int = 42, |
| slat_sampler_params: dict = {}, |
| formats: List[str] = ['mesh', 'gaussian', 'radiance_field'], |
| ) -> dict: |
| """ |
| Run the pipeline for making variants of an asset. |
| |
| Args: |
| mesh (o3d.geometry.TriangleMesh): The base mesh. |
| prompt (str): The text prompt. |
| num_samples (int): The number of samples to generate. |
| seed (int): The random seed |
| slat_sampler_params (dict): Additional parameters for the structured latent sampler. |
| formats (List[str]): The formats to decode the structured latent to. |
| """ |
| cond = self.get_cond([prompt]) |
| coords = self.voxelize(mesh) |
| coords = torch.cat([ |
| torch.arange(num_samples).repeat_interleave(coords.shape[0], 0)[:, None].int().cuda(), |
| coords.repeat(num_samples, 1) |
| ], 1) |
| torch.manual_seed(seed) |
| slat = self.sample_slat(cond, coords, slat_sampler_params) |
| return self.decode_slat(slat, formats) |
|
|