diff --git a/trellis2/__init__.py b/trellis2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b02ac31563fec7c36fa4bd5d420ac4af2472bba8 --- /dev/null +++ b/trellis2/__init__.py @@ -0,0 +1,6 @@ +from . import models +from . import modules +from . import pipelines +from . import renderers +from . import representations +from . import utils diff --git a/trellis2/__pycache__/__init__.cpython-311.pyc b/trellis2/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c6cb0efbf766b42f07b8a0b6ad0ea6952312008 Binary files /dev/null and b/trellis2/__pycache__/__init__.cpython-311.pyc differ diff --git a/trellis2/datasets/__init__.py b/trellis2/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..05c9e6ee3792183659b350373d4a2c7853b414b1 --- /dev/null +++ b/trellis2/datasets/__init__.py @@ -0,0 +1,46 @@ +import importlib + +__attributes = { + 'FlexiDualGridDataset': 'flexi_dual_grid', + 'SparseVoxelPbrDataset':'sparse_voxel_pbr', + + 'SparseStructureLatent': 'sparse_structure_latent', + 'TextConditionedSparseStructureLatent': 'sparse_structure_latent', + 'ImageConditionedSparseStructureLatent': 'sparse_structure_latent', + + 'SLat': 'structured_latent', + 'ImageConditionedSLat': 'structured_latent', + 'SLatShape': 'structured_latent_shape', + 'ImageConditionedSLatShape': 'structured_latent_shape', + 'SLatPbr': 'structured_latent_svpbr', + 'ImageConditionedSLatPbr': 'structured_latent_svpbr', +} + +__submodules = [] + +__all__ = list(__attributes.keys()) + __submodules + +def __getattr__(name): + if name not in globals(): + if name in __attributes: + module_name = __attributes[name] + module = importlib.import_module(f".{module_name}", __name__) + globals()[name] = getattr(module, name) + elif name in __submodules: + module = importlib.import_module(f".{name}", __name__) + globals()[name] = module + else: + raise AttributeError(f"module {__name__} has no attribute {name}") + return globals()[name] + + +# For Pylance +if __name__ == '__main__': + from .flexi_dual_grid import FlexiDualGridDataset + from .sparse_voxel_pbr import SparseVoxelPbrDataset + + from .sparse_structure_latent import SparseStructureLatent, ImageConditionedSparseStructureLatent + from .structured_latent import SLat, ImageConditionedSLat + from .structured_latent_shape import SLatShape, ImageConditionedSLatShape + from .structured_latent_svpbr import SLatPbr, ImageConditionedSLatPbr + \ No newline at end of file diff --git a/trellis2/datasets/components.py b/trellis2/datasets/components.py new file mode 100644 index 0000000000000000000000000000000000000000..c30a136e9dc29f27998ac9f88816203fbaa23594 --- /dev/null +++ b/trellis2/datasets/components.py @@ -0,0 +1,192 @@ +from typing import * +import json +from abc import abstractmethod +import os +import json +import torch +import numpy as np +import pandas as pd +from PIL import Image +from torch.utils.data import Dataset + + +class StandardDatasetBase(Dataset): + """ + Base class for standard datasets. + + Args: + roots (str): paths to the dataset + """ + + def __init__(self, + roots: str, + ): + super().__init__() + try: + self.roots = json.loads(roots) + root_type = 'obj' + except: + self.roots = roots.split(',') + root_type = 'list' + self.instances = [] + self.metadata = pd.DataFrame() + + self._stats = {} + if root_type == 'obj': + for key, root in self.roots.items(): + self._stats[key] = {} + metadata = pd.DataFrame(columns=['sha256']).set_index('sha256') + for _, r in root.items(): + metadata = metadata.combine_first(pd.read_csv(os.path.join(r, 'metadata.csv')).set_index('sha256')) + self._stats[key]['Total'] = len(metadata) + metadata, stats = self.filter_metadata(metadata) + self._stats[key].update(stats) + self.instances.extend([(root, sha256) for sha256 in metadata.index.values]) + self.metadata = pd.concat([self.metadata, metadata]) + else: + for root in self.roots: + key = os.path.basename(root) + self._stats[key] = {} + metadata = pd.read_csv(os.path.join(root, 'metadata.csv')) + self._stats[key]['Total'] = len(metadata) + metadata, stats = self.filter_metadata(metadata) + self._stats[key].update(stats) + self.instances.extend([(root, sha256) for sha256 in metadata['sha256'].values]) + metadata.set_index('sha256', inplace=True) + self.metadata = pd.concat([self.metadata, metadata]) + + @abstractmethod + def filter_metadata(self, metadata: pd.DataFrame) -> Tuple[pd.DataFrame, Dict[str, int]]: + pass + + @abstractmethod + def get_instance(self, root, instance: str) -> Dict[str, Any]: + pass + + def __len__(self): + return len(self.instances) + + def __getitem__(self, index) -> Dict[str, Any]: + try: + root, instance = self.instances[index] + return self.get_instance(root, instance) + except Exception as e: + print(f'Error loading {instance}: {e}') + return self.__getitem__(np.random.randint(0, len(self))) + + def __str__(self): + lines = [] + lines.append(self.__class__.__name__) + lines.append(f' - Total instances: {len(self)}') + lines.append(f' - Sources:') + for key, stats in self._stats.items(): + lines.append(f' - {key}:') + for k, v in stats.items(): + lines.append(f' - {k}: {v}') + return '\n'.join(lines) + + +class ImageConditionedMixin: + def __init__(self, roots, *, image_size=518, **kwargs): + self.image_size = image_size + super().__init__(roots, **kwargs) + + def filter_metadata(self, metadata): + metadata, stats = super().filter_metadata(metadata) + metadata = metadata[metadata['cond_rendered'].notna()] + stats['Cond rendered'] = len(metadata) + return metadata, stats + + def get_instance(self, root, instance): + pack = super().get_instance(root, instance) + + image_root = os.path.join(root['render_cond'], instance) + with open(os.path.join(image_root, 'transforms.json')) as f: + metadata = json.load(f) + n_views = len(metadata['frames']) + view = np.random.randint(n_views) + metadata = metadata['frames'][view] + + image_path = os.path.join(image_root, metadata['file_path']) + image = Image.open(image_path) + + alpha = np.array(image.getchannel(3)) + bbox = np.array(alpha).nonzero() + bbox = [bbox[1].min(), bbox[0].min(), bbox[1].max(), bbox[0].max()] + center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2] + hsize = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2 + aug_hsize = hsize + aug_center_offset = [0, 0] + aug_center = [center[0] + aug_center_offset[0], center[1] + aug_center_offset[1]] + aug_bbox = [int(aug_center[0] - aug_hsize), int(aug_center[1] - aug_hsize), int(aug_center[0] + aug_hsize), int(aug_center[1] + aug_hsize)] + image = image.crop(aug_bbox) + + image = image.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS) + alpha = image.getchannel(3) + image = image.convert('RGB') + image = torch.tensor(np.array(image)).permute(2, 0, 1).float() / 255.0 + alpha = torch.tensor(np.array(alpha)).float() / 255.0 + image = image * alpha.unsqueeze(0) + pack['cond'] = image + + return pack + + +class MultiImageConditionedMixin: + def __init__(self, roots, *, image_size=518, max_image_cond_view = 4, **kwargs): + self.image_size = image_size + self.max_image_cond_view = max_image_cond_view + super().__init__(roots, **kwargs) + + def filter_metadata(self, metadata): + metadata, stats = super().filter_metadata(metadata) + metadata = metadata[metadata['cond_rendered'].notna()] + stats['Cond rendered'] = len(metadata) + return metadata, stats + + def get_instance(self, root, instance): + pack = super().get_instance(root, instance) + + image_root = os.path.join(root['render_cond'], instance) + with open(os.path.join(image_root, 'transforms.json')) as f: + metadata = json.load(f) + + n_views = len(metadata['frames']) + n_sample_views = np.random.randint(1, self.max_image_cond_view+1) + + assert n_views >= n_sample_views, f'Not enough views to sample {n_sample_views} unique images.' + + sampled_views = np.random.choice(n_views, size=n_sample_views, replace=False) + + cond_images = [] + for v in sampled_views: + frame_info = metadata['frames'][v] + image_path = os.path.join(image_root, frame_info['file_path']) + image = Image.open(image_path) + + alpha = np.array(image.getchannel(3)) + bbox = np.array(alpha).nonzero() + bbox = [bbox[1].min(), bbox[0].min(), bbox[1].max(), bbox[0].max()] + center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2] + hsize = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2 + aug_hsize = hsize + aug_center = center + aug_bbox = [ + int(aug_center[0] - aug_hsize), + int(aug_center[1] - aug_hsize), + int(aug_center[0] + aug_hsize), + int(aug_center[1] + aug_hsize), + ] + + img = image.crop(aug_bbox) + img = img.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS) + alpha = img.getchannel(3) + img = img.convert('RGB') + img = torch.tensor(np.array(img)).permute(2, 0, 1).float() / 255.0 + alpha = torch.tensor(np.array(alpha)).float() / 255.0 + img = img * alpha.unsqueeze(0) + + cond_images.append(img) + + pack['cond'] = [torch.stack(cond_images, dim=0)] # (V,3,H,W) + return pack diff --git a/trellis2/datasets/flexi_dual_grid.py b/trellis2/datasets/flexi_dual_grid.py new file mode 100644 index 0000000000000000000000000000000000000000..8b5322bca209bab8e608bc1a1f8894a2c5c67e84 --- /dev/null +++ b/trellis2/datasets/flexi_dual_grid.py @@ -0,0 +1,173 @@ +import os +import numpy as np +import pickle +import torch +import utils3d +from .components import StandardDatasetBase +from ..modules import sparse as sp +from ..renderers import MeshRenderer +from ..representations import Mesh +from ..utils.data_utils import load_balanced_group_indices +import o_voxel + + +class FlexiDualGridVisMixin: + @torch.no_grad() + def visualize_sample(self, x: dict): + mesh = x['mesh'] + + renderer = MeshRenderer({'near': 1, 'far': 3}) + renderer.rendering_options.resolution = 512 + renderer.rendering_options.ssaa = 4 + + # Build camera + yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2] + yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4) + yaws = [y + yaws_offset for y in yaws] + pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)] + + exts = [] + ints = [] + for yaw, pitch in zip(yaws, pitch): + orig = torch.tensor([ + np.sin(yaw) * np.cos(pitch), + np.cos(yaw) * np.cos(pitch), + np.sin(pitch), + ]).float().cuda() * 2 + fov = torch.deg2rad(torch.tensor(30)).cuda() + extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda()) + intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov) + exts.append(extrinsics) + ints.append(intrinsics) + + # Build each representation + images = [] + for m in mesh: + image = torch.zeros(3, 1024, 1024).cuda() + tile = [2, 2] + for j, (ext, intr) in enumerate(zip(exts, ints)): + image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = \ + renderer.render(m.cuda(), ext, intr)['normal'] + images.append(image) + images = torch.stack(images) + + return images + + +class FlexiDualGridDataset(FlexiDualGridVisMixin, StandardDatasetBase): + """ + Flexible Dual Grid Dataset + + Args: + roots (str): path to the dataset + resolution (int): resolution of the voxel grid + min_aesthetic_score (float): minimum aesthetic score of the instances to be included in the dataset + """ + + def __init__( + self, + roots, + resolution: int = 1024, + max_active_voxels: int = 1000000, + max_num_faces: int = None, + min_aesthetic_score: float = 5.0, + ): + self.resolution = resolution + self.min_aesthetic_score = min_aesthetic_score + self.max_active_voxels = max_active_voxels + self.max_num_faces = max_num_faces + self.value_range = (0, 1) + + super().__init__(roots) + + self.loads = [self.metadata.loc[sha256, f'dual_grid_size'] for _, sha256 in self.instances] + + def __str__(self): + lines = [ + super().__str__(), + f' - Resolution: {self.resolution}', + ] + return '\n'.join(lines) + + def filter_metadata(self, metadata): + stats = {} + metadata = metadata[metadata[f'dual_grid_converted'] == True] + stats['Dual Grid Converted'] = len(metadata) + if self.min_aesthetic_score is not None: + metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score] + stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata) + metadata = metadata[metadata[f'dual_grid_size'] <= self.max_active_voxels] + stats[f'Active Voxels <= {self.max_active_voxels}'] = len(metadata) + if self.max_num_faces is not None: + metadata = metadata[metadata['num_faces'] <= self.max_num_faces] + stats[f'Faces <= {self.max_num_faces}'] = len(metadata) + return metadata, stats + + def read_mesh(self, root, instance): + with open(os.path.join(root, f'{instance}.pickle'), 'rb') as f: + dump = pickle.load(f) + start = 0 + vertices = [] + faces = [] + for obj in dump['objects']: + if obj['vertices'].size == 0 or obj['faces'].size == 0: + continue + vertices.append(obj['vertices']) + faces.append(obj['faces'] + start) + start += len(obj['vertices']) + vertices = torch.from_numpy(np.concatenate(vertices, axis=0)).float() + faces = torch.from_numpy(np.concatenate(faces, axis=0)).long() + vertices_min = vertices.min(dim=0)[0] + vertices_max = vertices.max(dim=0)[0] + center = (vertices_min + vertices_max) / 2 + scale = 0.99999 / (vertices_max - vertices_min).max() + vertices = (vertices - center) * scale + assert torch.all(vertices >= -0.5) and torch.all(vertices <= 0.5), 'vertices out of range' + return {'mesh': [Mesh(vertices=vertices, faces=faces)]} + + def read_dual_grid(self, root, instance): + coords, attr = o_voxel.io.read_vxz(os.path.join(root, f'{instance}.vxz'), num_threads=4) + vertices = sp.SparseTensor( + (attr['vertices'] / 255.0).float(), + torch.cat([torch.zeros_like(coords[:, 0:1]), coords], dim=-1), + ) + intersected = vertices.replace(torch.cat([ + attr['intersected'] % 2, + attr['intersected'] // 2 % 2, + attr['intersected'] // 4 % 2, + ], dim=-1).bool()) + return {'vertices': vertices, 'intersected': intersected} + + def get_instance(self, root, instance): + mesh = self.read_mesh(root['mesh_dump'], instance) + dual_grid = self.read_dual_grid(root['dual_grid'], instance) + return {**mesh, **dual_grid} + + @staticmethod + def collate_fn(batch, split_size=None): + if split_size is None: + group_idx = [list(range(len(batch)))] + else: + group_idx = load_balanced_group_indices([b['vertices'].feats.shape[0] for b in batch], split_size) + packs = [] + for group in group_idx: + sub_batch = [batch[i] for i in group] + pack = {} + + keys = [k for k in sub_batch[0].keys()] + for k in keys: + if isinstance(sub_batch[0][k], torch.Tensor): + pack[k] = torch.stack([b[k] for b in sub_batch]) + elif isinstance(sub_batch[0][k], sp.SparseTensor): + pack[k] = sp.sparse_cat([b[k] for b in sub_batch], dim=0) + elif isinstance(sub_batch[0][k], list): + pack[k] = sum([b[k] for b in sub_batch], []) + else: + pack[k] = [b[k] for b in sub_batch] + + packs.append(pack) + + if split_size is None: + return packs[0] + return packs + \ No newline at end of file diff --git a/trellis2/datasets/sparse_structure_latent.py b/trellis2/datasets/sparse_structure_latent.py new file mode 100644 index 0000000000000000000000000000000000000000..535abbc36a86bba2eec73e763ec3807fc854e780 --- /dev/null +++ b/trellis2/datasets/sparse_structure_latent.py @@ -0,0 +1,160 @@ +import os +import json +from typing import * +import numpy as np +import torch +from ..representations import Voxel +from ..renderers import VoxelRenderer +from .components import StandardDatasetBase, ImageConditionedMixin +from .. import models +from ..utils.render_utils import yaw_pitch_r_fov_to_extrinsics_intrinsics + + +class SparseStructureLatentVisMixin: + def __init__( + self, + *args, + pretrained_ss_dec: str = 'JeffreyXiang/TRELLIS-image-large/ckpts/ss_dec_conv3d_16l8_fp16.json', + ss_dec_path: Optional[str] = None, + ss_dec_ckpt: Optional[str] = None, + **kwargs + ): + super().__init__(*args, **kwargs) + self.ss_dec = None + self.pretrained_ss_dec = pretrained_ss_dec + self.ss_dec_path = ss_dec_path + self.ss_dec_ckpt = ss_dec_ckpt + + def _loading_ss_dec(self): + if self.ss_dec is not None: + return + if self.ss_dec_path is not None: + cfg = json.load(open(os.path.join(self.ss_dec_path, 'config.json'), 'r')) + decoder = getattr(models, cfg['models']['decoder']['name'])(**cfg['models']['decoder']['args']) + ckpt_path = os.path.join(self.ss_dec_path, 'ckpts', f'decoder_{self.ss_dec_ckpt}.pt') + decoder.load_state_dict(torch.load(ckpt_path, map_location='cpu', weights_only=True)) + else: + decoder = models.from_pretrained(self.pretrained_ss_dec) + self.ss_dec = decoder.cuda().eval() + + def _delete_ss_dec(self): + del self.ss_dec + self.ss_dec = None + + @torch.no_grad() + def decode_latent(self, z, batch_size=4): + self._loading_ss_dec() + ss = [] + if self.normalization: + z = z * self.std.to(z.device) + self.mean.to(z.device) + for i in range(0, z.shape[0], batch_size): + ss.append(self.ss_dec(z[i:i+batch_size])) + ss = torch.cat(ss, dim=0) + self._delete_ss_dec() + return ss + + @torch.no_grad() + def visualize_sample(self, x_0: Union[torch.Tensor, dict]): + x_0 = x_0 if isinstance(x_0, torch.Tensor) else x_0['x_0'] + x_0 = self.decode_latent(x_0.cuda()) + + renderer = VoxelRenderer() + renderer.rendering_options.resolution = 512 + renderer.rendering_options.ssaa = 4 + + # build camera + yaw = [0, np.pi/2, np.pi, 3*np.pi/2] + yaw_offset = -16 / 180 * np.pi + yaw = [y + yaw_offset for y in yaw] + pitch = [20 / 180 * np.pi for _ in range(4)] + exts, ints = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, 2, 30) + + images = [] + + # Build each representation + x_0 = x_0.cuda() + for i in range(x_0.shape[0]): + coords = torch.nonzero(x_0[i, 0] > 0, as_tuple=False) + resolution = x_0.shape[-1] + color = coords / resolution + rep = Voxel( + origin=[-0.5, -0.5, -0.5], + voxel_size=1/resolution, + coords=coords, + attrs=color, + layout={ + 'color': slice(0, 3), + } + ) + image = torch.zeros(3, 1024, 1024).cuda() + tile = [2, 2] + for j, (ext, intr) in enumerate(zip(exts, ints)): + res = renderer.render(rep, ext, intr, colors_overwrite=color) + image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color'] + images.append(image) + + return torch.stack(images) + + +class SparseStructureLatent(SparseStructureLatentVisMixin, StandardDatasetBase): + """ + Sparse structure latent dataset + + Args: + roots (str): path to the dataset + min_aesthetic_score (float): minimum aesthetic score + normalization (dict): normalization stats + pretrained_ss_dec (str): name of the pretrained sparse structure decoder + ss_dec_path (str): path to the sparse structure decoder, if given, will override the pretrained_ss_dec + ss_dec_ckpt (str): name of the sparse structure decoder checkpoint + """ + def __init__(self, + roots: str, + *, + min_aesthetic_score: float = 5.0, + normalization: Optional[dict] = None, + pretrained_ss_dec: str = 'JeffreyXiang/TRELLIS-image-large/ckpts/ss_dec_conv3d_16l8_fp16', + ss_dec_path: Optional[str] = None, + ss_dec_ckpt: Optional[str] = None, + ): + self.min_aesthetic_score = min_aesthetic_score + self.normalization = normalization + self.value_range = (0, 1) + + super().__init__( + roots, + pretrained_ss_dec=pretrained_ss_dec, + ss_dec_path=ss_dec_path, + ss_dec_ckpt=ss_dec_ckpt, + ) + + if self.normalization is not None: + self.mean = torch.tensor(self.normalization['mean']).reshape(-1, 1, 1, 1) + self.std = torch.tensor(self.normalization['std']).reshape(-1, 1, 1, 1) + + def filter_metadata(self, metadata): + stats = {} + metadata = metadata[metadata['ss_latent_encoded'] == True] + stats['With latent'] = len(metadata) + metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score] + stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata) + return metadata, stats + + def get_instance(self, root, instance): + latent = np.load(os.path.join(root['ss_latent'], f'{instance}.npz')) + z = torch.tensor(latent['z']).float() + if self.normalization is not None: + z = (z - self.mean) / self.std + + pack = { + 'x_0': z, + } + return pack + + +class ImageConditionedSparseStructureLatent(ImageConditionedMixin, SparseStructureLatent): + """ + Image-conditioned sparse structure dataset + """ + pass + \ No newline at end of file diff --git a/trellis2/datasets/sparse_voxel_pbr.py b/trellis2/datasets/sparse_voxel_pbr.py new file mode 100644 index 0000000000000000000000000000000000000000..838036905279aeb86fcae256551b620dc8803d17 --- /dev/null +++ b/trellis2/datasets/sparse_voxel_pbr.py @@ -0,0 +1,298 @@ +import os +import io +from typing import Union +import numpy as np +import pickle +import torch +from PIL import Image +import o_voxel +import utils3d +from .components import StandardDatasetBase +from ..modules import sparse as sp +from ..renderers import VoxelRenderer +from ..representations import Voxel +from ..representations.mesh import MeshWithPbrMaterial, TextureFilterMode, TextureWrapMode, AlphaMode, PbrMaterial, Texture + +from ..utils.data_utils import load_balanced_group_indices + + +def is_power_of_two(n: int) -> bool: + return n > 0 and (n & (n - 1)) == 0 + + +def nearest_power_of_two(n: int) -> int: + if n < 1: + raise ValueError("n must be >= 1") + if is_power_of_two(n): + return n + lower = 2 ** (n.bit_length() - 1) + upper = 2 ** n.bit_length() + if n - lower < upper - n: + return lower + else: + return upper + + +class SparseVoxelPbrVisMixin: + @torch.no_grad() + def visualize_sample(self, x: Union[sp.SparseTensor, dict]): + x = x if isinstance(x, sp.SparseTensor) else x['x'] + + renderer = VoxelRenderer() + renderer.rendering_options.resolution = 512 + renderer.rendering_options.ssaa = 4 + + # Build camera + yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2] + yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4) + yaws = [y + yaws_offset for y in yaws] + pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)] + + exts = [] + ints = [] + for yaw, pitch in zip(yaws, pitch): + orig = torch.tensor([ + np.sin(yaw) * np.cos(pitch), + np.cos(yaw) * np.cos(pitch), + np.sin(pitch), + ]).float().cuda() * 2 + fov = torch.deg2rad(torch.tensor(30)).cuda() + extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda()) + intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov) + exts.append(extrinsics) + ints.append(intrinsics) + + images = {k: [] for k in self.layout} + + # Build each representation + x = x.cuda() + for i in range(x.shape[0]): + rep = Voxel( + origin=[-0.5, -0.5, -0.5], + voxel_size=1/self.resolution, + coords=x[i].coords[:, 1:].contiguous(), + attrs=None, + layout={ + 'color': slice(0, 3), + } + ) + for k in self.layout: + image = torch.zeros(3, 1024, 1024).cuda() + tile = [2, 2] + for j, (ext, intr) in enumerate(zip(exts, ints)): + attr = x[i].feats[:, self.layout[k]].expand(-1, 3) + res = renderer.render(rep, ext, intr, colors_overwrite=attr) + image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color'] + images[k].append(image) + + for k in self.layout: + images[k] = torch.stack(images[k]) + + return images + + +class SparseVoxelPbrDataset(SparseVoxelPbrVisMixin, StandardDatasetBase): + """ + Sparse Voxel PBR dataset. + + Args: + roots (str): path to the dataset + resolution (int): resolution of the voxel grid + min_aesthetic_score (float): minimum aesthetic score of the instances to be included in the dataset + """ + + def __init__( + self, + roots, + resolution: int = 1024, + max_active_voxels: int = 1000000, + max_num_faces: int = None, + min_aesthetic_score: float = 5.0, + attrs: list[str] = ['base_color', 'metallic', 'roughness', 'emissive', 'alpha'], + with_mesh: bool = True, + ): + self.resolution = resolution + self.min_aesthetic_score = min_aesthetic_score + self.max_active_voxels = max_active_voxels + self.max_num_faces = max_num_faces + self.with_mesh = with_mesh + self.value_range = (-1, 1) + self.channels = { + 'base_color': 3, + 'metallic': 1, + 'roughness': 1, + 'emissive': 3, + 'alpha': 1, + } + self.layout = {} + start = 0 + for attr in attrs: + self.layout[attr] = slice(start, start + self.channels[attr]) + start += self.channels[attr] + + super().__init__(roots) + + self.loads = [self.metadata.loc[sha256, f'num_pbr_voxels'] for _, sha256 in self.instances] + + def __str__(self): + lines = [ + super().__str__(), + f' - Resolution: {self.resolution}', + f' - Attributes: {list(self.layout.keys())}', + ] + return '\n'.join(lines) + + def filter_metadata(self, metadata): + stats = {} + metadata = metadata[metadata['pbr_voxelized'] == True] + stats['PBR Voxelized'] = len(metadata) + if self.min_aesthetic_score is not None: + metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score] + stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata) + metadata = metadata[metadata['num_pbr_voxels'] <= self.max_active_voxels] + stats[f'Active voxels <= {self.max_active_voxels}'] = len(metadata) + if self.max_num_faces is not None: + metadata = metadata[metadata['num_faces'] <= self.max_num_faces] + stats[f'Faces <= {self.max_num_faces}'] = len(metadata) + return metadata, stats + + @staticmethod + def _texture_from_dump(pack) -> Texture: + png_bytes = pack['image'] + image = Image.open(io.BytesIO(png_bytes)) + if image.width != image.height or not is_power_of_two(image.width): + size = nearest_power_of_two(max(image.width, image.height)) + image = image.resize((size, size), Image.LANCZOS) + texture = torch.tensor(np.array(image) / 255.0, dtype=torch.float32).reshape(image.height, image.width, -1) + filter_mode = { + 'Linear': TextureFilterMode.LINEAR, + 'Closest': TextureFilterMode.CLOSEST, + 'Cubic': TextureFilterMode.LINEAR, + 'Smart': TextureFilterMode.LINEAR, + }[pack['interpolation']] + wrap_mode = { + 'REPEAT': TextureWrapMode.REPEAT, + 'EXTEND': TextureWrapMode.CLAMP_TO_EDGE, + 'CLIP': TextureWrapMode.CLAMP_TO_EDGE, + 'MIRROR': TextureWrapMode.MIRRORED_REPEAT, + }[pack['extension']] + return Texture(texture, filter_mode=filter_mode, wrap_mode=wrap_mode) + + def read_mesh_with_texture(self, root, instance): + with open(os.path.join(root, f'{instance}.pickle'), 'rb') as f: + dump = pickle.load(f) + + # Fix dump alpha map + for mat in dump['materials']: + if mat['alphaTexture'] is not None and mat['alphaMode'] == 'OPAQUE': + mat['alphaMode'] = 'BLEND' + + # process material + materials = [] + for mat in dump['materials']: + materials.append(PbrMaterial( + base_color_texture=self._texture_from_dump(mat['baseColorTexture']) if mat['baseColorTexture'] is not None else None, + base_color_factor=mat['baseColorFactor'], + metallic_texture=self._texture_from_dump(mat['metallicTexture']) if mat['metallicTexture'] is not None else None, + metallic_factor=mat['metallicFactor'], + roughness_texture=self._texture_from_dump(mat['roughnessTexture']) if mat['roughnessTexture'] is not None else None, + roughness_factor=mat['roughnessFactor'], + alpha_texture=self._texture_from_dump(mat['alphaTexture']) if mat['alphaTexture'] is not None else None, + alpha_factor=mat['alphaFactor'], + alpha_mode={ + 'OPAQUE': AlphaMode.OPAQUE, + 'MASK': AlphaMode.MASK, + 'BLEND': AlphaMode.BLEND, + }[mat['alphaMode']], + alpha_cutoff=mat['alphaCutoff'], + )) + materials.append(PbrMaterial( + base_color_factor=[0.8, 0.8, 0.8], + alpha_factor=1.0, + metallic_factor=0.0, + roughness_factor=0.5, + alpha_mode=AlphaMode.OPAQUE, + alpha_cutoff=0.5, + )) # append default material + + # process mesh + start = 0 + vertices = [] + faces = [] + material_ids = [] + uv_coords = [] + for obj in dump['objects']: + if obj['vertices'].size == 0 or obj['faces'].size == 0: + continue + vertices.append(obj['vertices']) + faces.append(obj['faces'] + start) + obj['mat_ids'][obj['mat_ids'] == -1] = len(materials) - 1 + material_ids.append(obj['mat_ids']) + uv_coords.append(obj['uvs'] if obj['uvs'] is not None else np.zeros((obj['faces'].shape[0], 3, 2), dtype=np.float32)) + start += len(obj['vertices']) + + vertices = torch.from_numpy(np.concatenate(vertices, axis=0)).float() + faces = torch.from_numpy(np.concatenate(faces, axis=0)).long() + material_ids = torch.from_numpy(np.concatenate(material_ids, axis=0)).long() + uv_coords = torch.from_numpy(np.concatenate(uv_coords, axis=0)).float() + + # Normalize vertices + vertices_min = vertices.min(dim=0)[0] + vertices_max = vertices.max(dim=0)[0] + center = (vertices_min + vertices_max) / 2 + scale = 0.99999 / (vertices_max - vertices_min).max() + vertices = (vertices - center) * scale + assert torch.all(vertices >= -0.5) and torch.all(vertices <= 0.5), 'vertices out of range' + + return {'mesh': [MeshWithPbrMaterial( + vertices=vertices, + faces=faces, + material_ids=material_ids, + uv_coords=uv_coords, + materials=materials, + )]} + + def read_pbr_voxel(self, root, instance): + coords, attr = o_voxel.io.read_vxz(os.path.join(root, f'{instance}.vxz'), num_threads=4) + feats = torch.concat([attr[k] for k in self.layout], dim=-1) / 255.0 * 2 - 1 + x = sp.SparseTensor( + feats.float(), + torch.cat([torch.zeros_like(coords[:, 0:1]), coords], dim=-1), + ) + return {'x': x} + + def get_instance(self, root, instance): + if self.with_mesh: + mesh = self.read_mesh_with_texture(root['pbr_dump'], instance) + pbr_voxel = self.read_pbr_voxel(root['pbr_voxel'], instance) + return {**mesh, **pbr_voxel} + else: + return self.read_pbr_voxel(root['pbr_voxel'], instance) + + @staticmethod + def collate_fn(batch, split_size=None): + if split_size is None: + group_idx = [list(range(len(batch)))] + else: + group_idx = load_balanced_group_indices([b['x'].feats.shape[0] for b in batch], split_size) + packs = [] + for group in group_idx: + sub_batch = [batch[i] for i in group] + pack = {} + + keys = [k for k in sub_batch[0].keys()] + for k in keys: + if isinstance(sub_batch[0][k], torch.Tensor): + pack[k] = torch.stack([b[k] for b in sub_batch]) + elif isinstance(sub_batch[0][k], sp.SparseTensor): + pack[k] = sp.sparse_cat([b[k] for b in sub_batch], dim=0) + elif isinstance(sub_batch[0][k], list): + pack[k] = sum([b[k] for b in sub_batch], []) + else: + pack[k] = [b[k] for b in sub_batch] + + packs.append(pack) + + if split_size is None: + return packs[0] + return packs diff --git a/trellis2/datasets/structured_latent.py b/trellis2/datasets/structured_latent.py new file mode 100644 index 0000000000000000000000000000000000000000..664f7e41aaf6971b3f150c80ff872cb51c696632 --- /dev/null +++ b/trellis2/datasets/structured_latent.py @@ -0,0 +1,210 @@ +import json +import os +from typing import * +import numpy as np +import torch +import utils3d.torch +from .components import StandardDatasetBase, ImageConditionedMixin +from ..modules.sparse.basic import SparseTensor +from .. import models +from ..utils.render_utils import get_renderer +from ..utils.data_utils import load_balanced_group_indices + + +class SLatVisMixin: + def __init__( + self, + *args, + pretrained_slat_dec: str = 'JeffreyXiang/TRELLIS-image-large/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16', + slat_dec_path: Optional[str] = None, + slat_dec_ckpt: Optional[str] = None, + **kwargs + ): + super().__init__(*args, **kwargs) + self.slat_dec = None + self.pretrained_slat_dec = pretrained_slat_dec + self.slat_dec_path = slat_dec_path + self.slat_dec_ckpt = slat_dec_ckpt + + def _loading_slat_dec(self): + if self.slat_dec is not None: + return + if self.slat_dec_path is not None: + cfg = json.load(open(os.path.join(self.slat_dec_path, 'config.json'), 'r')) + decoder = getattr(models, cfg['models']['decoder']['name'])(**cfg['models']['decoder']['args']) + ckpt_path = os.path.join(self.slat_dec_path, 'ckpts', f'decoder_{self.slat_dec_ckpt}.pt') + decoder.load_state_dict(torch.load(ckpt_path, map_location='cpu', weights_only=True)) + else: + decoder = models.from_pretrained(self.pretrained_slat_dec) + self.slat_dec = decoder.cuda().eval() + + def _delete_slat_dec(self): + del self.slat_dec + self.slat_dec = None + + @torch.no_grad() + def decode_latent(self, z, batch_size=4): + self._loading_slat_dec() + reps = [] + if self.normalization is not None: + z = z * self.std.to(z.device) + self.mean.to(z.device) + for i in range(0, z.shape[0], batch_size): + reps.append(self.slat_dec(z[i:i+batch_size])) + reps = sum(reps, []) + self._delete_slat_dec() + return reps + + @torch.no_grad() + def visualize_sample(self, x_0: Union[SparseTensor, dict]): + x_0 = x_0 if isinstance(x_0, SparseTensor) else x_0['x_0'] + reps = self.decode_latent(x_0.cuda()) + + # Build camera + yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2] + yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4) + yaws = [y + yaws_offset for y in yaws] + pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)] + + exts = [] + ints = [] + for yaw, pitch in zip(yaws, pitch): + orig = torch.tensor([ + np.sin(yaw) * np.cos(pitch), + np.cos(yaw) * np.cos(pitch), + np.sin(pitch), + ]).float().cuda() * 2 + fov = torch.deg2rad(torch.tensor(40)).cuda() + extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda()) + intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov) + exts.append(extrinsics) + ints.append(intrinsics) + + renderer = get_renderer(reps[0]) + images = [] + for representation in reps: + image = torch.zeros(3, 1024, 1024).cuda() + tile = [2, 2] + for j, (ext, intr) in enumerate(zip(exts, ints)): + res = renderer.render(representation, ext, intr) + image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color'] + images.append(image) + images = torch.stack(images) + + return images + + +class SLat(SLatVisMixin, StandardDatasetBase): + """ + structured latent V2 dataset + + Args: + roots (str): path to the dataset + min_aesthetic_score (float): minimum aesthetic score + max_tokens (int): maximum number of tokens + latent_key (str): key of the latent to be used + normalization (dict): normalization stats + pretrained_slat_dec (str): name of the pretrained slat decoder + slat_dec_path (str): path to the slat decoder, if given, will override the pretrained_slat_dec + slat_dec_ckpt (str): name of the slat decoder checkpoint + """ + def __init__(self, + roots: str, + *, + min_aesthetic_score: float = 5.0, + max_tokens: int = 32768, + latent_key: str = 'shape_latent', + normalization: Optional[dict] = None, + pretrained_slat_dec: str = 'JeffreyXiang/TRELLIS-image-large/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16', + slat_dec_path: Optional[str] = None, + slat_dec_ckpt: Optional[str] = None, + ): + self.normalization = normalization + self.min_aesthetic_score = min_aesthetic_score + self.max_tokens = max_tokens + self.latent_key = latent_key + self.value_range = (0, 1) + + super().__init__( + roots, + pretrained_slat_dec=pretrained_slat_dec, + slat_dec_path=slat_dec_path, + slat_dec_ckpt=slat_dec_ckpt, + ) + + self.loads = [self.metadata.loc[sha256, f'{latent_key}_tokens'] for _, sha256 in self.instances] + + if self.normalization is not None: + self.mean = torch.tensor(self.normalization['mean']).reshape(1, -1) + self.std = torch.tensor(self.normalization['std']).reshape(1, -1) + + def filter_metadata(self, metadata): + stats = {} + metadata = metadata[metadata[f'{self.latent_key}_encoded'] == True] + stats['With latent'] = len(metadata) + metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score] + stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata) + metadata = metadata[metadata[f'{self.latent_key}_tokens'] <= self.max_tokens] + stats[f'Num tokens <= {self.max_tokens}'] = len(metadata) + return metadata, stats + + def get_instance(self, root, instance): + data = np.load(os.path.join(root[self.latent_key], f'{instance}.npz')) + coords = torch.tensor(data['coords']).int() + feats = torch.tensor(data['feats']).float() + if self.normalization is not None: + feats = (feats - self.mean) / self.std + return { + 'coords': coords, + 'feats': feats, + } + + @staticmethod + def collate_fn(batch, split_size=None): + if split_size is None: + group_idx = [list(range(len(batch)))] + else: + group_idx = load_balanced_group_indices([b['coords'].shape[0] for b in batch], split_size) + packs = [] + for group in group_idx: + sub_batch = [batch[i] for i in group] + pack = {} + coords = [] + feats = [] + layout = [] + start = 0 + for i, b in enumerate(sub_batch): + coords.append(torch.cat([torch.full((b['coords'].shape[0], 1), i, dtype=torch.int32), b['coords']], dim=-1)) + feats.append(b['feats']) + layout.append(slice(start, start + b['coords'].shape[0])) + start += b['coords'].shape[0] + coords = torch.cat(coords) + feats = torch.cat(feats) + pack['x_0'] = SparseTensor( + coords=coords, + feats=feats, + ) + pack['x_0']._shape = torch.Size([len(group), *sub_batch[0]['feats'].shape[1:]]) + pack['x_0'].register_spatial_cache('layout', layout) + + # collate other data + keys = [k for k in sub_batch[0].keys() if k not in ['coords', 'feats']] + for k in keys: + if isinstance(sub_batch[0][k], torch.Tensor): + pack[k] = torch.stack([b[k] for b in sub_batch]) + elif isinstance(sub_batch[0][k], list): + pack[k] = sum([b[k] for b in sub_batch], []) + else: + pack[k] = [b[k] for b in sub_batch] + + packs.append(pack) + + if split_size is None: + return packs[0] + return packs + + +class ImageConditionedSLat(ImageConditionedMixin, SLat): + """ + Image conditioned structured latent dataset + """ + pass diff --git a/trellis2/datasets/structured_latent_shape.py b/trellis2/datasets/structured_latent_shape.py new file mode 100644 index 0000000000000000000000000000000000000000..8d636a1340a8646a7064493412223031f79706a6 --- /dev/null +++ b/trellis2/datasets/structured_latent_shape.py @@ -0,0 +1,96 @@ +import os +import json +from typing import * +import numpy as np +import torch +from .. import models +from .components import ImageConditionedMixin +from ..modules.sparse import SparseTensor +from .structured_latent import SLatVisMixin, SLat +from ..utils.render_utils import get_renderer, yaw_pitch_r_fov_to_extrinsics_intrinsics + + +class SLatShapeVisMixin(SLatVisMixin): + def _loading_slat_dec(self): + if self.slat_dec is not None: + return + if self.slat_dec_path is not None: + cfg = json.load(open(os.path.join(self.slat_dec_path, 'config.json'), 'r')) + decoder = getattr(models, cfg['models']['decoder']['name'])(**cfg['models']['decoder']['args']) + ckpt_path = os.path.join(self.slat_dec_path, 'ckpts', f'decoder_{self.slat_dec_ckpt}.pt') + decoder.load_state_dict(torch.load(ckpt_path, map_location='cpu', weights_only=True)) + else: + decoder = models.from_pretrained(self.pretrained_slat_dec) + decoder.set_resolution(self.resolution) + self.slat_dec = decoder.cuda().eval() + + @torch.no_grad() + def visualize_sample(self, x_0: Union[SparseTensor, dict]): + x_0 = x_0 if isinstance(x_0, SparseTensor) else x_0['x_0'] + reps = self.decode_latent(x_0.cuda()) + + # build camera + yaw = [0, np.pi/2, np.pi, 3*np.pi/2] + yaw_offset = -16 / 180 * np.pi + yaw = [y + yaw_offset for y in yaw] + pitch = [20 / 180 * np.pi for _ in range(4)] + exts, ints = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, 2, 30) + + # render + renderer = get_renderer(reps[0]) + images = [] + for representation in reps: + image = torch.zeros(3, 1024, 1024).cuda() + tile = [2, 2] + for j, (ext, intr) in enumerate(zip(exts, ints)): + res = renderer.render(representation, ext, intr) + image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['normal'] + images.append(image) + images = torch.stack(images) + return images + + +class SLatShape(SLatShapeVisMixin, SLat): + """ + structured latent for shape generation + + Args: + roots (str): path to the dataset + resolution (int): resolution of the shape + min_aesthetic_score (float): minimum aesthetic score + max_tokens (int): maximum number of tokens + latent_key (str): key of the latent to be used + normalization (dict): normalization stats + pretrained_slat_dec (str): name of the pretrained slat decoder + slat_dec_path (str): path to the slat decoder, if given, will override the pretrained_slat_dec + slat_dec_ckpt (str): name of the slat decoder checkpoint + """ + def __init__(self, + roots: str, + *, + resolution: int, + min_aesthetic_score: float = 5.0, + max_tokens: int = 32768, + normalization: Optional[dict] = None, + pretrained_slat_dec: str = 'microsoft/TRELLIS.2-4B/ckpts/shape_dec_next_dc_f16c32_fp16', + slat_dec_path: Optional[str] = None, + slat_dec_ckpt: Optional[str] = None, + ): + super().__init__( + roots, + min_aesthetic_score=min_aesthetic_score, + max_tokens=max_tokens, + latent_key='shape_latent', + normalization=normalization, + pretrained_slat_dec=pretrained_slat_dec, + slat_dec_path=slat_dec_path, + slat_dec_ckpt=slat_dec_ckpt, + ) + self.resolution = resolution + + +class ImageConditionedSLatShape(ImageConditionedMixin, SLatShape): + """ + Image conditioned structured latent for shape generation + """ + pass diff --git a/trellis2/datasets/structured_latent_svpbr.py b/trellis2/datasets/structured_latent_svpbr.py new file mode 100644 index 0000000000000000000000000000000000000000..56735f471dc7cfd9d35ad648d8a60dd2de71702e --- /dev/null +++ b/trellis2/datasets/structured_latent_svpbr.py @@ -0,0 +1,290 @@ +import os +os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' +import json +from typing import * +import numpy as np +import torch +import cv2 +from .. import models +from .components import StandardDatasetBase, ImageConditionedMixin +from ..modules.sparse import SparseTensor, sparse_cat +from ..representations import MeshWithVoxel +from ..renderers import PbrMeshRenderer, EnvMap +from ..utils.data_utils import load_balanced_group_indices +from ..utils.render_utils import yaw_pitch_r_fov_to_extrinsics_intrinsics + + +class SLatPbrVisMixin: + def __init__( + self, + *args, + pretrained_pbr_slat_dec: str = 'JeffreyXiang/TRELLIS.2-4B/ckpts/tex_dec_next_dc_f16c32_fp16', + pbr_slat_dec_path: Optional[str] = None, + pbr_slat_dec_ckpt: Optional[str] = None, + pretrained_shape_slat_dec: str = 'JeffreyXiang/TRELLIS.2-4B/ckpts/shape_dec_next_dc_f16c32_fp16', + shape_slat_dec_path: Optional[str] = None, + shape_slat_dec_ckpt: Optional[str] = None, + **kwargs + ): + super().__init__(*args, **kwargs) + self.pbr_slat_dec = None + self.pretrained_pbr_slat_dec = pretrained_pbr_slat_dec + self.pbr_slat_dec_path = pbr_slat_dec_path + self.pbr_slat_dec_ckpt = pbr_slat_dec_ckpt + self.shape_slat_dec = None + self.pretrained_shape_slat_dec = pretrained_shape_slat_dec + self.shape_slat_dec_path = shape_slat_dec_path + self.shape_slat_dec_ckpt = shape_slat_dec_ckpt + + def _loading_slat_dec(self): + if self.pbr_slat_dec is not None and self.shape_slat_dec is not None: + return + if self.pbr_slat_dec_path is not None: + cfg = json.load(open(os.path.join(self.pbr_slat_dec_path, 'config.json'), 'r')) + decoder = getattr(models, cfg['models']['decoder']['name'])(**cfg['models']['decoder']['args']) + ckpt_path = os.path.join(self.pbr_slat_dec_path, 'ckpts', f'decoder_{self.pbr_slat_dec_ckpt}.pt') + decoder.load_state_dict(torch.load(ckpt_path, map_location='cpu', weights_only=True)) + else: + decoder = models.from_pretrained(self.pretrained_pbr_slat_dec) + self.pbr_slat_dec = decoder.cuda().eval() + + if self.shape_slat_dec_path is not None: + cfg = json.load(open(os.path.join(self.shape_slat_dec_path, 'config.json'), 'r')) + decoder = getattr(models, cfg['models']['decoder']['name'])(**cfg['models']['decoder']['args']) + ckpt_path = os.path.join(self.shape_slat_dec_path, 'ckpts', f'decoder_{self.shape_slat_dec_ckpt}.pt') + decoder.load_state_dict(torch.load(ckpt_path, map_location='cpu', weights_only=True)) + else: + decoder = models.from_pretrained(self.pretrained_shape_slat_dec) + decoder.set_resolution(self.resolution) + self.shape_slat_dec = decoder.cuda().eval() + + def _delete_slat_dec(self): + del self.pbr_slat_dec + self.pbr_slat_dec = None + del self.shape_slat_dec + self.shape_slat_dec = None + + @torch.no_grad() + def decode_latent(self, z, shape_z, batch_size=4): + self._loading_slat_dec() + reps = [] + if self.shape_slat_normalization is not None: + shape_z = shape_z * self.shape_slat_std.to(z.device) + self.shape_slat_mean.to(z.device) + if self.pbr_slat_normalization is not None: + z = z * self.pbr_slat_std.to(z.device) + self.pbr_slat_mean.to(z.device) + for i in range(0, z.shape[0], batch_size): + mesh, subs = self.shape_slat_dec(shape_z[i:i+batch_size], return_subs=True) + vox = self.pbr_slat_dec(z[i:i+batch_size], guide_subs=subs) * 0.5 + 0.5 + reps.extend([ + MeshWithVoxel( + m.vertices, m.faces, + origin = [-0.5, -0.5, -0.5], + voxel_size = 1 / self.resolution, + coords = v.coords[:, 1:], + attrs = v.feats, + voxel_shape = torch.Size([*v.shape, *v.spatial_shape]), + layout = self.layout, + ) + for m, v in zip(mesh, vox) + ]) + self._delete_slat_dec() + return reps + + @torch.no_grad() + def visualize_sample(self, sample: dict): + shape_z = sample['concat_cond'].cuda() + z = sample['x_0'].cuda() + reps = self.decode_latent(z, shape_z) + + # build camera + yaw = [0, np.pi/2, np.pi, 3*np.pi/2] + yaw_offset = -16 / 180 * np.pi + yaw = [y + yaw_offset for y in yaw] + pitch = [20 / 180 * np.pi for _ in range(4)] + exts, ints = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, 2, 30) + + # render + renderer = PbrMeshRenderer() + renderer.rendering_options.resolution = 512 + renderer.rendering_options.near = 1 + renderer.rendering_options.far = 100 + renderer.rendering_options.ssaa = 2 + renderer.rendering_options.peel_layers = 8 + envmap = EnvMap(torch.tensor( + cv2.cvtColor(cv2.imread('assets/hdri/forest.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), + dtype=torch.float32, device='cuda' + )) + + images = {} + for representation in reps: + image = {} + tile = [2, 2] + for j, (ext, intr) in enumerate(zip(exts, ints)): + res = renderer.render(representation, ext, intr, envmap=envmap) + for k, v in res.items(): + if k not in images: + images[k] = [] + if k not in image: + image[k] = torch.zeros(3, 1024, 1024).cuda() + image[k][:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = v + for k in images.keys(): + images[k].append(image[k]) + for k in images.keys(): + images[k] = torch.stack(images[k], dim=0) + return images + + +class SLatPbr(SLatPbrVisMixin, StandardDatasetBase): + """ + structured latent for sparse voxel pbr dataset + + Args: + roots (str): path to the dataset + latent_key (str): key of the latent to be used + min_aesthetic_score (float): minimum aesthetic score + normalization (dict): normalization stats + resolution (int): resolution of decoded sparse voxel + attrs (list): attributes to be decoded + pretained_slat_dec (str): name of the pretrained slat decoder + slat_dec_path (str): path to the slat decoder, if given, will override the pretrained_slat_dec + slat_dec_ckpt (str): name of the slat decoder checkpoint + """ + def __init__(self, + roots: str, + *, + resolution: int, + min_aesthetic_score: float = 5.0, + max_tokens: int = 32768, + full_pbr: bool = False, + pbr_slat_normalization: Optional[dict] = None, + shape_slat_normalization: Optional[dict] = None, + attrs: list[str] = ['base_color', 'metallic', 'roughness', 'emissive', 'alpha'], + pretrained_pbr_slat_dec: str = 'JeffreyXiang/TRELLIS.2-4B/ckpts/tex_dec_next_dc_f16c32_fp16', + pbr_slat_dec_path: Optional[str] = None, + pbr_slat_dec_ckpt: Optional[str] = None, + pretrained_shape_slat_dec: str = 'JeffreyXiang/TRELLIS.2-4B/ckpts/shape_dec_next_dc_f16c32_fp16', + shape_slat_dec_path: Optional[str] = None, + shape_slat_dec_ckpt: Optional[str] = None, + **kwargs + ): + self.resolution = resolution + self.pbr_slat_normalization = pbr_slat_normalization + self.shape_slat_normalization = shape_slat_normalization + self.min_aesthetic_score = min_aesthetic_score + self.max_tokens = max_tokens + self.full_pbr = full_pbr + self.value_range = (0, 1) + + super().__init__( + roots, + pretrained_pbr_slat_dec=pretrained_pbr_slat_dec, + pbr_slat_dec_path=pbr_slat_dec_path, + pbr_slat_dec_ckpt=pbr_slat_dec_ckpt, + pretrained_shape_slat_dec=pretrained_shape_slat_dec, + shape_slat_dec_path=shape_slat_dec_path, + shape_slat_dec_ckpt=shape_slat_dec_ckpt, + **kwargs + ) + + self.loads = [self.metadata.loc[sha256, 'pbr_latent_tokens'] for _, sha256 in self.instances] + + if self.pbr_slat_normalization is not None: + self.pbr_slat_mean = torch.tensor(self.pbr_slat_normalization['mean']).reshape(1, -1) + self.pbr_slat_std = torch.tensor(self.pbr_slat_normalization['std']).reshape(1, -1) + + if self.shape_slat_normalization is not None: + self.shape_slat_mean = torch.tensor(self.shape_slat_normalization['mean']).reshape(1, -1) + self.shape_slat_std = torch.tensor(self.shape_slat_normalization['std']).reshape(1, -1) + + self.attrs = attrs + self.channels = { + 'base_color': 3, + 'metallic': 1, + 'roughness': 1, + 'emissive': 3, + 'alpha': 1, + } + self.layout = {} + start = 0 + for attr in attrs: + self.layout[attr] = slice(start, start + self.channels[attr]) + start += self.channels[attr] + + def filter_metadata(self, metadata): + stats = {} + metadata = metadata[metadata['pbr_latent_encoded'] == True] + stats['With PBR latent'] = len(metadata) + metadata = metadata[metadata['shape_latent_encoded'] == True] + stats['With shape latent'] = len(metadata) + metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score] + stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata) + metadata = metadata[metadata['pbr_latent_tokens'] <= self.max_tokens] + stats[f'Num tokens <= {self.max_tokens}'] = len(metadata) + if self.full_pbr: + metadata = metadata[metadata['num_basecolor_tex'] > 0] + metadata = metadata[metadata['num_metallic_tex'] > 0] + metadata = metadata[metadata['num_roughness_tex'] > 0] + stats['Full PBR'] = len(metadata) + return metadata, stats + + def get_instance(self, root, instance): + # PBR latent + data = np.load(os.path.join(root['pbr_latent'], f'{instance}.npz')) + coords = torch.tensor(data['coords']).int() + coords = torch.cat([torch.zeros_like(coords)[:, :1], coords], dim=1) + feats = torch.tensor(data['feats']).float() + if self.pbr_slat_normalization is not None: + feats = (feats - self.pbr_slat_mean) / self.pbr_slat_std + pbr_z = SparseTensor(feats, coords) + + # Shape latent + data = np.load(os.path.join(root['shape_latent'], f'{instance}.npz')) + coords = torch.tensor(data['coords']).int() + coords = torch.cat([torch.zeros_like(coords)[:, :1], coords], dim=1) + feats = torch.tensor(data['feats']).float() + if self.shape_slat_normalization is not None: + feats = (feats - self.shape_slat_mean) / self.shape_slat_std + shape_z = SparseTensor(feats, coords) + + assert torch.equal(shape_z.coords, pbr_z.coords), \ + f"Shape latent and PBR latent have different coordinates: {shape_z.coords.shape} vs {pbr_z.coords.shape}" + + return { + 'x_0': pbr_z, + 'concat_cond': shape_z, + } + + @staticmethod + def collate_fn(batch, split_size=None): + if split_size is None: + group_idx = [list(range(len(batch)))] + else: + group_idx = load_balanced_group_indices([b['x_0'].feats.shape[0] for b in batch], split_size) + packs = [] + for group in group_idx: + sub_batch = [batch[i] for i in group] + pack = {} + + keys = [k for k in sub_batch[0].keys()] + for k in keys: + if isinstance(sub_batch[0][k], torch.Tensor): + pack[k] = torch.stack([b[k] for b in sub_batch]) + elif isinstance(sub_batch[0][k], SparseTensor): + pack[k] = sparse_cat([b[k] for b in sub_batch], dim=0) + elif isinstance(sub_batch[0][k], list): + pack[k] = sum([b[k] for b in sub_batch], []) + else: + pack[k] = [b[k] for b in sub_batch] + + packs.append(pack) + + if split_size is None: + return packs[0] + return packs + + +class ImageConditionedSLatPbr(ImageConditionedMixin, SLatPbr): + """ + Image conditioned structured latent dataset + """ + pass diff --git a/trellis2/models/__init__.py b/trellis2/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3ed4799488d304f61dfe28c552f5daf9e1a452da --- /dev/null +++ b/trellis2/models/__init__.py @@ -0,0 +1,78 @@ +import importlib + +__attributes = { + # Sparse Structure + 'SparseStructureEncoder': 'sparse_structure_vae', + 'SparseStructureDecoder': 'sparse_structure_vae', + 'SparseStructureFlowModel': 'sparse_structure_flow', + + # SLat Generation + 'SLatFlowModel': 'structured_latent_flow', + 'ElasticSLatFlowModel': 'structured_latent_flow', + + # SC-VAEs + 'SparseUnetVaeEncoder': 'sc_vaes.sparse_unet_vae', + 'SparseUnetVaeDecoder': 'sc_vaes.sparse_unet_vae', + 'FlexiDualGridVaeEncoder': 'sc_vaes.fdg_vae', + 'FlexiDualGridVaeDecoder': 'sc_vaes.fdg_vae' +} + +__submodules = [] + +__all__ = list(__attributes.keys()) + __submodules + +def __getattr__(name): + if name not in globals(): + if name in __attributes: + module_name = __attributes[name] + module = importlib.import_module(f".{module_name}", __name__) + globals()[name] = getattr(module, name) + elif name in __submodules: + module = importlib.import_module(f".{name}", __name__) + globals()[name] = module + else: + raise AttributeError(f"module {__name__} has no attribute {name}") + return globals()[name] + + +def from_pretrained(path: str, **kwargs): + """ + Load a model from a pretrained checkpoint. + + Args: + path: The path to the checkpoint. Can be either local path or a Hugging Face model name. + NOTE: config file and model file should take the name f'{path}.json' and f'{path}.safetensors' respectively. + **kwargs: Additional arguments for the model constructor. + """ + import os + import json + from safetensors.torch import load_file + is_local = os.path.exists(f"{path}.json") and os.path.exists(f"{path}.safetensors") + + if is_local: + config_file = f"{path}.json" + model_file = f"{path}.safetensors" + else: + from huggingface_hub import hf_hub_download + path_parts = path.split('/') + repo_id = f'{path_parts[0]}/{path_parts[1]}' + model_name = '/'.join(path_parts[2:]) + config_file = hf_hub_download(repo_id, f"{model_name}.json") + model_file = hf_hub_download(repo_id, f"{model_name}.safetensors") + + with open(config_file, 'r') as f: + config = json.load(f) + model = __getattr__(config['name'])(**config['args'], **kwargs) + model.load_state_dict(load_file(model_file), strict=False) + + return model + + +# For Pylance +if __name__ == '__main__': + from .sparse_structure_vae import SparseStructureEncoder, SparseStructureDecoder + from .sparse_structure_flow import SparseStructureFlowModel + from .structured_latent_flow import SLatFlowModel, ElasticSLatFlowModel + + from .sc_vaes.sparse_unet_vae import SparseUnetVaeEncoder, SparseUnetVaeDecoder + from .sc_vaes.fdg_vae import FlexiDualGridVaeEncoder, FlexiDualGridVaeDecoder diff --git a/trellis2/models/__pycache__/__init__.cpython-311.pyc b/trellis2/models/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3be59143ad16d1b4846dbb13f687e4eb8d4763b Binary files /dev/null and b/trellis2/models/__pycache__/__init__.cpython-311.pyc differ diff --git a/trellis2/models/__pycache__/sparse_elastic_mixin.cpython-311.pyc b/trellis2/models/__pycache__/sparse_elastic_mixin.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ef3bf15ca715fd15dc9483c3d1a9084ea49e586 Binary files /dev/null and b/trellis2/models/__pycache__/sparse_elastic_mixin.cpython-311.pyc differ diff --git a/trellis2/models/__pycache__/sparse_structure_flow.cpython-311.pyc b/trellis2/models/__pycache__/sparse_structure_flow.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dae04cab4653c993c0ff75aabfdecc2ad8dc0e97 Binary files /dev/null and b/trellis2/models/__pycache__/sparse_structure_flow.cpython-311.pyc differ diff --git a/trellis2/models/__pycache__/sparse_structure_vae.cpython-311.pyc b/trellis2/models/__pycache__/sparse_structure_vae.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cdc129acaa4f83192bdc68113c5aed4ba229ff2c Binary files /dev/null and b/trellis2/models/__pycache__/sparse_structure_vae.cpython-311.pyc differ diff --git a/trellis2/models/__pycache__/structured_latent_flow.cpython-311.pyc b/trellis2/models/__pycache__/structured_latent_flow.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5fc9105bbebaf54db9c0d3a654929471e74c6e5 Binary files /dev/null and b/trellis2/models/__pycache__/structured_latent_flow.cpython-311.pyc differ diff --git a/trellis2/models/sc_vaes/__pycache__/fdg_vae.cpython-311.pyc b/trellis2/models/sc_vaes/__pycache__/fdg_vae.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb942251720844a56a013175ed119b0e044b5e2f Binary files /dev/null and b/trellis2/models/sc_vaes/__pycache__/fdg_vae.cpython-311.pyc differ diff --git a/trellis2/models/sc_vaes/__pycache__/sparse_unet_vae.cpython-311.pyc b/trellis2/models/sc_vaes/__pycache__/sparse_unet_vae.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c7ead550e70f7b50f35ce7faf704b46d7112ee90 Binary files /dev/null and b/trellis2/models/sc_vaes/__pycache__/sparse_unet_vae.cpython-311.pyc differ diff --git a/trellis2/models/sc_vaes/fdg_vae.py b/trellis2/models/sc_vaes/fdg_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..17552d06eac071b1cfb5a78ce10bdf70b165d230 --- /dev/null +++ b/trellis2/models/sc_vaes/fdg_vae.py @@ -0,0 +1,110 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +from ...modules import sparse as sp +from .sparse_unet_vae import ( + SparseResBlock3d, + SparseConvNeXtBlock3d, + + SparseResBlockDownsample3d, + SparseResBlockUpsample3d, + SparseResBlockS2C3d, + SparseResBlockC2S3d, +) +from .sparse_unet_vae import ( + SparseUnetVaeEncoder, + SparseUnetVaeDecoder, +) +from ...representations import Mesh +from o_voxel.convert import flexible_dual_grid_to_mesh + + +class FlexiDualGridVaeEncoder(SparseUnetVaeEncoder): + def __init__( + self, + model_channels: List[int], + latent_channels: int, + num_blocks: List[int], + block_type: List[str], + down_block_type: List[str], + block_args: List[Dict[str, Any]], + use_fp16: bool = False, + ): + super().__init__( + 6, + model_channels, + latent_channels, + num_blocks, + block_type, + down_block_type, + block_args, + use_fp16, + ) + + def forward(self, vertices: sp.SparseTensor, intersected: sp.SparseTensor, sample_posterior=False, return_raw=False): + x = vertices.replace(torch.cat([ + vertices.feats - 0.5, + intersected.feats.float() - 0.5, + ], dim=1)) + return super().forward(x, sample_posterior, return_raw) + + +class FlexiDualGridVaeDecoder(SparseUnetVaeDecoder): + def __init__( + self, + resolution: int, + model_channels: List[int], + latent_channels: int, + num_blocks: List[int], + block_type: List[str], + up_block_type: List[str], + block_args: List[Dict[str, Any]], + voxel_margin: float = 0.5, + use_fp16: bool = False, + ): + self.resolution = resolution + self.voxel_margin = voxel_margin + + super().__init__( + 7, + model_channels, + latent_channels, + num_blocks, + block_type, + up_block_type, + block_args, + use_fp16, + ) + + def set_resolution(self, resolution: int) -> None: + self.resolution = resolution + + def forward(self, x: sp.SparseTensor, gt_intersected: sp.SparseTensor = None, **kwargs): + decoded = super().forward(x, **kwargs) + if self.training: + h, subs_gt, subs = decoded + vertices = h.replace((1 + 2 * self.voxel_margin) * F.sigmoid(h.feats[..., 0:3]) - self.voxel_margin) + intersected_logits = h.replace(h.feats[..., 3:6]) + quad_lerp = h.replace(F.softplus(h.feats[..., 6:7])) + mesh = [Mesh(*flexible_dual_grid_to_mesh( + v.coords[:, 1:], v.feats, i.feats, q.feats, + aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]], + grid_size=self.resolution, + train=True + )) for v, i, q in zip(vertices, gt_intersected, quad_lerp)] + return mesh, vertices, intersected_logits, subs_gt, subs + else: + out_list = list(decoded) if isinstance(decoded, tuple) else [decoded] + h = out_list[0] + vertices = h.replace((1 + 2 * self.voxel_margin) * F.sigmoid(h.feats[..., 0:3]) - self.voxel_margin) + intersected = h.replace(h.feats[..., 3:6] > 0) + quad_lerp = h.replace(F.softplus(h.feats[..., 6:7])) + mesh = [Mesh(*flexible_dual_grid_to_mesh( + v.coords[:, 1:], v.feats, i.feats, q.feats, + aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]], + grid_size=self.resolution, + train=False + )) for v, i, q in zip(vertices, intersected, quad_lerp)] + out_list[0] = mesh + return out_list[0] if len(out_list) == 1 else tuple(out_list) diff --git a/trellis2/models/sc_vaes/sparse_unet_vae.py b/trellis2/models/sc_vaes/sparse_unet_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..ea5d4731bcbd0b8a41dc6247a5eabbb7671307f9 --- /dev/null +++ b/trellis2/models/sc_vaes/sparse_unet_vae.py @@ -0,0 +1,522 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +from ...modules.utils import convert_module_to_f16, convert_module_to_f32, zero_module +from ...modules import sparse as sp +from ...modules.norm import LayerNorm32 + + +class SparseResBlock3d(nn.Module): + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + downsample: bool = False, + upsample: bool = False, + resample_mode: Literal['nearest', 'spatial2channel'] = 'nearest', + use_checkpoint: bool = False, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.downsample = downsample + self.upsample = upsample + self.resample_mode = resample_mode + self.use_checkpoint = use_checkpoint + + assert not (downsample and upsample), "Cannot downsample and upsample at the same time" + + self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6) + if resample_mode == 'nearest': + self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3) + elif resample_mode =='spatial2channel' and not self.downsample: + self.conv1 = sp.SparseConv3d(channels, self.out_channels * 8, 3) + elif resample_mode =='spatial2channel' and self.downsample: + self.conv1 = sp.SparseConv3d(channels, self.out_channels // 8, 3) + self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3)) + if resample_mode == 'nearest': + self.skip_connection = sp.SparseLinear(channels, self.out_channels) if channels != self.out_channels else nn.Identity() + elif resample_mode =='spatial2channel' and self.downsample: + self.skip_connection = lambda x: x.replace(x.feats.reshape(x.feats.shape[0], out_channels, channels * 8 // out_channels).mean(dim=-1)) + elif resample_mode =='spatial2channel' and not self.downsample: + self.skip_connection = lambda x: x.replace(x.feats.repeat_interleave(out_channels // (channels // 8), dim=1)) + self.updown = None + if self.downsample: + if resample_mode == 'nearest': + self.updown = sp.SparseDownsample(2) + elif resample_mode =='spatial2channel': + self.updown = sp.SparseSpatial2Channel(2) + elif self.upsample: + self.to_subdiv = sp.SparseLinear(channels, 8) + if resample_mode == 'nearest': + self.updown = sp.SparseUpsample(2) + elif resample_mode =='spatial2channel': + self.updown = sp.SparseChannel2Spatial(2) + + def _updown(self, x: sp.SparseTensor, subdiv: sp.SparseTensor = None) -> sp.SparseTensor: + if self.downsample: + x = self.updown(x) + elif self.upsample: + x = self.updown(x, subdiv.replace(subdiv.feats > 0)) + return x + + def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + subdiv = None + if self.upsample: + subdiv = self.to_subdiv(x) + h = x.replace(self.norm1(x.feats)) + h = h.replace(F.silu(h.feats)) + if self.resample_mode == 'spatial2channel': + h = self.conv1(h) + h = self._updown(h, subdiv) + x = self._updown(x, subdiv) + if self.resample_mode == 'nearest': + h = self.conv1(h) + h = h.replace(self.norm2(h.feats)) + h = h.replace(F.silu(h.feats)) + h = self.conv2(h) + h = h + self.skip_connection(x) + if self.upsample: + return h, subdiv + return h + + def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) + else: + return self._forward(x) + + +class SparseResBlockDownsample3d(nn.Module): + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + use_checkpoint: bool = False, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_checkpoint = use_checkpoint + + self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6) + self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3) + self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3)) + self.skip_connection = sp.SparseLinear(channels, self.out_channels) if channels != self.out_channels else nn.Identity() + self.updown = sp.SparseDownsample(2) + + def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + h = x.replace(self.norm1(x.feats)) + h = h.replace(F.silu(h.feats)) + h = self.updown(h) + x = self.updown(x) + h = self.conv1(h) + h = h.replace(self.norm2(h.feats)) + h = h.replace(F.silu(h.feats)) + h = self.conv2(h) + h = h + self.skip_connection(x) + return h + + def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) + else: + return self._forward(x) + + +class SparseResBlockUpsample3d(nn.Module): + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + use_checkpoint: bool = False, + pred_subdiv: bool = True, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_checkpoint = use_checkpoint + self.pred_subdiv = pred_subdiv + + self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6) + self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3) + self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3)) + self.skip_connection = sp.SparseLinear(channels, self.out_channels) if channels != self.out_channels else nn.Identity() + if self.pred_subdiv: + self.to_subdiv = sp.SparseLinear(channels, 8) + self.updown = sp.SparseUpsample(2) + + def _forward(self, x: sp.SparseTensor, subdiv: sp.SparseTensor = None) -> sp.SparseTensor: + if self.pred_subdiv: + subdiv = self.to_subdiv(x) + h = x.replace(self.norm1(x.feats)) + h = h.replace(F.silu(h.feats)) + subdiv_binarized = subdiv.replace(subdiv.feats > 0) if subdiv is not None else None + h = self.updown(h, subdiv_binarized) + x = self.updown(x, subdiv_binarized) + h = self.conv1(h) + h = h.replace(self.norm2(h.feats)) + h = h.replace(F.silu(h.feats)) + h = self.conv2(h) + h = h + self.skip_connection(x) + if self.pred_subdiv: + return h, subdiv + else: + return h + + def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) + else: + return self._forward(x) + + +class SparseResBlockS2C3d(nn.Module): + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + use_checkpoint: bool = False, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_checkpoint = use_checkpoint + + self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6) + self.conv1 = sp.SparseConv3d(channels, self.out_channels // 8, 3) + self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3)) + self.skip_connection = lambda x: x.replace(x.feats.reshape(x.feats.shape[0], out_channels, channels * 8 // out_channels).mean(dim=-1)) + self.updown = sp.SparseSpatial2Channel(2) + + def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + h = x.replace(self.norm1(x.feats)) + h = h.replace(F.silu(h.feats)) + h = self.conv1(h) + h = self.updown(h) + x = self.updown(x) + h = h.replace(self.norm2(h.feats)) + h = h.replace(F.silu(h.feats)) + h = self.conv2(h) + h = h + self.skip_connection(x) + return h + + def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) + else: + return self._forward(x) + + +class SparseResBlockC2S3d(nn.Module): + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + use_checkpoint: bool = False, + pred_subdiv: bool = True, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_checkpoint = use_checkpoint + self.pred_subdiv = pred_subdiv + + self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6) + self.conv1 = sp.SparseConv3d(channels, self.out_channels * 8, 3) + self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3)) + self.skip_connection = lambda x: x.replace(x.feats.repeat_interleave(out_channels // (channels // 8), dim=1)) + if pred_subdiv: + self.to_subdiv = sp.SparseLinear(channels, 8) + self.updown = sp.SparseChannel2Spatial(2) + + def _forward(self, x: sp.SparseTensor, subdiv: sp.SparseTensor = None) -> sp.SparseTensor: + if self.pred_subdiv: + subdiv = self.to_subdiv(x) + h = x.replace(self.norm1(x.feats)) + h = h.replace(F.silu(h.feats)) + h = self.conv1(h) + subdiv_binarized = subdiv.replace(subdiv.feats > 0) if subdiv is not None else None + h = self.updown(h, subdiv_binarized) + x = self.updown(x, subdiv_binarized) + h = h.replace(self.norm2(h.feats)) + h = h.replace(F.silu(h.feats)) + h = self.conv2(h) + h = h + self.skip_connection(x) + if self.pred_subdiv: + return h, subdiv + else: + return h + + def forward(self, x: sp.SparseTensor, subdiv: sp.SparseTensor = None) -> sp.SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, subdiv, use_reentrant=False) + else: + return self._forward(x, subdiv) + + +class SparseConvNeXtBlock3d(nn.Module): + def __init__( + self, + channels: int, + mlp_ratio: float = 4.0, + use_checkpoint: bool = False, + ): + super().__init__() + self.channels = channels + self.use_checkpoint = use_checkpoint + + self.norm = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.conv = sp.SparseConv3d(channels, channels, 3) + self.mlp = nn.Sequential( + nn.Linear(channels, int(channels * mlp_ratio)), + nn.SiLU(), + zero_module(nn.Linear(int(channels * mlp_ratio), channels)), + ) + + def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + h = self.conv(x) + h = h.replace(self.norm(h.feats)) + h = h.replace(self.mlp(h.feats)) + return h + x + + def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) + else: + return self._forward(x) + + +class SparseUnetVaeEncoder(nn.Module): + """ + Sparse Swin Transformer Unet VAE model. + """ + def __init__( + self, + in_channels: int, + model_channels: List[int], + latent_channels: int, + num_blocks: List[int], + block_type: List[str], + down_block_type: List[str], + block_args: List[Dict[str, Any]], + use_fp16: bool = False, + ): + super().__init__() + self.in_channels = in_channels + self.model_channels = model_channels + self.num_blocks = num_blocks + self.dtype = torch.float16 if use_fp16 else torch.float32 + self.dtype = torch.float16 if use_fp16 else torch.float32 + + self.input_layer = sp.SparseLinear(in_channels, model_channels[0]) + self.to_latent = sp.SparseLinear(model_channels[-1], 2 * latent_channels) + + self.blocks = nn.ModuleList([]) + for i in range(len(num_blocks)): + self.blocks.append(nn.ModuleList([])) + for j in range(num_blocks[i]): + self.blocks[-1].append( + globals()[block_type[i]]( + model_channels[i], + **block_args[i], + ) + ) + if i < len(num_blocks) - 1: + self.blocks[-1].append( + globals()[down_block_type[i]]( + model_channels[i], + model_channels[i+1], + **block_args[i], + ) + ) + + self.initialize_weights() + if use_fp16: + self.convert_to_fp16() + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.blocks.apply(convert_module_to_f32) + + def initialize_weights(self) -> None: + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + def forward(self, x: sp.SparseTensor, sample_posterior=False, return_raw=False): + h = self.input_layer(x) + h = h.type(self.dtype) + for i, res in enumerate(self.blocks): + for j, block in enumerate(res): + h = block(h) + h = h.type(x.dtype) + h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) + h = self.to_latent(h) + + # Sample from the posterior distribution + mean, logvar = h.feats.chunk(2, dim=-1) + if sample_posterior: + std = torch.exp(0.5 * logvar) + z = mean + std * torch.randn_like(std) + else: + z = mean + z = h.replace(z) + + if return_raw: + return z, mean, logvar + else: + return z + + +class SparseUnetVaeDecoder(nn.Module): + """ + Sparse Swin Transformer Unet VAE model. + """ + def __init__( + self, + out_channels: int, + model_channels: List[int], + latent_channels: int, + num_blocks: List[int], + block_type: List[str], + up_block_type: List[str], + block_args: List[Dict[str, Any]], + use_fp16: bool = False, + pred_subdiv: bool = True, + ): + super().__init__() + self.out_channels = out_channels + self.model_channels = model_channels + self.num_blocks = num_blocks + self.use_fp16 = use_fp16 + self.pred_subdiv = pred_subdiv + self.dtype = torch.float16 if use_fp16 else torch.float32 + self.low_vram = False + + self.output_layer = sp.SparseLinear(model_channels[-1], out_channels) + self.from_latent = sp.SparseLinear(latent_channels, model_channels[0]) + + self.blocks = nn.ModuleList([]) + for i in range(len(num_blocks)): + self.blocks.append(nn.ModuleList([])) + for j in range(num_blocks[i]): + self.blocks[-1].append( + globals()[block_type[i]]( + model_channels[i], + **block_args[i], + ) + ) + if i < len(num_blocks) - 1: + self.blocks[-1].append( + globals()[up_block_type[i]]( + model_channels[i], + model_channels[i+1], + pred_subdiv=pred_subdiv, + **block_args[i], + ) + ) + + self.initialize_weights() + if use_fp16: + self.convert_to_fp16() + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.blocks.apply(convert_module_to_f32) + + def initialize_weights(self) -> None: + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + def forward(self, x: sp.SparseTensor, guide_subs: Optional[List[sp.SparseTensor]] = None, return_subs: bool = False) -> sp.SparseTensor: + assert guide_subs is None or self.pred_subdiv == False, "Only decoders with pred_subdiv=False can be used with guide_subs" + assert return_subs == False or self.pred_subdiv == True, "Only decoders with pred_subdiv=True can be used with return_subs" + + h = self.from_latent(x) + h = h.type(self.dtype) + subs_gt = [] + subs = [] + for i, res in enumerate(self.blocks): + for j, block in enumerate(res): + if i < len(self.blocks) - 1 and j == len(res) - 1: + if self.pred_subdiv: + if self.training: + subs_gt.append(h.get_spatial_cache('subdivision')) + h, sub = block(h) + subs.append(sub) + else: + h = block(h, subdiv=guide_subs[i] if guide_subs is not None else None) + else: + h = block(h) + h = h.type(x.dtype) + h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) + h = self.output_layer(h) + if self.training and self.pred_subdiv: + return h, subs_gt, subs + else: + if return_subs: + return h, subs + else: + return h + + def upsample(self, x: sp.SparseTensor, upsample_times: int) -> torch.Tensor: + assert self.pred_subdiv == True, "Only decoders with pred_subdiv=True can be used with upsampling" + + h = self.from_latent(x) + h = h.type(self.dtype) + for i, res in enumerate(self.blocks): + if i == upsample_times: + return h.coords + for j, block in enumerate(res): + if i < len(self.blocks) - 1 and j == len(res) - 1: + h, sub = block(h) + else: + h = block(h) + \ No newline at end of file diff --git a/trellis2/models/sparse_elastic_mixin.py b/trellis2/models/sparse_elastic_mixin.py new file mode 100644 index 0000000000000000000000000000000000000000..75cf58319463e7644661ad434bb3f2703340ad09 --- /dev/null +++ b/trellis2/models/sparse_elastic_mixin.py @@ -0,0 +1,24 @@ +from contextlib import contextmanager +from typing import * +import math +from ..modules import sparse as sp +from ..utils.elastic_utils import ElasticModuleMixin + + +class SparseTransformerElasticMixin(ElasticModuleMixin): + def _get_input_size(self, x: sp.SparseTensor, *args, **kwargs): + return x.feats.shape[0] + + @contextmanager + def with_mem_ratio(self, mem_ratio=1.0): + if mem_ratio == 1.0: + yield 1.0 + return + num_blocks = len(self.blocks) + num_checkpoint_blocks = min(math.ceil((1 - mem_ratio) * num_blocks) + 1, num_blocks) + exact_mem_ratio = 1 - (num_checkpoint_blocks - 1) / num_blocks + for i in range(num_blocks): + self.blocks[i].use_checkpoint = i < num_checkpoint_blocks + yield exact_mem_ratio + for i in range(num_blocks): + self.blocks[i].use_checkpoint = False diff --git a/trellis2/models/sparse_structure_flow.py b/trellis2/models/sparse_structure_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..72d34ef00da2bbbd7b5aeb5466341b45251d8602 --- /dev/null +++ b/trellis2/models/sparse_structure_flow.py @@ -0,0 +1,247 @@ +from typing import * +from functools import partial +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from ..modules.utils import convert_module_to, manual_cast, str_to_dtype +from ..modules.transformer import AbsolutePositionEmbedder, ModulatedTransformerCrossBlock +from ..modules.attention import RotaryPositionEmbedder + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + Args: + t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + dim: the dimension of the output. + max_period: controls the minimum frequency of the embeddings. + + Returns: + an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp( + -np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + + +class SparseStructureFlowModel(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + model_channels: int, + cond_channels: int, + out_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + pe_mode: Literal["ape", "rope"] = "ape", + rope_freq: Tuple[float, float] = (1.0, 10000.0), + dtype: str = 'float32', + use_checkpoint: bool = False, + share_mod: bool = False, + initialization: str = 'vanilla', + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + **kwargs + ): + super().__init__() + self.resolution = resolution + self.in_channels = in_channels + self.model_channels = model_channels + self.cond_channels = cond_channels + self.out_channels = out_channels + self.num_blocks = num_blocks + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.pe_mode = pe_mode + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.initialization = initialization + self.qk_rms_norm = qk_rms_norm + self.qk_rms_norm_cross = qk_rms_norm_cross + self.dtype = str_to_dtype(dtype) + + self.t_embedder = TimestepEmbedder(model_channels) + if share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(model_channels, 6 * model_channels, bias=True) + ) + + if pe_mode == "ape": + pos_embedder = AbsolutePositionEmbedder(model_channels, 3) + coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution] * 3], indexing='ij') + coords = torch.stack(coords, dim=-1).reshape(-1, 3) + pos_emb = pos_embedder(coords) + self.register_buffer("pos_emb", pos_emb) + elif pe_mode == "rope": + pos_embedder = RotaryPositionEmbedder(self.model_channels // self.num_heads, 3) + coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution] * 3], indexing='ij') + coords = torch.stack(coords, dim=-1).reshape(-1, 3) + rope_phases = pos_embedder(coords) + self.register_buffer("rope_phases", rope_phases) + + if pe_mode != "rope": + self.rope_phases = None + + self.input_layer = nn.Linear(in_channels, model_channels) + + self.blocks = nn.ModuleList([ + ModulatedTransformerCrossBlock( + model_channels, + cond_channels, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + attn_mode='full', + use_checkpoint=self.use_checkpoint, + use_rope=(pe_mode == "rope"), + rope_freq=rope_freq, + share_mod=share_mod, + qk_rms_norm=self.qk_rms_norm, + qk_rms_norm_cross=self.qk_rms_norm_cross, + ) + for _ in range(num_blocks) + ]) + + self.out_layer = nn.Linear(model_channels, out_channels) + + self.initialize_weights() + self.convert_to(self.dtype) + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to(self, dtype: torch.dtype) -> None: + """ + Convert the torso of the model to the specified dtype. + """ + self.dtype = dtype + self.blocks.apply(partial(convert_module_to, dtype=dtype)) + + def initialize_weights(self) -> None: + if self.initialization == 'vanilla': + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + if self.share_mod: + nn.init.constant_(self.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.adaLN_modulation[-1].bias, 0) + else: + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + elif self.initialization == 'scaled': + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, std=np.sqrt(2.0 / (5.0 * self.model_channels))) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + # Scaled init for to_out and ffn2 + def _scaled_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, std=1.0 / np.sqrt(5 * self.num_blocks * self.model_channels)) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + for block in self.blocks: + block.self_attn.to_out.apply(_scaled_init) + block.cross_attn.to_out.apply(_scaled_init) + block.mlp.mlp[2].apply(_scaled_init) + + # Initialize input layer to make the initial representation have variance 1 + nn.init.normal_(self.input_layer.weight, std=1.0 / np.sqrt(self.in_channels)) + nn.init.zeros_(self.input_layer.bias) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + if self.share_mod: + nn.init.constant_(self.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.adaLN_modulation[-1].bias, 0) + else: + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: + assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \ + f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}" + + h = x.view(*x.shape[:2], -1).permute(0, 2, 1).contiguous() + + h = self.input_layer(h) + if self.pe_mode == "ape": + h = h + self.pos_emb[None] + t_emb = self.t_embedder(t) + if self.share_mod: + t_emb = self.adaLN_modulation(t_emb) + t_emb = manual_cast(t_emb, self.dtype) + h = manual_cast(h, self.dtype) + cond = manual_cast(cond, self.dtype) + for block in self.blocks: + h = block(h, t_emb, cond, self.rope_phases) + h = manual_cast(h, x.dtype) + h = F.layer_norm(h, h.shape[-1:]) + h = self.out_layer(h) + + h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution] * 3).contiguous() + + return h diff --git a/trellis2/models/sparse_structure_vae.py b/trellis2/models/sparse_structure_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..6ed49ae65b9cde2a45a59beb6868981a644b75d3 --- /dev/null +++ b/trellis2/models/sparse_structure_vae.py @@ -0,0 +1,306 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +from ..modules.norm import GroupNorm32, ChannelLayerNorm32 +from ..modules.spatial import pixel_shuffle_3d +from ..modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32 + + +def norm_layer(norm_type: str, *args, **kwargs) -> nn.Module: + """ + Return a normalization layer. + """ + if norm_type == "group": + return GroupNorm32(32, *args, **kwargs) + elif norm_type == "layer": + return ChannelLayerNorm32(*args, **kwargs) + else: + raise ValueError(f"Invalid norm type {norm_type}") + + +class ResBlock3d(nn.Module): + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + norm_type: Literal["group", "layer"] = "layer", + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.norm1 = norm_layer(norm_type, channels) + self.norm2 = norm_layer(norm_type, self.out_channels) + self.conv1 = nn.Conv3d(channels, self.out_channels, 3, padding=1) + self.conv2 = zero_module(nn.Conv3d(self.out_channels, self.out_channels, 3, padding=1)) + self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.norm1(x) + h = F.silu(h) + h = self.conv1(h) + h = self.norm2(h) + h = F.silu(h) + h = self.conv2(h) + h = h + self.skip_connection(x) + return h + + +class DownsampleBlock3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + mode: Literal["conv", "avgpool"] = "conv", + ): + assert mode in ["conv", "avgpool"], f"Invalid mode {mode}" + + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + if mode == "conv": + self.conv = nn.Conv3d(in_channels, out_channels, 2, stride=2) + elif mode == "avgpool": + assert in_channels == out_channels, "Pooling mode requires in_channels to be equal to out_channels" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if hasattr(self, "conv"): + return self.conv(x) + else: + return F.avg_pool3d(x, 2) + + +class UpsampleBlock3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + mode: Literal["conv", "nearest"] = "conv", + ): + assert mode in ["conv", "nearest"], f"Invalid mode {mode}" + + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + if mode == "conv": + self.conv = nn.Conv3d(in_channels, out_channels*8, 3, padding=1) + elif mode == "nearest": + assert in_channels == out_channels, "Nearest mode requires in_channels to be equal to out_channels" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if hasattr(self, "conv"): + x = self.conv(x) + return pixel_shuffle_3d(x, 2) + else: + return F.interpolate(x, scale_factor=2, mode="nearest") + + +class SparseStructureEncoder(nn.Module): + """ + Encoder for Sparse Structure (\mathcal{E}_S in the paper Sec. 3.3). + + Args: + in_channels (int): Channels of the input. + latent_channels (int): Channels of the latent representation. + num_res_blocks (int): Number of residual blocks at each resolution. + channels (List[int]): Channels of the encoder blocks. + num_res_blocks_middle (int): Number of residual blocks in the middle. + norm_type (Literal["group", "layer"]): Type of normalization layer. + use_fp16 (bool): Whether to use FP16. + """ + def __init__( + self, + in_channels: int, + latent_channels: int, + num_res_blocks: int, + channels: List[int], + num_res_blocks_middle: int = 2, + norm_type: Literal["group", "layer"] = "layer", + use_fp16: bool = False, + ): + super().__init__() + self.in_channels = in_channels + self.latent_channels = latent_channels + self.num_res_blocks = num_res_blocks + self.channels = channels + self.num_res_blocks_middle = num_res_blocks_middle + self.norm_type = norm_type + self.use_fp16 = use_fp16 + self.dtype = torch.float16 if use_fp16 else torch.float32 + + self.input_layer = nn.Conv3d(in_channels, channels[0], 3, padding=1) + + self.blocks = nn.ModuleList([]) + for i, ch in enumerate(channels): + self.blocks.extend([ + ResBlock3d(ch, ch) + for _ in range(num_res_blocks) + ]) + if i < len(channels) - 1: + self.blocks.append( + DownsampleBlock3d(ch, channels[i+1]) + ) + + self.middle_block = nn.Sequential(*[ + ResBlock3d(channels[-1], channels[-1]) + for _ in range(num_res_blocks_middle) + ]) + + self.out_layer = nn.Sequential( + norm_layer(norm_type, channels[-1]), + nn.SiLU(), + nn.Conv3d(channels[-1], latent_channels*2, 3, padding=1) + ) + + if use_fp16: + self.convert_to_fp16() + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.use_fp16 = True + self.dtype = torch.float16 + self.blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.use_fp16 = False + self.dtype = torch.float32 + self.blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x: torch.Tensor, sample_posterior: bool = False, return_raw: bool = False) -> torch.Tensor: + h = self.input_layer(x) + h = h.type(self.dtype) + + for block in self.blocks: + h = block(h) + h = self.middle_block(h) + + h = h.type(x.dtype) + h = self.out_layer(h) + + mean, logvar = h.chunk(2, dim=1) + + if sample_posterior: + std = torch.exp(0.5 * logvar) + z = mean + std * torch.randn_like(std) + else: + z = mean + + if return_raw: + return z, mean, logvar + return z + + +class SparseStructureDecoder(nn.Module): + """ + Decoder for Sparse Structure (\mathcal{D}_S in the paper Sec. 3.3). + + Args: + out_channels (int): Channels of the output. + latent_channels (int): Channels of the latent representation. + num_res_blocks (int): Number of residual blocks at each resolution. + channels (List[int]): Channels of the decoder blocks. + num_res_blocks_middle (int): Number of residual blocks in the middle. + norm_type (Literal["group", "layer"]): Type of normalization layer. + use_fp16 (bool): Whether to use FP16. + """ + def __init__( + self, + out_channels: int, + latent_channels: int, + num_res_blocks: int, + channels: List[int], + num_res_blocks_middle: int = 2, + norm_type: Literal["group", "layer"] = "layer", + use_fp16: bool = False, + ): + super().__init__() + self.out_channels = out_channels + self.latent_channels = latent_channels + self.num_res_blocks = num_res_blocks + self.channels = channels + self.num_res_blocks_middle = num_res_blocks_middle + self.norm_type = norm_type + self.use_fp16 = use_fp16 + self.dtype = torch.float16 if use_fp16 else torch.float32 + + self.input_layer = nn.Conv3d(latent_channels, channels[0], 3, padding=1) + + self.middle_block = nn.Sequential(*[ + ResBlock3d(channels[0], channels[0]) + for _ in range(num_res_blocks_middle) + ]) + + self.blocks = nn.ModuleList([]) + for i, ch in enumerate(channels): + self.blocks.extend([ + ResBlock3d(ch, ch) + for _ in range(num_res_blocks) + ]) + if i < len(channels) - 1: + self.blocks.append( + UpsampleBlock3d(ch, channels[i+1]) + ) + + self.out_layer = nn.Sequential( + norm_layer(norm_type, channels[-1]), + nn.SiLU(), + nn.Conv3d(channels[-1], out_channels, 3, padding=1) + ) + + if use_fp16: + self.convert_to_fp16() + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.use_fp16 = True + self.dtype = torch.float16 + self.blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.use_fp16 = False + self.dtype = torch.float32 + self.blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.input_layer(x) + + h = h.type(self.dtype) + + h = self.middle_block(h) + for block in self.blocks: + h = block(h) + + h = h.type(x.dtype) + h = self.out_layer(h) + return h diff --git a/trellis2/models/structured_latent_flow.py b/trellis2/models/structured_latent_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..8039d27bb74dbe5f5bae3bd024baaed13a514724 --- /dev/null +++ b/trellis2/models/structured_latent_flow.py @@ -0,0 +1,207 @@ +from typing import * +from functools import partial +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from ..modules.utils import convert_module_to, manual_cast, str_to_dtype +from ..modules.transformer import AbsolutePositionEmbedder +from ..modules import sparse as sp +from ..modules.sparse.transformer import ModulatedSparseTransformerCrossBlock +from .sparse_structure_flow import TimestepEmbedder +from .sparse_elastic_mixin import SparseTransformerElasticMixin + + +class SLatFlowModel(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + model_channels: int, + cond_channels: int, + out_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + pe_mode: Literal["ape", "rope"] = "ape", + rope_freq: Tuple[float, float] = (1.0, 10000.0), + dtype: str = 'float32', + use_checkpoint: bool = False, + share_mod: bool = False, + initialization: str = 'vanilla', + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + ): + super().__init__() + self.resolution = resolution + self.in_channels = in_channels + self.model_channels = model_channels + self.cond_channels = cond_channels + self.out_channels = out_channels + self.num_blocks = num_blocks + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.pe_mode = pe_mode + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.initialization = initialization + self.qk_rms_norm = qk_rms_norm + self.qk_rms_norm_cross = qk_rms_norm_cross + self.dtype = str_to_dtype(dtype) + + self.t_embedder = TimestepEmbedder(model_channels) + if share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(model_channels, 6 * model_channels, bias=True) + ) + + if pe_mode == "ape": + self.pos_embedder = AbsolutePositionEmbedder(model_channels) + + self.input_layer = sp.SparseLinear(in_channels, model_channels) + + self.blocks = nn.ModuleList([ + ModulatedSparseTransformerCrossBlock( + model_channels, + cond_channels, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + attn_mode='full', + use_checkpoint=self.use_checkpoint, + use_rope=(pe_mode == "rope"), + rope_freq=rope_freq, + share_mod=self.share_mod, + qk_rms_norm=self.qk_rms_norm, + qk_rms_norm_cross=self.qk_rms_norm_cross, + ) + for _ in range(num_blocks) + ]) + + self.out_layer = sp.SparseLinear(model_channels, out_channels) + + self.initialize_weights() + self.convert_to(self.dtype) + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to(self, dtype: torch.dtype) -> None: + """ + Convert the torso of the model to the specified dtype. + """ + self.dtype = dtype + self.blocks.apply(partial(convert_module_to, dtype=dtype)) + + def initialize_weights(self) -> None: + if self.initialization == 'vanilla': + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + if self.share_mod: + nn.init.constant_(self.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.adaLN_modulation[-1].bias, 0) + else: + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + elif self.initialization == 'scaled': + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, std=np.sqrt(2.0 / (5.0 * self.model_channels))) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + # Scaled init for to_out and ffn2 + def _scaled_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, std=1.0 / np.sqrt(5 * self.num_blocks * self.model_channels)) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + for block in self.blocks: + block.self_attn.to_out.apply(_scaled_init) + block.cross_attn.to_out.apply(_scaled_init) + block.mlp.mlp[2].apply(_scaled_init) + + # Initialize input layer to make the initial representation have variance 1 + nn.init.normal_(self.input_layer.weight, std=1.0 / np.sqrt(self.in_channels)) + nn.init.zeros_(self.input_layer.bias) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + if self.share_mod: + nn.init.constant_(self.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.adaLN_modulation[-1].bias, 0) + else: + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + def forward( + self, + x: sp.SparseTensor, + t: torch.Tensor, + cond: Union[torch.Tensor, List[torch.Tensor]], + concat_cond: Optional[sp.SparseTensor] = None, + **kwargs + ) -> sp.SparseTensor: + if concat_cond is not None: + x = sp.sparse_cat([x, concat_cond], dim=-1) + if isinstance(cond, list): + cond = sp.VarLenTensor.from_tensor_list(cond) + + h = self.input_layer(x) + h = manual_cast(h, self.dtype) + t_emb = self.t_embedder(t) + if self.share_mod: + t_emb = self.adaLN_modulation(t_emb) + t_emb = manual_cast(t_emb, self.dtype) + cond = manual_cast(cond, self.dtype) + + if self.pe_mode == "ape": + pe = self.pos_embedder(h.coords[:, 1:]) + h = h + manual_cast(pe, self.dtype) + for block in self.blocks: + h = block(h, t_emb, cond) + + h = manual_cast(h, x.dtype) + h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) + h = self.out_layer(h) + return h + + +class ElasticSLatFlowModel(SparseTransformerElasticMixin, SLatFlowModel): + """ + SLat Flow Model with elastic memory management. + Used for training with low VRAM. + """ + pass diff --git a/trellis2/modules/__pycache__/image_feature_extractor.cpython-311.pyc b/trellis2/modules/__pycache__/image_feature_extractor.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93b2af1377368c009041d9ccdaeca3b9afd4b225 Binary files /dev/null and b/trellis2/modules/__pycache__/image_feature_extractor.cpython-311.pyc differ diff --git a/trellis2/modules/__pycache__/norm.cpython-311.pyc b/trellis2/modules/__pycache__/norm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2b9596b4f44ac597f1988125afd8d665a0d6c96 Binary files /dev/null and b/trellis2/modules/__pycache__/norm.cpython-311.pyc differ diff --git a/trellis2/modules/__pycache__/spatial.cpython-311.pyc b/trellis2/modules/__pycache__/spatial.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5604f0994fc90649104850a210b3fc2041ab260b Binary files /dev/null and b/trellis2/modules/__pycache__/spatial.cpython-311.pyc differ diff --git a/trellis2/modules/__pycache__/utils.cpython-311.pyc b/trellis2/modules/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f9b01b9137b86f4aa2af211496015d3420497cf Binary files /dev/null and b/trellis2/modules/__pycache__/utils.cpython-311.pyc differ diff --git a/trellis2/modules/attention/__init__.py b/trellis2/modules/attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..99f938980e50e54b1357d9093f7ed1be23a905e1 --- /dev/null +++ b/trellis2/modules/attention/__init__.py @@ -0,0 +1,3 @@ +from .full_attn import * +from .modules import * +from .rope import * diff --git a/trellis2/modules/attention/__pycache__/__init__.cpython-311.pyc b/trellis2/modules/attention/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e952bb64fa85975e3d99bbb22311508399d2a0f Binary files /dev/null and b/trellis2/modules/attention/__pycache__/__init__.cpython-311.pyc differ diff --git a/trellis2/modules/attention/__pycache__/config.cpython-311.pyc b/trellis2/modules/attention/__pycache__/config.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ca48d84c9af8714adaf6c9785d675edf03e2d68 Binary files /dev/null and b/trellis2/modules/attention/__pycache__/config.cpython-311.pyc differ diff --git a/trellis2/modules/attention/__pycache__/full_attn.cpython-311.pyc b/trellis2/modules/attention/__pycache__/full_attn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d948958ac95acea9c2b04a1e3a3e1832d5219b90 Binary files /dev/null and b/trellis2/modules/attention/__pycache__/full_attn.cpython-311.pyc differ diff --git a/trellis2/modules/attention/__pycache__/modules.cpython-311.pyc b/trellis2/modules/attention/__pycache__/modules.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d068f5149e5f756acbe3effd4b312cc88eacbd8 Binary files /dev/null and b/trellis2/modules/attention/__pycache__/modules.cpython-311.pyc differ diff --git a/trellis2/modules/attention/__pycache__/rope.cpython-311.pyc b/trellis2/modules/attention/__pycache__/rope.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1eb6be204f7921156cba32dd79374ad5430a8baa Binary files /dev/null and b/trellis2/modules/attention/__pycache__/rope.cpython-311.pyc differ diff --git a/trellis2/modules/attention/config.py b/trellis2/modules/attention/config.py new file mode 100644 index 0000000000000000000000000000000000000000..a3d037a9c0ed346fd74c4568001ada1650770c53 --- /dev/null +++ b/trellis2/modules/attention/config.py @@ -0,0 +1,34 @@ +from typing import * +import sys + +# Default to 'sdpa' (PyTorch's built-in) on Windows since flash_attn isn't available +BACKEND = 'sdpa' if sys.platform == 'win32' else 'flash_attn' +DEBUG = False + +def __from_env(): + import os + + global BACKEND + global DEBUG + + env_attn_backend = os.environ.get('ATTN_BACKEND') + env_attn_debug = os.environ.get('ATTN_DEBUG') + + if env_attn_backend is not None and env_attn_backend in ['xformers', 'flash_attn', 'flash_attn_3', 'sdpa', 'naive']: + BACKEND = env_attn_backend + if env_attn_debug is not None: + DEBUG = env_attn_debug == '1' + + print(f"[ATTENTION] Using backend: {BACKEND}") + + +__from_env() + + +def set_backend(backend: Literal['xformers', 'flash_attn']): + global BACKEND + BACKEND = backend + +def set_debug(debug: bool): + global DEBUG + DEBUG = debug diff --git a/trellis2/modules/attention/full_attn.py b/trellis2/modules/attention/full_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..7cb20f906ecdb1bd1b148255d53dbf4f5628ee40 --- /dev/null +++ b/trellis2/modules/attention/full_attn.py @@ -0,0 +1,145 @@ +from typing import * +import torch +import math +from . import config + + +__all__ = [ + 'scaled_dot_product_attention', +] + + +def _naive_sdpa(q, k, v): + """ + Naive implementation of scaled dot product attention. + """ + q = q.permute(0, 2, 1, 3) # [N, H, L, C] + k = k.permute(0, 2, 1, 3) # [N, H, L, C] + v = v.permute(0, 2, 1, 3) # [N, H, L, C] + scale_factor = 1 / math.sqrt(q.size(-1)) + attn_weight = q @ k.transpose(-2, -1) * scale_factor + attn_weight = torch.softmax(attn_weight, dim=-1) + out = attn_weight @ v + out = out.permute(0, 2, 1, 3) # [N, L, H, C] + return out + + +@overload +def scaled_dot_product_attention(qkv: torch.Tensor) -> torch.Tensor: + """ + Apply scaled dot product attention. + + Args: + qkv (torch.Tensor): A [N, L, 3, H, C] tensor containing Qs, Ks, and Vs. + """ + ... + +@overload +def scaled_dot_product_attention(q: torch.Tensor, kv: torch.Tensor) -> torch.Tensor: + """ + Apply scaled dot product attention. + + Args: + q (torch.Tensor): A [N, L, H, C] tensor containing Qs. + kv (torch.Tensor): A [N, L, 2, H, C] tensor containing Ks and Vs. + """ + ... + +@overload +def scaled_dot_product_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + """ + Apply scaled dot product attention. + + Args: + q (torch.Tensor): A [N, L, H, Ci] tensor containing Qs. + k (torch.Tensor): A [N, L, H, Ci] tensor containing Ks. + v (torch.Tensor): A [N, L, H, Co] tensor containing Vs. + + Note: + k and v are assumed to have the same coordinate map. + """ + ... + +def scaled_dot_product_attention(*args, **kwargs): + arg_names_dict = { + 1: ['qkv'], + 2: ['q', 'kv'], + 3: ['q', 'k', 'v'] + } + num_all_args = len(args) + len(kwargs) + assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3" + for key in arg_names_dict[num_all_args][len(args):]: + assert key in kwargs, f"Missing argument {key}" + + if num_all_args == 1: + qkv = args[0] if len(args) > 0 else kwargs['qkv'] + assert len(qkv.shape) == 5 and qkv.shape[2] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, L, 3, H, C]" + device = qkv.device + + elif num_all_args == 2: + q = args[0] if len(args) > 0 else kwargs['q'] + kv = args[1] if len(args) > 1 else kwargs['kv'] + assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}" + assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]" + assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]" + device = q.device + + elif num_all_args == 3: + q = args[0] if len(args) > 0 else kwargs['q'] + k = args[1] if len(args) > 1 else kwargs['k'] + v = args[2] if len(args) > 2 else kwargs['v'] + assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}" + assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]" + assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]" + assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]" + device = q.device + + if config.BACKEND == 'xformers': + if 'xops' not in globals(): + import xformers.ops as xops + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + out = xops.memory_efficient_attention(q, k, v) + elif config.BACKEND == 'flash_attn': + if 'flash_attn' not in globals(): + import flash_attn + if num_all_args == 1: + out = flash_attn.flash_attn_qkvpacked_func(qkv) + elif num_all_args == 2: + out = flash_attn.flash_attn_kvpacked_func(q, kv) + elif num_all_args == 3: + out = flash_attn.flash_attn_func(q, k, v) + elif config.BACKEND == 'flash_attn_3': + if 'flash_attn_3' not in globals(): + import flash_attn_interface as flash_attn_3 + if num_all_args == 1: + out = flash_attn_3.flash_attn_qkvpacked_func(qkv) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + out = flash_attn_3.flash_attn_func(q, k, v) + elif num_all_args == 3: + out = flash_attn_3.flash_attn_func(q, k, v) + elif config.BACKEND == 'sdpa': + if 'sdpa' not in globals(): + from torch.nn.functional import scaled_dot_product_attention as sdpa + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + q = q.permute(0, 2, 1, 3) # [N, H, L, C] + k = k.permute(0, 2, 1, 3) # [N, H, L, C] + v = v.permute(0, 2, 1, 3) # [N, H, L, C] + out = sdpa(q, k, v) # [N, H, L, C] + out = out.permute(0, 2, 1, 3) # [N, L, H, C] + elif config.BACKEND == 'naive': + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + out = _naive_sdpa(q, k, v) + else: + raise ValueError(f"Unknown attention module: {config.BACKEND}") + + return out diff --git a/trellis2/modules/attention/modules.py b/trellis2/modules/attention/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..9889f827fa05742f9cc8fb79bc262262fd799107 --- /dev/null +++ b/trellis2/modules/attention/modules.py @@ -0,0 +1,102 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +from .full_attn import scaled_dot_product_attention +from .rope import RotaryPositionEmbedder + + +class MultiHeadRMSNorm(nn.Module): + def __init__(self, dim: int, heads: int): + super().__init__() + self.scale = dim ** 0.5 + self.gamma = nn.Parameter(torch.ones(heads, dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype) + + +class MultiHeadAttention(nn.Module): + def __init__( + self, + channels: int, + num_heads: int, + ctx_channels: Optional[int]=None, + type: Literal["self", "cross"] = "self", + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + qkv_bias: bool = True, + use_rope: bool = False, + rope_freq: Tuple[float, float] = (1.0, 10000.0), + qk_rms_norm: bool = False, + ): + super().__init__() + assert channels % num_heads == 0 + assert type in ["self", "cross"], f"Invalid attention type: {type}" + assert attn_mode in ["full", "windowed"], f"Invalid attention mode: {attn_mode}" + assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention" + + if attn_mode == "windowed": + raise NotImplementedError("Windowed attention is not yet implemented") + + self.channels = channels + self.head_dim = channels // num_heads + self.ctx_channels = ctx_channels if ctx_channels is not None else channels + self.num_heads = num_heads + self._type = type + self.attn_mode = attn_mode + self.window_size = window_size + self.shift_window = shift_window + self.use_rope = use_rope + self.qk_rms_norm = qk_rms_norm + + if self._type == "self": + self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias) + else: + self.to_q = nn.Linear(channels, channels, bias=qkv_bias) + self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias) + + if self.qk_rms_norm: + self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads) + self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads) + + self.to_out = nn.Linear(channels, channels) + + def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + B, L, C = x.shape + if self._type == "self": + qkv = self.to_qkv(x) + qkv = qkv.reshape(B, L, 3, self.num_heads, -1) + + if self.attn_mode == "full": + if self.qk_rms_norm or self.use_rope: + q, k, v = qkv.unbind(dim=2) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k = self.k_rms_norm(k) + if self.use_rope: + assert phases is not None, "Phases must be provided for RoPE" + q = RotaryPositionEmbedder.apply_rotary_embedding(q, phases) + k = RotaryPositionEmbedder.apply_rotary_embedding(k, phases) + h = scaled_dot_product_attention(q, k, v) + else: + h = scaled_dot_product_attention(qkv) + elif self.attn_mode == "windowed": + raise NotImplementedError("Windowed attention is not yet implemented") + else: + Lkv = context.shape[1] + q = self.to_q(x) + kv = self.to_kv(context) + q = q.reshape(B, L, self.num_heads, -1) + kv = kv.reshape(B, Lkv, 2, self.num_heads, -1) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k, v = kv.unbind(dim=2) + k = self.k_rms_norm(k) + h = scaled_dot_product_attention(q, k, v) + else: + h = scaled_dot_product_attention(q, kv) + h = h.reshape(B, L, -1) + h = self.to_out(h) + return h diff --git a/trellis2/modules/attention/rope.py b/trellis2/modules/attention/rope.py new file mode 100644 index 0000000000000000000000000000000000000000..93f83e249eb23a36a5d512515435850c989228ef --- /dev/null +++ b/trellis2/modules/attention/rope.py @@ -0,0 +1,48 @@ +from typing import * +import torch +import torch.nn as nn + + +class RotaryPositionEmbedder(nn.Module): + def __init__( + self, + head_dim: int, + dim: int = 3, + rope_freq: Tuple[float, float] = (1.0, 10000.0) + ): + super().__init__() + assert head_dim % 2 == 0, "Head dim must be divisible by 2" + self.head_dim = head_dim + self.dim = dim + self.rope_freq = rope_freq + self.freq_dim = head_dim // 2 // dim + self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim + self.freqs = rope_freq[0] / (rope_freq[1] ** (self.freqs)) + + def _get_phases(self, indices: torch.Tensor) -> torch.Tensor: + self.freqs = self.freqs.to(indices.device) + phases = torch.outer(indices, self.freqs) + phases = torch.polar(torch.ones_like(phases), phases) + return phases + + @staticmethod + def apply_rotary_embedding(x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor: + x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + x_rotated = x_complex * phases.unsqueeze(-2) + x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype) + return x_embed + + def forward(self, indices: torch.Tensor) -> torch.Tensor: + """ + Args: + indices (torch.Tensor): [..., N, C] tensor of spatial positions + """ + assert indices.shape[-1] == self.dim, f"Last dim of indices must be {self.dim}" + phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1) + if phases.shape[-1] < self.head_dim // 2: + padn = self.head_dim // 2 - phases.shape[-1] + phases = torch.cat([phases, torch.polar( + torch.ones(*phases.shape[:-1], padn, device=phases.device), + torch.zeros(*phases.shape[:-1], padn, device=phases.device) + )], dim=-1) + return phases \ No newline at end of file diff --git a/trellis2/modules/image_feature_extractor.py b/trellis2/modules/image_feature_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..b0699465a7882aacb2541f44bac1e550920b9dc8 --- /dev/null +++ b/trellis2/modules/image_feature_extractor.py @@ -0,0 +1,123 @@ +from typing import * +import torch +import torch.nn.functional as F +from torchvision import transforms +from transformers import DINOv3ViTModel +import numpy as np +from PIL import Image + + +class DinoV2FeatureExtractor: + """ + Feature extractor for DINOv2 models. + """ + def __init__(self, model_name: str): + self.model_name = model_name + self.model = torch.hub.load('facebookresearch/dinov2', model_name, pretrained=True) + self.model.eval() + self.transform = transforms.Compose([ + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + + def to(self, device): + self.model.to(device) + + def cuda(self): + self.model.cuda() + + def cpu(self): + self.model.cpu() + + @torch.no_grad() + def __call__(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor: + """ + Extract features from the image. + + Args: + image: A batch of images as a tensor of shape (B, C, H, W) or a list of PIL images. + + Returns: + A tensor of shape (B, N, D) where N is the number of patches and D is the feature dimension. + """ + if isinstance(image, torch.Tensor): + assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)" + elif isinstance(image, list): + assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images" + image = [i.resize((518, 518), Image.LANCZOS) for i in image] + image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image] + image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image] + image = torch.stack(image).cuda() + else: + raise ValueError(f"Unsupported type of image: {type(image)}") + + image = self.transform(image).cuda() + features = self.model(image, is_training=True)['x_prenorm'] + patchtokens = F.layer_norm(features, features.shape[-1:]) + return patchtokens + + +class DinoV3FeatureExtractor: + """ + Feature extractor for DINOv3 models. + """ + def __init__(self, model_name: str, image_size=512): + self.model_name = model_name + # Try loading with local_files_only first (for cached gated models) + try: + self.model = DINOv3ViTModel.from_pretrained(model_name, local_files_only=True) + except Exception: + # Fall back to remote loading + self.model = DINOv3ViTModel.from_pretrained(model_name) + self.model.eval() + self.image_size = image_size + self.transform = transforms.Compose([ + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + + def to(self, device): + self.model.to(device) + + def cuda(self): + self.model.cuda() + + def cpu(self): + self.model.cpu() + + def extract_features(self, image: torch.Tensor) -> torch.Tensor: + image = image.to(self.model.embeddings.patch_embeddings.weight.dtype) + hidden_states = self.model.embeddings(image, bool_masked_pos=None) + position_embeddings = self.model.rope_embeddings(image) + + for i, layer_module in enumerate(self.model.layer): + hidden_states = layer_module( + hidden_states, + position_embeddings=position_embeddings, + ) + + return F.layer_norm(hidden_states, hidden_states.shape[-1:]) + + @torch.no_grad() + def __call__(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor: + """ + Extract features from the image. + + Args: + image: A batch of images as a tensor of shape (B, C, H, W) or a list of PIL images. + + Returns: + A tensor of shape (B, N, D) where N is the number of patches and D is the feature dimension. + """ + if isinstance(image, torch.Tensor): + assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)" + elif isinstance(image, list): + assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images" + image = [i.resize((self.image_size, self.image_size), Image.LANCZOS) for i in image] + image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image] + image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image] + image = torch.stack(image).cuda() + else: + raise ValueError(f"Unsupported type of image: {type(image)}") + + image = self.transform(image).cuda() + features = self.extract_features(image) + return features diff --git a/trellis2/modules/norm.py b/trellis2/modules/norm.py new file mode 100644 index 0000000000000000000000000000000000000000..658b4f6e2f10732a3be58fa76ca67a0b7df2ce22 --- /dev/null +++ b/trellis2/modules/norm.py @@ -0,0 +1,32 @@ +import torch +import torch.nn as nn +from .utils import manual_cast + + +class LayerNorm32(nn.LayerNorm): + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_dtype = x.dtype + x = manual_cast(x, torch.float32) + o = super().forward(x) + return manual_cast(o, x_dtype) + + +class GroupNorm32(nn.GroupNorm): + """ + A GroupNorm layer that converts to float32 before the forward pass. + """ + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_dtype = x.dtype + x = manual_cast(x, torch.float32) + o = super().forward(x) + return manual_cast(o, x_dtype) + + +class ChannelLayerNorm32(LayerNorm32): + def forward(self, x: torch.Tensor) -> torch.Tensor: + DIM = x.dim() + x = x.permute(0, *range(2, DIM), 1).contiguous() + x = super().forward(x) + x = x.permute(0, DIM-1, *range(1, DIM-1)).contiguous() + return x + \ No newline at end of file diff --git a/trellis2/modules/sparse/__init__.py b/trellis2/modules/sparse/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cb1ece28469d1f0970fd7b799a35d04bd5413d00 --- /dev/null +++ b/trellis2/modules/sparse/__init__.py @@ -0,0 +1,69 @@ +from . import config +import importlib + +__attributes = { + 'VarLenTensor': 'basic', + 'varlen_cat': 'basic', + 'varlen_unbind': 'basic', + 'SparseTensor': 'basic', + 'sparse_cat': 'basic', + 'sparse_unbind': 'basic', + 'SparseGroupNorm': 'norm', + 'SparseLayerNorm': 'norm', + 'SparseGroupNorm32': 'norm', + 'SparseLayerNorm32': 'norm', + 'SparseReLU': 'nonlinearity', + 'SparseSiLU': 'nonlinearity', + 'SparseGELU': 'nonlinearity', + 'SparseActivation': 'nonlinearity', + 'SparseLinear': 'linear', + 'sparse_scaled_dot_product_attention': 'attention', + 'SerializeMode': 'attention', + 'sparse_serialized_scaled_dot_product_self_attention': 'attention', + 'sparse_windowed_scaled_dot_product_self_attention': 'attention', + 'sparse_windowed_scaled_dot_product_cross_attention': 'attention', + 'SparseRotaryPositionEmbedder': 'attention', + 'SparseMultiHeadAttention': 'attention', + 'SparseConv3d': 'conv', + 'SparseInverseConv3d': 'conv', + 'SparseDownsample': 'spatial', + 'SparseUpsample': 'spatial', + 'SparseSubdivide': 'spatial', + 'SparseSpatial2Channel': 'spatial', + 'SparseChannel2Spatial': 'spatial', + 'sparse_nearest_interpolate': 'spatial', + 'sparse_trilinear_interpolate': 'spatial', + 'encode_seq': 'serialize', + 'decode_seq': 'serialize', +} + +__submodules = ['transformer', 'conv'] + +__all__ = list(__attributes.keys()) + __submodules + +def __getattr__(name): + if name not in globals(): + if name in __attributes: + module_name = __attributes[name] + module = importlib.import_module(f".{module_name}", __name__) + globals()[name] = getattr(module, name) + elif name in __submodules: + module = importlib.import_module(f".{name}", __name__) + globals()[name] = module + else: + raise AttributeError(f"module {__name__} has no attribute {name}") + return globals()[name] + + +# For Pylance +if __name__ == '__main__': + from .basic import * + from .norm import * + from .nonlinearity import * + from .linear import * + from .attention import * + from .conv import * + from .spatial import * + from .serialize import * + import transformer + import conv diff --git a/trellis2/modules/sparse/__pycache__/__init__.cpython-311.pyc b/trellis2/modules/sparse/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..256bc904cf5b1cc9ea7d043c1c8d82fa949a4a04 Binary files /dev/null and b/trellis2/modules/sparse/__pycache__/__init__.cpython-311.pyc differ diff --git a/trellis2/modules/sparse/__pycache__/basic.cpython-311.pyc b/trellis2/modules/sparse/__pycache__/basic.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30c0d336a8b3dd906bf2436181086da18a91e6e5 Binary files /dev/null and b/trellis2/modules/sparse/__pycache__/basic.cpython-311.pyc differ diff --git a/trellis2/modules/sparse/__pycache__/config.cpython-311.pyc b/trellis2/modules/sparse/__pycache__/config.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba8b545d67f5b0b2c228545b90e5ad3604db3b87 Binary files /dev/null and b/trellis2/modules/sparse/__pycache__/config.cpython-311.pyc differ diff --git a/trellis2/modules/sparse/__pycache__/linear.cpython-311.pyc b/trellis2/modules/sparse/__pycache__/linear.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..725389f7662f92068e8da2af10914ea7318bc2d4 Binary files /dev/null and b/trellis2/modules/sparse/__pycache__/linear.cpython-311.pyc differ diff --git a/trellis2/modules/sparse/__pycache__/nonlinearity.cpython-311.pyc b/trellis2/modules/sparse/__pycache__/nonlinearity.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c4523ee84ba592e924ef41222c3425ebe21d5dc Binary files /dev/null and b/trellis2/modules/sparse/__pycache__/nonlinearity.cpython-311.pyc differ diff --git a/trellis2/modules/sparse/attention/__init__.py b/trellis2/modules/sparse/attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..47f84747d008c0eba9b3986c723f5803c9e74060 --- /dev/null +++ b/trellis2/modules/sparse/attention/__init__.py @@ -0,0 +1,3 @@ +from .full_attn import * +from .windowed_attn import * +from .modules import * diff --git a/trellis2/modules/sparse/attention/__pycache__/__init__.cpython-311.pyc b/trellis2/modules/sparse/attention/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59188898b067353036ddb32bdcad524c215a0ed4 Binary files /dev/null and b/trellis2/modules/sparse/attention/__pycache__/__init__.cpython-311.pyc differ diff --git a/trellis2/modules/sparse/attention/__pycache__/full_attn.cpython-311.pyc b/trellis2/modules/sparse/attention/__pycache__/full_attn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8c9c76bb710cfdcc6e96151569d843e062f7933 Binary files /dev/null and b/trellis2/modules/sparse/attention/__pycache__/full_attn.cpython-311.pyc differ diff --git a/trellis2/modules/sparse/attention/__pycache__/modules.cpython-311.pyc b/trellis2/modules/sparse/attention/__pycache__/modules.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4ff71a653c54c4456f8c844f572a4de51734dde Binary files /dev/null and b/trellis2/modules/sparse/attention/__pycache__/modules.cpython-311.pyc differ diff --git a/trellis2/modules/sparse/attention/__pycache__/rope.cpython-311.pyc b/trellis2/modules/sparse/attention/__pycache__/rope.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..073676f09485442a55c9f57ec0a310d911d3b01c Binary files /dev/null and b/trellis2/modules/sparse/attention/__pycache__/rope.cpython-311.pyc differ diff --git a/trellis2/modules/sparse/attention/__pycache__/windowed_attn.cpython-311.pyc b/trellis2/modules/sparse/attention/__pycache__/windowed_attn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..57d5599c6cb21550e048dc716e4b73d179418928 Binary files /dev/null and b/trellis2/modules/sparse/attention/__pycache__/windowed_attn.cpython-311.pyc differ diff --git a/trellis2/modules/sparse/attention/full_attn.py b/trellis2/modules/sparse/attention/full_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..371ced63063c1936fbc61293ed68ab362152e9ae --- /dev/null +++ b/trellis2/modules/sparse/attention/full_attn.py @@ -0,0 +1,243 @@ +from typing import * +import torch +from .. import VarLenTensor +from .. import config + + +__all__ = [ + 'sparse_scaled_dot_product_attention', +] + + +@overload +def sparse_scaled_dot_product_attention(qkv: VarLenTensor) -> VarLenTensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + qkv (VarLenTensor): A [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: VarLenTensor, kv: Union[VarLenTensor, torch.Tensor]) -> VarLenTensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (VarLenTensor): A [N, *, H, C] sparse tensor containing Qs. + kv (VarLenTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor or a [N, L, 2, H, C] dense tensor containing Ks and Vs. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: torch.Tensor, kv: VarLenTensor) -> torch.Tensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (torch.Tensor): A [N, L, H, C] dense tensor containing Qs. + kv (VarLenTensor): A [N, *, 2, H, C] sparse tensor containing Ks and Vs. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: VarLenTensor, k: VarLenTensor, v: VarLenTensor) -> VarLenTensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (VarLenTensor): A [N, *, H, Ci] sparse tensor containing Qs. + k (VarLenTensor): A [N, *, H, Ci] sparse tensor containing Ks. + v (VarLenTensor): A [N, *, H, Co] sparse tensor containing Vs. + + Note: + k and v are assumed to have the same coordinate map. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: VarLenTensor, k: torch.Tensor, v: torch.Tensor) -> VarLenTensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (VarLenTensor): A [N, *, H, Ci] sparse tensor containing Qs. + k (torch.Tensor): A [N, L, H, Ci] dense tensor containing Ks. + v (torch.Tensor): A [N, L, H, Co] dense tensor containing Vs. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: torch.Tensor, k: VarLenTensor, v: VarLenTensor) -> torch.Tensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (torch.Tensor): A [N, L, H, Ci] dense tensor containing Qs. + k (VarLenTensor): A [N, *, H, Ci] sparse tensor containing Ks. + v (VarLenTensor): A [N, *, H, Co] sparse tensor containing Vs. + """ + ... + +def sparse_scaled_dot_product_attention(*args, **kwargs): + arg_names_dict = { + 1: ['qkv'], + 2: ['q', 'kv'], + 3: ['q', 'k', 'v'] + } + num_all_args = len(args) + len(kwargs) + assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3" + for key in arg_names_dict[num_all_args][len(args):]: + assert key in kwargs, f"Missing argument {key}" + + if num_all_args == 1: + qkv = args[0] if len(args) > 0 else kwargs['qkv'] + assert isinstance(qkv, VarLenTensor), f"qkv must be a VarLenTensor, got {type(qkv)}" + assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]" + device = qkv.device + + s = qkv + q_seqlen = [qkv.layout[i].stop - qkv.layout[i].start for i in range(qkv.shape[0])] + kv_seqlen = q_seqlen + qkv = qkv.feats # [T, 3, H, C] + + elif num_all_args == 2: + q = args[0] if len(args) > 0 else kwargs['q'] + kv = args[1] if len(args) > 1 else kwargs['kv'] + assert isinstance(q, VarLenTensor) and isinstance(kv, (VarLenTensor, torch.Tensor)) or \ + isinstance(q, torch.Tensor) and isinstance(kv, VarLenTensor), \ + f"Invalid types, got {type(q)} and {type(kv)}" + assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}" + device = q.device + + if isinstance(q, VarLenTensor): + assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, C]" + s = q + q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])] + q = q.feats # [T_Q, H, C] + else: + assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]" + s = None + N, L, H, C = q.shape + q_seqlen = [L] * N + q = q.reshape(N * L, H, C) # [T_Q, H, C] + + if isinstance(kv, VarLenTensor): + assert len(kv.shape) == 4 and kv.shape[1] == 2, f"Invalid shape for kv, got {kv.shape}, expected [N, *, 2, H, C]" + kv_seqlen = [kv.layout[i].stop - kv.layout[i].start for i in range(kv.shape[0])] + kv = kv.feats # [T_KV, 2, H, C] + else: + assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]" + N, L, _, H, C = kv.shape + kv_seqlen = [L] * N + kv = kv.reshape(N * L, 2, H, C) # [T_KV, 2, H, C] + + elif num_all_args == 3: + q = args[0] if len(args) > 0 else kwargs['q'] + k = args[1] if len(args) > 1 else kwargs['k'] + v = args[2] if len(args) > 2 else kwargs['v'] + assert isinstance(q, VarLenTensor) and isinstance(k, (VarLenTensor, torch.Tensor)) and type(k) == type(v) or \ + isinstance(q, torch.Tensor) and isinstance(k, VarLenTensor) and isinstance(v, VarLenTensor), \ + f"Invalid types, got {type(q)}, {type(k)}, and {type(v)}" + assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}" + device = q.device + + if isinstance(q, VarLenTensor): + assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, Ci]" + s = q + q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])] + q = q.feats # [T_Q, H, Ci] + else: + assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]" + s = None + N, L, H, CI = q.shape + q_seqlen = [L] * N + q = q.reshape(N * L, H, CI) # [T_Q, H, Ci] + + if isinstance(k, VarLenTensor): + assert len(k.shape) == 3, f"Invalid shape for k, got {k.shape}, expected [N, *, H, Ci]" + assert len(v.shape) == 3, f"Invalid shape for v, got {v.shape}, expected [N, *, H, Co]" + kv_seqlen = [k.layout[i].stop - k.layout[i].start for i in range(k.shape[0])] + k = k.feats # [T_KV, H, Ci] + v = v.feats # [T_KV, H, Co] + else: + assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]" + assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]" + N, L, H, CI, CO = *k.shape, v.shape[-1] + kv_seqlen = [L] * N + k = k.reshape(N * L, H, CI) # [T_KV, H, Ci] + v = v.reshape(N * L, H, CO) # [T_KV, H, Co] + + if config.ATTN == 'sdpa': + # PyTorch native scaled_dot_product_attention - process each batch element separately + from torch.nn.functional import scaled_dot_product_attention as sdpa_fn + if num_all_args == 1: + q, k, v = qkv.unbind(dim=1) # each is [T, H, C] + elif num_all_args == 2: + k, v = kv.unbind(dim=1) # each is [T_KV, H, C] + # Process each batch element separately using cumulative sequence lengths + cu_seqlens_q = [0] + list(torch.cumsum(torch.tensor(q_seqlen), dim=0).tolist()) + cu_seqlens_kv = [0] + list(torch.cumsum(torch.tensor(kv_seqlen), dim=0).tolist()) + outputs = [] + for i in range(len(q_seqlen)): + q_start, q_end = cu_seqlens_q[i], cu_seqlens_q[i + 1] + kv_start, kv_end = cu_seqlens_kv[i], cu_seqlens_kv[i + 1] + # [seq_len, H, C] -> [1, H, seq_len, C] for sdpa + qi = q[q_start:q_end].permute(1, 0, 2).unsqueeze(0) # [1, H, Lq, C] + ki = k[kv_start:kv_end].permute(1, 0, 2).unsqueeze(0) # [1, H, Lkv, C] + vi = v[kv_start:kv_end].permute(1, 0, 2).unsqueeze(0) # [1, H, Lkv, C] + oi = sdpa_fn(qi, ki, vi) # [1, H, Lq, C] + # [1, H, Lq, C] -> [Lq, H, C] + oi = oi.squeeze(0).permute(1, 0, 2) + outputs.append(oi) + out = torch.cat(outputs, dim=0) # [T, H, C] + elif config.ATTN == 'xformers': + if 'xops' not in globals(): + import xformers.ops as xops + if num_all_args == 1: + q, k, v = qkv.unbind(dim=1) + elif num_all_args == 2: + k, v = kv.unbind(dim=1) + q = q.unsqueeze(0) + k = k.unsqueeze(0) + v = v.unsqueeze(0) + mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen) + out = xops.memory_efficient_attention(q, k, v, mask)[0] + elif config.ATTN == 'flash_attn': + if 'flash_attn' not in globals(): + import flash_attn + cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device) + if num_all_args in [2, 3]: + cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device) + if num_all_args == 1: + out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens_q, max(q_seqlen)) + elif num_all_args == 2: + out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) + elif num_all_args == 3: + out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) + elif config.ATTN == 'flash_attn_3': + if 'flash_attn_3' not in globals(): + import flash_attn_interface as flash_attn_3 + cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device) + if num_all_args == 1: + q, k, v = qkv.unbind(dim=1) + cu_seqlens_kv = cu_seqlens_q.clone() + max_q_seqlen = max_kv_seqlen = max(q_seqlen) + elif num_all_args == 2: + k, v = kv.unbind(dim=1) + cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device) + max_q_seqlen = max(q_seqlen) + max_kv_seqlen = max(kv_seqlen) + elif num_all_args == 3: + cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device) + max_q_seqlen = max(q_seqlen) + max_kv_seqlen = max(kv_seqlen) + out = flash_attn_3.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_q_seqlen, max_kv_seqlen) + else: + raise ValueError(f"Unknown attention module: {config.ATTN}") + + if s is not None: + return s.replace(out) + else: + return out.reshape(N, L, H, -1) diff --git a/trellis2/modules/sparse/attention/modules.py b/trellis2/modules/sparse/attention/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..531a8c4530da875f5bc27bfbf35862dc7083ad24 --- /dev/null +++ b/trellis2/modules/sparse/attention/modules.py @@ -0,0 +1,141 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +from .. import VarLenTensor, SparseTensor +from .full_attn import sparse_scaled_dot_product_attention +from .windowed_attn import sparse_windowed_scaled_dot_product_self_attention +from .rope import SparseRotaryPositionEmbedder + + +class SparseMultiHeadRMSNorm(nn.Module): + def __init__(self, dim: int, heads: int): + super().__init__() + self.scale = dim ** 0.5 + self.gamma = nn.Parameter(torch.ones(heads, dim)) + + def forward(self, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]: + x_type = x.dtype + x = x.float() + if isinstance(x, VarLenTensor): + x = x.replace(F.normalize(x.feats, dim=-1) * self.gamma * self.scale) + else: + x = F.normalize(x, dim=-1) * self.gamma * self.scale + return x.to(x_type) + + +class SparseMultiHeadAttention(nn.Module): + def __init__( + self, + channels: int, + num_heads: int, + ctx_channels: Optional[int] = None, + type: Literal["self", "cross"] = "self", + attn_mode: Literal["full", "windowed", "double_windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + qkv_bias: bool = True, + use_rope: bool = False, + rope_freq: Tuple[int, int] = (1.0, 10000.0), + qk_rms_norm: bool = False, + ): + super().__init__() + assert channels % num_heads == 0 + assert type in ["self", "cross"], f"Invalid attention type: {type}" + assert attn_mode in ["full", "windowed", "double_windowed"], f"Invalid attention mode: {attn_mode}" + assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention" + assert type == "self" or use_rope is False, "Rotary position embeddings only supported for self-attention" + if attn_mode == 'double_windowed': + assert window_size % 2 == 0, "Window size must be even for double windowed attention" + assert num_heads % 2 == 0, "Number of heads must be even for double windowed attention" + self.channels = channels + self.head_dim = channels // num_heads + self.ctx_channels = ctx_channels if ctx_channels is not None else channels + self.num_heads = num_heads + self._type = type + self.attn_mode = attn_mode + self.window_size = window_size + self.shift_window = shift_window + self.use_rope = use_rope + self.qk_rms_norm = qk_rms_norm + + if self._type == "self": + self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias) + else: + self.to_q = nn.Linear(channels, channels, bias=qkv_bias) + self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias) + + if self.qk_rms_norm: + self.q_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads) + self.k_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads) + + self.to_out = nn.Linear(channels, channels) + + if use_rope: + self.rope = SparseRotaryPositionEmbedder(self.head_dim, rope_freq=rope_freq) + + @staticmethod + def _linear(module: nn.Linear, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]: + if isinstance(x, VarLenTensor): + return x.replace(module(x.feats)) + else: + return module(x) + + @staticmethod + def _reshape_chs(x: Union[VarLenTensor, torch.Tensor], shape: Tuple[int, ...]) -> Union[VarLenTensor, torch.Tensor]: + if isinstance(x, VarLenTensor): + return x.reshape(*shape) + else: + return x.reshape(*x.shape[:2], *shape) + + def _fused_pre(self, x: Union[VarLenTensor, torch.Tensor], num_fused: int) -> Union[VarLenTensor, torch.Tensor]: + if isinstance(x, VarLenTensor): + x_feats = x.feats.unsqueeze(0) + else: + x_feats = x + x_feats = x_feats.reshape(*x_feats.shape[:2], num_fused, self.num_heads, -1) + return x.replace(x_feats.squeeze(0)) if isinstance(x, VarLenTensor) else x_feats + + def forward(self, x: SparseTensor, context: Optional[Union[VarLenTensor, torch.Tensor]] = None) -> SparseTensor: + if self._type == "self": + qkv = self._linear(self.to_qkv, x) + qkv = self._fused_pre(qkv, num_fused=3) + if self.qk_rms_norm or self.use_rope: + q, k, v = qkv.unbind(dim=-3) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k = self.k_rms_norm(k) + if self.use_rope: + q, k = self.rope(q, k) + qkv = qkv.replace(torch.stack([q.feats, k.feats, v.feats], dim=1)) + if self.attn_mode == "full": + h = sparse_scaled_dot_product_attention(qkv) + elif self.attn_mode == "windowed": + h = sparse_windowed_scaled_dot_product_self_attention( + qkv, self.window_size, shift_window=self.shift_window + ) + elif self.attn_mode == "double_windowed": + qkv0 = qkv.replace(qkv.feats[:, :, self.num_heads//2:]) + qkv1 = qkv.replace(qkv.feats[:, :, :self.num_heads//2]) + h0 = sparse_windowed_scaled_dot_product_self_attention( + qkv0, self.window_size, shift_window=(0, 0, 0) + ) + h1 = sparse_windowed_scaled_dot_product_self_attention( + qkv1, self.window_size, shift_window=tuple([self.window_size//2] * 3) + ) + h = qkv.replace(torch.cat([h0.feats, h1.feats], dim=1)) + else: + q = self._linear(self.to_q, x) + q = self._reshape_chs(q, (self.num_heads, -1)) + kv = self._linear(self.to_kv, context) + kv = self._fused_pre(kv, num_fused=2) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k, v = kv.unbind(dim=-3) + k = self.k_rms_norm(k) + h = sparse_scaled_dot_product_attention(q, k, v) + else: + h = sparse_scaled_dot_product_attention(q, kv) + h = self._reshape_chs(h, (-1,)) + h = self._linear(self.to_out, h) + return h diff --git a/trellis2/modules/sparse/attention/rope.py b/trellis2/modules/sparse/attention/rope.py new file mode 100644 index 0000000000000000000000000000000000000000..de7a253707122cff39795939b5ca7420bdd3f88e --- /dev/null +++ b/trellis2/modules/sparse/attention/rope.py @@ -0,0 +1,58 @@ +from typing import * +import torch +import torch.nn as nn +from ..basic import SparseTensor + + +class SparseRotaryPositionEmbedder(nn.Module): + def __init__( + self, + head_dim: int, + dim: int = 3, + rope_freq: Tuple[float, float] = (1.0, 10000.0) + ): + super().__init__() + assert head_dim % 2 == 0, "Head dim must be divisible by 2" + self.head_dim = head_dim + self.dim = dim + self.rope_freq = rope_freq + self.freq_dim = head_dim // 2 // dim + self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim + self.freqs = rope_freq[0] / (rope_freq[1] ** (self.freqs)) + + def _get_phases(self, indices: torch.Tensor) -> torch.Tensor: + self.freqs = self.freqs.to(indices.device) + phases = torch.outer(indices, self.freqs) + phases = torch.polar(torch.ones_like(phases), phases) + return phases + + def _rotary_embedding(self, x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor: + x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + x_rotated = x_complex * phases.unsqueeze(-2) + x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype) + return x_embed + + def forward(self, q: SparseTensor, k: Optional[SparseTensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + q (SparseTensor): [..., N, H, D] tensor of queries + k (SparseTensor): [..., N, H, D] tensor of keys + """ + assert q.coords.shape[-1] == self.dim + 1, "Last dimension of coords must be equal to dim+1" + phases_cache_name = f'rope_phase_{self.dim}d_freq{self.rope_freq[0]}-{self.rope_freq[1]}_hd{self.head_dim}' + phases = q.get_spatial_cache(phases_cache_name) + if phases is None: + coords = q.coords[..., 1:] + phases = self._get_phases(coords.reshape(-1)).reshape(*coords.shape[:-1], -1) + if phases.shape[-1] < self.head_dim // 2: + padn = self.head_dim // 2 - phases.shape[-1] + phases = torch.cat([phases, torch.polar( + torch.ones(*phases.shape[:-1], padn, device=phases.device), + torch.zeros(*phases.shape[:-1], padn, device=phases.device) + )], dim=-1) + q.register_spatial_cache(phases_cache_name, phases) + q_embed = q.replace(self._rotary_embedding(q.feats, phases)) + if k is None: + return q_embed + k_embed = k.replace(self._rotary_embedding(k.feats, phases)) + return q_embed, k_embed \ No newline at end of file diff --git a/trellis2/modules/sparse/attention/windowed_attn.py b/trellis2/modules/sparse/attention/windowed_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..b9d8d52b07185102b6681e83f0157257fde2b388 --- /dev/null +++ b/trellis2/modules/sparse/attention/windowed_attn.py @@ -0,0 +1,190 @@ +from typing import * +import torch +import math +from .. import SparseTensor +from .. import config + + +__all__ = [ + 'sparse_windowed_scaled_dot_product_self_attention', + 'sparse_windowed_scaled_dot_product_cross_attention', +] + + +def calc_window_partition( + tensor: SparseTensor, + window_size: Union[int, Tuple[int, ...]], + shift_window: Union[int, Tuple[int, ...]] = 0, +) -> Tuple[torch.Tensor, torch.Tensor, List[int], List[int]]: + """ + Calculate serialization and partitioning for a set of coordinates. + + Args: + tensor (SparseTensor): The input tensor. + window_size (int): The window size to use. + shift_window (Tuple[int, ...]): The shift of serialized coordinates. + + Returns: + (torch.Tensor): Forwards indices. + (torch.Tensor): Backwards indices. + (torch.Tensor): Sequence lengths. + (dict): Attn func args. + """ + DIM = tensor.coords.shape[1] - 1 + shift_window = (shift_window,) * DIM if isinstance(shift_window, int) else shift_window + window_size = (window_size,) * DIM if isinstance(window_size, int) else window_size + shifted_coords = tensor.coords.clone().detach() + shifted_coords[:, 1:] += torch.tensor(shift_window, device=tensor.device, dtype=torch.int32).unsqueeze(0) + + MAX_COORDS = [i + j for i, j in zip(tensor.spatial_shape, shift_window)] + NUM_WINDOWS = [math.ceil((mc + 1) / ws) for mc, ws in zip(MAX_COORDS, window_size)] + OFFSET = torch.cumprod(torch.tensor([1] + NUM_WINDOWS[::-1]), dim=0).tolist()[::-1] + + shifted_coords[:, 1:] //= torch.tensor(window_size, device=tensor.device, dtype=torch.int32).unsqueeze(0) + shifted_indices = (shifted_coords * torch.tensor(OFFSET, device=tensor.device, dtype=torch.int32).unsqueeze(0)).sum(dim=1) + fwd_indices = torch.argsort(shifted_indices) + bwd_indices = torch.empty_like(fwd_indices) + bwd_indices[fwd_indices] = torch.arange(fwd_indices.shape[0], device=tensor.device) + seq_lens = torch.bincount(shifted_indices) + mask = seq_lens != 0 + seq_lens = seq_lens[mask] + + if config.ATTN == 'xformers': + if 'xops' not in globals(): + import xformers.ops as xops + attn_func_args = { + 'attn_bias': xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens) + } + elif config.ATTN == 'flash_attn': + attn_func_args = { + 'cu_seqlens': torch.cat([torch.tensor([0], device=tensor.device), torch.cumsum(seq_lens, dim=0)], dim=0).int(), + 'max_seqlen': torch.max(seq_lens) + } + + return fwd_indices, bwd_indices, seq_lens, attn_func_args + + +def sparse_windowed_scaled_dot_product_self_attention( + qkv: SparseTensor, + window_size: int, + shift_window: Tuple[int, int, int] = (0, 0, 0) +) -> SparseTensor: + """ + Apply windowed scaled dot product self attention to a sparse tensor. + + Args: + qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs. + window_size (int): The window size to use. + shift_window (Tuple[int, int, int]): The shift of serialized coordinates. + + Returns: + (SparseTensor): [N, *, H, C] sparse tensor containing the output features. + """ + assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]" + + serialization_spatial_cache_name = f'windowed_attention_{window_size}_{shift_window}' + serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name) + if serialization_spatial_cache is None: + fwd_indices, bwd_indices, seq_lens, attn_func_args = calc_window_partition(qkv, window_size, shift_window) + qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, attn_func_args)) + else: + fwd_indices, bwd_indices, seq_lens, attn_func_args = serialization_spatial_cache + + qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C] + + if config.DEBUG: + start = 0 + qkv_coords = qkv.coords[fwd_indices] + for i in range(len(seq_lens)): + seq_coords = qkv_coords[start:start+seq_lens[i]] + assert (seq_coords[:, 1:].max(dim=0).values - seq_coords[:, 1:].min(dim=0).values < window_size).all(), \ + f"SparseWindowedScaledDotProductSelfAttention: window size exceeded" + start += seq_lens[i] + + if config.ATTN == 'xformers': + if 'xops' not in globals(): + import xformers.ops as xops + q, k, v = qkv_feats.unbind(dim=1) # [M, H, C] + q = q.unsqueeze(0) # [1, M, H, C] + k = k.unsqueeze(0) # [1, M, H, C] + v = v.unsqueeze(0) # [1, M, H, C] + out = xops.memory_efficient_attention(q, k, v, **attn_func_args)[0] # [M, H, C] + elif config.ATTN == 'flash_attn': + if 'flash_attn' not in globals(): + import flash_attn + out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, **attn_func_args) # [M, H, C] + + out = out[bwd_indices] # [T, H, C] + + if config.DEBUG: + qkv_coords = qkv_coords[bwd_indices] + assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch" + + return qkv.replace(out) + + +def sparse_windowed_scaled_dot_product_cross_attention( + q: SparseTensor, + kv: SparseTensor, + q_window_size: int, + kv_window_size: int, + q_shift_window: Tuple[int, int, int] = (0, 0, 0), + kv_shift_window: Tuple[int, int, int] = (0, 0, 0), +) -> SparseTensor: + """ + Apply windowed scaled dot product cross attention to two sparse tensors. + + Args: + q (SparseTensor): [N, *, H, C] sparse tensor containing Qs. + kv (SparseTensor): [N, *, 2, H, C] sparse tensor containing Ks and Vs. + q_window_size (int): The window size to use for Qs. + kv_window_size (int): The window size to use for Ks and Vs. + q_shift_window (Tuple[int, int, int]): The shift of serialized coordinates for Qs. + kv_shift_window (Tuple[int, int, int]): The shift of serialized coordinates for Ks and Vs. + + Returns: + (SparseTensor): [N, *, H, C] sparse tensor containing the output features. + """ + assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, C]" + assert len(kv.shape) == 4 and kv.shape[1] == 2, f"Invalid shape for kv, got {kv.shape}, expected [N, *, 2, H, C]" + + q_serialization_spatial_cache_name = f'windowed_attention_{q_window_size}_{q_shift_window}' + q_serialization_spatial_cache = q.get_spatial_cache(q_serialization_spatial_cache_name) + if q_serialization_spatial_cache is None: + q_fwd_indices, q_bwd_indices, q_seq_lens, q_attn_func_args = calc_window_partition(q, q_window_size, q_shift_window) + q.register_spatial_cache(q_serialization_spatial_cache_name, (q_fwd_indices, q_bwd_indices, q_seq_lens, q_attn_func_args)) + else: + q_fwd_indices, q_bwd_indices, q_seq_lens, q_attn_func_args = q_serialization_spatial_cache + kv_serialization_spatial_cache_name = f'windowed_attention_{kv_window_size}_{kv_shift_window}' + kv_serialization_spatial_cache = kv.get_spatial_cache(kv_serialization_spatial_cache_name) + if kv_serialization_spatial_cache is None: + kv_fwd_indices, kv_bwd_indices, kv_seq_lens, kv_attn_func_args = calc_window_partition(kv, kv_window_size, kv_shift_window) + kv.register_spatial_cache(kv_serialization_spatial_cache_name, (kv_fwd_indices, kv_bwd_indices, kv_seq_lens, kv_attn_func_args)) + else: + kv_fwd_indices, kv_bwd_indices, kv_seq_lens, kv_attn_func_args = kv_serialization_spatial_cache + + assert len(q_seq_lens) == len(kv_seq_lens), "Number of sequences in q and kv must match" + + q_feats = q.feats[q_fwd_indices] # [M, H, C] + kv_feats = kv.feats[kv_fwd_indices] # [M, 2, H, C] + + if config.ATTN == 'xformers': + if 'xops' not in globals(): + import xformers.ops as xops + k, v = kv_feats.unbind(dim=1) # [M, H, C] + q = q.unsqueeze(0) # [1, M, H, C] + k = k.unsqueeze(0) # [1, M, H, C] + v = v.unsqueeze(0) # [1, M, H, C] + mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seq_lens, kv_seq_lens) + out = xops.memory_efficient_attention(q, k, v, attn_bias=mask)[0] # [M, H, C] + elif config.ATTN == 'flash_attn': + if 'flash_attn' not in globals(): + import flash_attn + out = flash_attn.flash_attn_varlen_kvpacked_func(q_feats, kv_feats, + cu_seqlens_q=q_attn_func_args['cu_seqlens'], cu_seqlens_k=kv_attn_func_args['cu_seqlens'], + max_seqlen_q=q_attn_func_args['max_seqlen'], max_seqlen_k=kv_attn_func_args['max_seqlen'], + ) # [M, H, C] + + out = out[q_bwd_indices] # [T, H, C] + + return q.replace(out) diff --git a/trellis2/modules/sparse/basic.py b/trellis2/modules/sparse/basic.py new file mode 100644 index 0000000000000000000000000000000000000000..06106ae4c4553904f8c9c900d538e3c02a88092d --- /dev/null +++ b/trellis2/modules/sparse/basic.py @@ -0,0 +1,836 @@ +from typing import * +from fractions import Fraction +import torch +from . import config + + +__all__ = [ + 'VarLenTensor', + 'varlen_cat', + 'varlen_unbind', + 'SparseTensor', + 'sparse_cat', + 'sparse_unbind', +] + + +class VarLenTensor: + """ + Sequential tensor with variable length. + + Args: + feats (torch.Tensor): Features of the varlen tensor. + layout (List[slice]): Layout of the varlen tensor for each batch + """ + def __init__(self, feats: torch.Tensor, layout: List[slice]=None): + self.feats = feats + self.layout = layout if layout is not None else [slice(0, feats.shape[0])] + self._cache = {} + + @staticmethod + def layout_from_seqlen(seqlen: list) -> List[slice]: + """ + Create a layout from a tensor of sequence lengths. + """ + layout = [] + start = 0 + for l in seqlen: + layout.append(slice(start, start + l)) + start += l + return layout + + @staticmethod + def from_tensor_list(tensor_list: List[torch.Tensor]) -> 'VarLenTensor': + """ + Create a VarLenTensor from a list of tensors. + """ + feats = torch.cat(tensor_list, dim=0) + layout = [] + start = 0 + for tensor in tensor_list: + layout.append(slice(start, start + tensor.shape[0])) + start += tensor.shape[0] + return VarLenTensor(feats, layout) + + def to_tensor_list(self) -> List[torch.Tensor]: + """ + Convert a VarLenTensor to a list of tensors. + """ + tensor_list = [] + for s in self.layout: + tensor_list.append(self.feats[s]) + return tensor_list + + def __len__(self) -> int: + return len(self.layout) + + @property + def shape(self) -> torch.Size: + return torch.Size([len(self.layout), *self.feats.shape[1:]]) + + def dim(self) -> int: + return len(self.shape) + + @property + def ndim(self) -> int: + return self.dim() + + @property + def dtype(self): + return self.feats.dtype + + @property + def device(self): + return self.feats.device + + @property + def seqlen(self) -> torch.LongTensor: + if 'seqlen' not in self._cache: + self._cache['seqlen'] = torch.tensor([l.stop - l.start for l in self.layout], dtype=torch.long, device=self.device) + return self._cache['seqlen'] + + @property + def cum_seqlen(self) -> torch.LongTensor: + if 'cum_seqlen' not in self._cache: + self._cache['cum_seqlen'] = torch.cat([ + torch.tensor([0], dtype=torch.long, device=self.device), + self.seqlen.cumsum(dim=0) + ], dim=0) + return self._cache['cum_seqlen'] + + @property + def batch_boardcast_map(self) -> torch.LongTensor: + """ + Get the broadcast map for the varlen tensor. + """ + if 'batch_boardcast_map' not in self._cache: + self._cache['batch_boardcast_map'] = torch.repeat_interleave( + torch.arange(len(self.layout), device=self.device), + self.seqlen, + ) + return self._cache['batch_boardcast_map'] + + @overload + def to(self, dtype: torch.dtype, *, non_blocking: bool = False, copy: bool = False) -> 'VarLenTensor': ... + + @overload + def to(self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None, *, non_blocking: bool = False, copy: bool = False) -> 'VarLenTensor': ... + + def to(self, *args, **kwargs) -> 'VarLenTensor': + device = None + dtype = None + if len(args) == 2: + device, dtype = args + elif len(args) == 1: + if isinstance(args[0], torch.dtype): + dtype = args[0] + else: + device = args[0] + if 'dtype' in kwargs: + assert dtype is None, "to() received multiple values for argument 'dtype'" + dtype = kwargs['dtype'] + if 'device' in kwargs: + assert device is None, "to() received multiple values for argument 'device'" + device = kwargs['device'] + non_blocking = kwargs.get('non_blocking', False) + copy = kwargs.get('copy', False) + + new_feats = self.feats.to(device=device, dtype=dtype, non_blocking=non_blocking, copy=copy) + return self.replace(new_feats) + + def type(self, dtype): + new_feats = self.feats.type(dtype) + return self.replace(new_feats) + + def cpu(self) -> 'VarLenTensor': + new_feats = self.feats.cpu() + return self.replace(new_feats) + + def cuda(self) -> 'VarLenTensor': + new_feats = self.feats.cuda() + return self.replace(new_feats) + + def half(self) -> 'VarLenTensor': + new_feats = self.feats.half() + return self.replace(new_feats) + + def float(self) -> 'VarLenTensor': + new_feats = self.feats.float() + return self.replace(new_feats) + + def detach(self) -> 'VarLenTensor': + new_feats = self.feats.detach() + return self.replace(new_feats) + + def reshape(self, *shape) -> 'VarLenTensor': + new_feats = self.feats.reshape(self.feats.shape[0], *shape) + return self.replace(new_feats) + + def unbind(self, dim: int) -> List['VarLenTensor']: + return varlen_unbind(self, dim) + + def replace(self, feats: torch.Tensor) -> 'VarLenTensor': + new_tensor = VarLenTensor( + feats=feats, + layout=self.layout, + ) + new_tensor._cache = self._cache + return new_tensor + + def to_dense(self, max_length=None) -> torch.Tensor: + """ + Convert a VarLenTensor to a dense representation without for-loop. + + Returns: + dense (torch.Tensor): (N, L, C) dense tensor + mask (torch.BoolTensor): (N, L) mask indicating valid positions + """ + N = len(self) + L = max_length or self.seqlen.max().item() + spatial = self.feats.shape[1:] + idx = torch.arange(L, device=self.device).unsqueeze(0).expand(N, L) + mask = (idx < self.seqlen.unsqueeze(1)) + mapping = mask.reshape(-1).cumsum(dim=0) - 1 + dense = self.feats[mapping] + dense = dense.reshape(N, L, *spatial) + return dense, mask + + def __neg__(self) -> 'VarLenTensor': + return self.replace(-self.feats) + + def __elemwise__(self, other: Union[torch.Tensor, 'VarLenTensor'], op: callable) -> 'VarLenTensor': + if isinstance(other, torch.Tensor): + try: + other = torch.broadcast_to(other, self.shape) + other = other[self.batch_boardcast_map] + except: + pass + if isinstance(other, VarLenTensor): + other = other.feats + new_feats = op(self.feats, other) + new_tensor = self.replace(new_feats) + return new_tensor + + def __add__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, torch.add) + + def __radd__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, torch.add) + + def __sub__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, torch.sub) + + def __rsub__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, lambda x, y: torch.sub(y, x)) + + def __mul__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, torch.mul) + + def __rmul__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, torch.mul) + + def __truediv__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, torch.div) + + def __rtruediv__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, lambda x, y: torch.div(y, x)) + + def __getitem__(self, idx): + if isinstance(idx, int): + idx = [idx] + elif isinstance(idx, slice): + idx = range(*idx.indices(self.shape[0])) + elif isinstance(idx, list): + assert all(isinstance(i, int) for i in idx), f"Only integer indices are supported: {idx}" + elif isinstance(idx, torch.Tensor): + if idx.dtype == torch.bool: + assert idx.shape == (self.shape[0],), f"Invalid index shape: {idx.shape}" + idx = idx.nonzero().squeeze(1) + elif idx.dtype in [torch.int32, torch.int64]: + assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}" + else: + raise ValueError(f"Unknown index type: {idx.dtype}") + else: + raise ValueError(f"Unknown index type: {type(idx)}") + + new_feats = [] + new_layout = [] + start = 0 + for new_idx, old_idx in enumerate(idx): + new_feats.append(self.feats[self.layout[old_idx]]) + new_layout.append(slice(start, start + len(new_feats[-1]))) + start += len(new_feats[-1]) + new_feats = torch.cat(new_feats, dim=0).contiguous() + new_tensor = VarLenTensor(feats=new_feats, layout=new_layout) + return new_tensor + + def reduce(self, op: str, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: + if isinstance(dim, int): + dim = (dim,) + + if op =='mean': + red = self.feats.mean(dim=dim, keepdim=keepdim) + elif op =='sum': + red = self.feats.sum(dim=dim, keepdim=keepdim) + elif op == 'prod': + red = self.feats.prod(dim=dim, keepdim=keepdim) + else: + raise ValueError(f"Unsupported reduce operation: {op}") + + if dim is None or 0 in dim: + return red + + red = torch.segment_reduce(red, reduce=op, lengths=self.seqlen) + return red + + def mean(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: + return self.reduce(op='mean', dim=dim, keepdim=keepdim) + + def sum(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: + return self.reduce(op='sum', dim=dim, keepdim=keepdim) + + def prod(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: + return self.reduce(op='prod', dim=dim, keepdim=keepdim) + + def std(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: + mean = self.mean(dim=dim, keepdim=True) + mean2 = self.replace(self.feats ** 2).mean(dim=dim, keepdim=True) + std = (mean2 - mean ** 2).sqrt() + return std + + def __repr__(self) -> str: + return f"VarLenTensor(shape={self.shape}, dtype={self.dtype}, device={self.device})" + + +def varlen_cat(inputs: List[VarLenTensor], dim: int = 0) -> VarLenTensor: + """ + Concatenate a list of varlen tensors. + + Args: + inputs (List[VarLenTensor]): List of varlen tensors to concatenate. + """ + if dim == 0: + new_feats = torch.cat([input.feats for input in inputs], dim=0) + start = 0 + new_layout = [] + for input in inputs: + for l in input.layout: + new_layout.append(slice(start, start + l.stop - l.start)) + start += l.stop - l.start + output = VarLenTensor(feats=new_feats, layout=new_layout) + else: + feats = torch.cat([input.feats for input in inputs], dim=dim) + output = inputs[0].replace(feats) + + return output + + +def varlen_unbind(input: VarLenTensor, dim: int) -> Union[List[VarLenTensor]]: + """ + Unbind a varlen tensor along a dimension. + + Args: + input (VarLenTensor): Varlen tensor to unbind. + dim (int): Dimension to unbind. + """ + if dim == 0: + return [input[i] for i in range(len(input))] + else: + feats = input.feats.unbind(dim) + return [input.replace(f) for f in feats] + + +class SparseTensor(VarLenTensor): + """ + Sparse tensor with support for both torchsparse and spconv backends. + + Parameters: + - feats (torch.Tensor): Features of the sparse tensor. + - coords (torch.Tensor): Coordinates of the sparse tensor. + - shape (torch.Size): Shape of the sparse tensor. + - layout (List[slice]): Layout of the sparse tensor for each batch + - data (SparseTensorData): Sparse tensor data used for convolusion + + NOTE: + - Data corresponding to a same batch should be contiguous. + - Coords should be in [0, 1023] + """ + SparseTensorData = None + + @overload + def __init__(self, feats: torch.Tensor, coords: torch.Tensor, shape: Optional[torch.Size] = None, **kwargs): ... + + @overload + def __init__(self, data, shape: Optional[torch.Size] = None, **kwargs): ... + + def __init__(self, *args, **kwargs): + # Lazy import of sparse tensor backend + if self.SparseTensorData is None: + import importlib + if config.CONV == 'torchsparse': + self.SparseTensorData = importlib.import_module('torchsparse').SparseTensor + elif config.CONV == 'spconv': + self.SparseTensorData = importlib.import_module('spconv.pytorch').SparseConvTensor + + method_id = 0 + if len(args) != 0: + method_id = 0 if isinstance(args[0], torch.Tensor) else 1 + else: + method_id = 1 if 'data' in kwargs else 0 + + if method_id == 0: + feats, coords, shape = args + (None,) * (3 - len(args)) + if 'feats' in kwargs: + feats = kwargs['feats'] + del kwargs['feats'] + if 'coords' in kwargs: + coords = kwargs['coords'] + del kwargs['coords'] + if 'shape' in kwargs: + shape = kwargs['shape'] + del kwargs['shape'] + + if config.CONV == 'torchsparse': + self.data = self.SparseTensorData(feats, coords, **kwargs) + elif config.CONV == 'spconv': + spatial_shape = list(coords.max(0)[0] + 1) + self.data = self.SparseTensorData(feats.reshape(feats.shape[0], -1), coords, spatial_shape[1:], spatial_shape[0], **kwargs) + self.data._features = feats + else: + self.data = { + 'feats': feats, + 'coords': coords, + } + elif method_id == 1: + data, shape = args + (None,) * (2 - len(args)) + if 'data' in kwargs: + data = kwargs['data'] + del kwargs['data'] + if 'shape' in kwargs: + shape = kwargs['shape'] + del kwargs['shape'] + + self.data = data + + self._shape = shape + self._scale = kwargs.get('scale', (Fraction(1, 1), Fraction(1, 1), Fraction(1, 1))) + self._spatial_cache = kwargs.get('spatial_cache', {}) + + if config.DEBUG: + try: + assert self.feats.shape[0] == self.coords.shape[0], f"Invalid feats shape: {self.feats.shape}, coords shape: {self.coords.shape}" + assert self.shape == self.__cal_shape(self.feats, self.coords), f"Invalid shape: {self.shape}" + assert self.layout == self.__cal_layout(self.coords, self.shape[0]), f"Invalid layout: {self.layout}" + for i in range(self.shape[0]): + assert torch.all(self.coords[self.layout[i], 0] == i), f"The data of batch {i} is not contiguous" + except Exception as e: + print('Debugging information:') + print(f"- Shape: {self.shape}") + print(f"- Layout: {self.layout}") + print(f"- Scale: {self._scale}") + print(f"- Coords: {self.coords}") + raise e + + @staticmethod + def from_tensor_list(feats_list: List[torch.Tensor], coords_list: List[torch.Tensor]) -> 'SparseTensor': + """ + Create a SparseTensor from a list of tensors. + """ + feats = torch.cat(feats_list, dim=0) + coords = [] + for i, coord in enumerate(coords_list): + coord = torch.cat([torch.full_like(coord[:, :1], i), coord[:, 1:]], dim=1) + coords.append(coord) + coords = torch.cat(coords, dim=0) + return SparseTensor(feats, coords) + + def to_tensor_list(self) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Convert a SparseTensor to list of tensors. + """ + feats_list = [] + coords_list = [] + for s in self.layout: + feats_list.append(self.feats[s]) + coords_list.append(self.coords[s]) + return feats_list, coords_list + + def __len__(self) -> int: + return len(self.layout) + + def __cal_shape(self, feats, coords): + shape = [] + shape.append(coords[:, 0].max().item() + 1) + shape.extend([*feats.shape[1:]]) + return torch.Size(shape) + + def __cal_layout(self, coords, batch_size): + seq_len = torch.bincount(coords[:, 0], minlength=batch_size) + offset = torch.cumsum(seq_len, dim=0) + layout = [slice((offset[i] - seq_len[i]).item(), offset[i].item()) for i in range(batch_size)] + return layout + + def __cal_spatial_shape(self, coords): + return torch.Size((coords[:, 1:].max(0)[0] + 1).tolist()) + + @property + def shape(self) -> torch.Size: + if self._shape is None: + self._shape = self.__cal_shape(self.feats, self.coords) + return self._shape + + @property + def layout(self) -> List[slice]: + layout = self.get_spatial_cache('layout') + if layout is None: + layout = self.__cal_layout(self.coords, self.shape[0]) + self.register_spatial_cache('layout', layout) + return layout + + @property + def spatial_shape(self) -> torch.Size: + spatial_shape = self.get_spatial_cache('shape') + if spatial_shape is None: + spatial_shape = self.__cal_spatial_shape(self.coords) + self.register_spatial_cache('shape', spatial_shape) + return spatial_shape + + @property + def feats(self) -> torch.Tensor: + if config.CONV == 'torchsparse': + return self.data.F + elif config.CONV == 'spconv': + return self.data.features + else: + return self.data['feats'] + + @feats.setter + def feats(self, value: torch.Tensor): + if config.CONV == 'torchsparse': + self.data.F = value + elif config.CONV == 'spconv': + self.data.features = value + else: + self.data['feats'] = value + + @property + def coords(self) -> torch.Tensor: + if config.CONV == 'torchsparse': + return self.data.C + elif config.CONV == 'spconv': + return self.data.indices + else: + return self.data['coords'] + + @coords.setter + def coords(self, value: torch.Tensor): + if config.CONV == 'torchsparse': + self.data.C = value + elif config.CONV == 'spconv': + self.data.indices = value + else: + self.data['coords'] = value + + @property + def dtype(self): + return self.feats.dtype + + @property + def device(self): + return self.feats.device + + @property + def seqlen(self) -> torch.LongTensor: + seqlen = self.get_spatial_cache('seqlen') + if seqlen is None: + seqlen = torch.tensor([l.stop - l.start for l in self.layout], dtype=torch.long, device=self.device) + self.register_spatial_cache('seqlen', seqlen) + return seqlen + + @property + def cum_seqlen(self) -> torch.LongTensor: + cum_seqlen = self.get_spatial_cache('cum_seqlen') + if cum_seqlen is None: + cum_seqlen = torch.cat([ + torch.tensor([0], dtype=torch.long, device=self.device), + self.seqlen.cumsum(dim=0) + ], dim=0) + self.register_spatial_cache('cum_seqlen', cum_seqlen) + return cum_seqlen + + @property + def batch_boardcast_map(self) -> torch.LongTensor: + """ + Get the broadcast map for the varlen tensor. + """ + batch_boardcast_map = self.get_spatial_cache('batch_boardcast_map') + if batch_boardcast_map is None: + batch_boardcast_map = torch.repeat_interleave( + torch.arange(len(self.layout), device=self.device), + self.seqlen, + ) + self.register_spatial_cache('batch_boardcast_map', batch_boardcast_map) + return batch_boardcast_map + + @overload + def to(self, dtype: torch.dtype, *, non_blocking: bool = False, copy: bool = False) -> 'SparseTensor': ... + + @overload + def to(self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None, *, non_blocking: bool = False, copy: bool = False) -> 'SparseTensor': ... + + def to(self, *args, **kwargs) -> 'SparseTensor': + device = None + dtype = None + if len(args) == 2: + device, dtype = args + elif len(args) == 1: + if isinstance(args[0], torch.dtype): + dtype = args[0] + else: + device = args[0] + if 'dtype' in kwargs: + assert dtype is None, "to() received multiple values for argument 'dtype'" + dtype = kwargs['dtype'] + if 'device' in kwargs: + assert device is None, "to() received multiple values for argument 'device'" + device = kwargs['device'] + non_blocking = kwargs.get('non_blocking', False) + copy = kwargs.get('copy', False) + + new_feats = self.feats.to(device=device, dtype=dtype, non_blocking=non_blocking, copy=copy) + new_coords = self.coords.to(device=device, non_blocking=non_blocking, copy=copy) + return self.replace(new_feats, new_coords) + + def type(self, dtype): + new_feats = self.feats.type(dtype) + return self.replace(new_feats) + + def cpu(self) -> 'SparseTensor': + new_feats = self.feats.cpu() + new_coords = self.coords.cpu() + return self.replace(new_feats, new_coords) + + def cuda(self) -> 'SparseTensor': + new_feats = self.feats.cuda() + new_coords = self.coords.cuda() + return self.replace(new_feats, new_coords) + + def half(self) -> 'SparseTensor': + new_feats = self.feats.half() + return self.replace(new_feats) + + def float(self) -> 'SparseTensor': + new_feats = self.feats.float() + return self.replace(new_feats) + + def detach(self) -> 'SparseTensor': + new_coords = self.coords.detach() + new_feats = self.feats.detach() + return self.replace(new_feats, new_coords) + + def reshape(self, *shape) -> 'SparseTensor': + new_feats = self.feats.reshape(self.feats.shape[0], *shape) + return self.replace(new_feats) + + def unbind(self, dim: int) -> List['SparseTensor']: + return sparse_unbind(self, dim) + + def replace(self, feats: torch.Tensor, coords: Optional[torch.Tensor] = None) -> 'SparseTensor': + if config.CONV == 'torchsparse': + new_data = self.SparseTensorData( + feats=feats, + coords=self.data.coords if coords is None else coords, + stride=self.data.stride, + spatial_range=self.data.spatial_range, + ) + new_data._caches = self.data._caches + elif config.CONV == 'spconv': + new_data = self.SparseTensorData( + self.data.features.reshape(self.data.features.shape[0], -1), + self.data.indices, + self.data.spatial_shape, + self.data.batch_size, + self.data.grid, + self.data.voxel_num, + self.data.indice_dict + ) + new_data._features = feats + new_data.benchmark = self.data.benchmark + new_data.benchmark_record = self.data.benchmark_record + new_data.thrust_allocator = self.data.thrust_allocator + new_data._timer = self.data._timer + new_data.force_algo = self.data.force_algo + new_data.int8_scale = self.data.int8_scale + if coords is not None: + new_data.indices = coords + else: + new_data = { + 'feats': feats, + 'coords': self.data['coords'] if coords is None else coords, + } + new_tensor = SparseTensor( + new_data, + shape=torch.Size([self._shape[0]] + list(feats.shape[1:])) if self._shape is not None else None, + scale=self._scale, + spatial_cache=self._spatial_cache + ) + return new_tensor + + def to_dense(self) -> torch.Tensor: + if config.CONV == 'torchsparse': + return self.data.dense() + elif config.CONV == 'spconv': + return self.data.dense() + else: + spatial_shape = self.spatial_shape + ret = torch.zeros(*self.shape, *spatial_shape, dtype=self.dtype, device=self.device) + idx = [self.coords[:, 0], slice(None)] + self.coords[:, 1:].unbind(1) + ret[tuple(idx)] = self.feats + return ret + + @staticmethod + def full(aabb, dim, value, dtype=torch.float32, device=None) -> 'SparseTensor': + N, C = dim + x = torch.arange(aabb[0], aabb[3] + 1) + y = torch.arange(aabb[1], aabb[4] + 1) + z = torch.arange(aabb[2], aabb[5] + 1) + coords = torch.stack(torch.meshgrid(x, y, z, indexing='ij'), dim=-1).reshape(-1, 3) + coords = torch.cat([ + torch.arange(N).view(-1, 1).repeat(1, coords.shape[0]).view(-1, 1), + coords.repeat(N, 1), + ], dim=1).to(dtype=torch.int32, device=device) + feats = torch.full((coords.shape[0], C), value, dtype=dtype, device=device) + return SparseTensor(feats=feats, coords=coords) + + def __merge_sparse_cache(self, other: 'SparseTensor') -> dict: + new_cache = {} + for k in set(list(self._spatial_cache.keys()) + list(other._spatial_cache.keys())): + if k in self._spatial_cache: + new_cache[k] = self._spatial_cache[k] + if k in other._spatial_cache: + if k not in new_cache: + new_cache[k] = other._spatial_cache[k] + else: + new_cache[k].update(other._spatial_cache[k]) + return new_cache + + def __elemwise__(self, other: Union[torch.Tensor, VarLenTensor], op: callable) -> 'SparseTensor': + if isinstance(other, torch.Tensor): + try: + other = torch.broadcast_to(other, self.shape) + other = other[self.batch_boardcast_map] + except: + pass + if isinstance(other, VarLenTensor): + other = other.feats + new_feats = op(self.feats, other) + new_tensor = self.replace(new_feats) + if isinstance(other, SparseTensor): + new_tensor._spatial_cache = self.__merge_sparse_cache(other) + return new_tensor + + def __getitem__(self, idx): + if isinstance(idx, int): + idx = [idx] + elif isinstance(idx, slice): + idx = range(*idx.indices(self.shape[0])) + elif isinstance(idx, list): + assert all(isinstance(i, int) for i in idx), f"Only integer indices are supported: {idx}" + elif isinstance(idx, torch.Tensor): + if idx.dtype == torch.bool: + assert idx.shape == (self.shape[0],), f"Invalid index shape: {idx.shape}" + idx = idx.nonzero().squeeze(1) + elif idx.dtype in [torch.int32, torch.int64]: + assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}" + else: + raise ValueError(f"Unknown index type: {idx.dtype}") + else: + raise ValueError(f"Unknown index type: {type(idx)}") + + new_coords = [] + new_feats = [] + new_layout = [] + new_shape = torch.Size([len(idx)] + list(self.shape[1:])) + start = 0 + for new_idx, old_idx in enumerate(idx): + new_coords.append(self.coords[self.layout[old_idx]].clone()) + new_coords[-1][:, 0] = new_idx + new_feats.append(self.feats[self.layout[old_idx]]) + new_layout.append(slice(start, start + len(new_coords[-1]))) + start += len(new_coords[-1]) + new_coords = torch.cat(new_coords, dim=0).contiguous() + new_feats = torch.cat(new_feats, dim=0).contiguous() + new_tensor = SparseTensor(feats=new_feats, coords=new_coords, shape=new_shape) + new_tensor.register_spatial_cache('layout', new_layout) + return new_tensor + + def clear_spatial_cache(self) -> None: + """ + Clear all spatial caches. + """ + self._spatial_cache = {} + + def register_spatial_cache(self, key, value) -> None: + """ + Register a spatial cache. + The spatial cache can be any thing you want to cache. + The registery and retrieval of the cache is based on current scale. + """ + scale_key = str(self._scale) + if scale_key not in self._spatial_cache: + self._spatial_cache[scale_key] = {} + self._spatial_cache[scale_key][key] = value + + def get_spatial_cache(self, key=None): + """ + Get a spatial cache. + """ + scale_key = str(self._scale) + cur_scale_cache = self._spatial_cache.get(scale_key, {}) + if key is None: + return cur_scale_cache + return cur_scale_cache.get(key, None) + + def __repr__(self) -> str: + return f"SparseTensor(shape={self.shape}, dtype={self.dtype}, device={self.device})" + +def sparse_cat(inputs: List[SparseTensor], dim: int = 0) -> SparseTensor: + """ + Concatenate a list of sparse tensors. + + Args: + inputs (List[SparseTensor]): List of sparse tensors to concatenate. + """ + if dim == 0: + start = 0 + coords = [] + for input in inputs: + coords.append(input.coords.clone()) + coords[-1][:, 0] += start + start += input.shape[0] + coords = torch.cat(coords, dim=0) + feats = torch.cat([input.feats for input in inputs], dim=0) + output = SparseTensor( + coords=coords, + feats=feats, + ) + else: + feats = torch.cat([input.feats for input in inputs], dim=dim) + output = inputs[0].replace(feats) + + return output + + +def sparse_unbind(input: SparseTensor, dim: int) -> List[SparseTensor]: + """ + Unbind a sparse tensor along a dimension. + + Args: + input (SparseTensor): Sparse tensor to unbind. + dim (int): Dimension to unbind. + """ + if dim == 0: + return [input[i] for i in range(input.shape[0])] + else: + feats = input.feats.unbind(dim) + return [input.replace(f) for f in feats] diff --git a/trellis2/modules/sparse/config.py b/trellis2/modules/sparse/config.py new file mode 100644 index 0000000000000000000000000000000000000000..37ce0da39383251877add603ceff7b064ab7ffab --- /dev/null +++ b/trellis2/modules/sparse/config.py @@ -0,0 +1,45 @@ +from typing import * +import sys + +CONV = 'flex_gemm' +DEBUG = False +# Default to 'sdpa' (PyTorch's built-in) on Windows since flash_attn isn't available +ATTN = 'sdpa' if sys.platform == 'win32' else 'flash_attn' + +def __from_env(): + import os + + global CONV + global DEBUG + global ATTN + + env_sparse_conv_backend = os.environ.get('SPARSE_CONV_BACKEND') + env_sparse_debug = os.environ.get('SPARSE_DEBUG') + env_sparse_attn_backend = os.environ.get('SPARSE_ATTN_BACKEND') + if env_sparse_attn_backend is None: + env_sparse_attn_backend = os.environ.get('ATTN_BACKEND') + + if env_sparse_conv_backend is not None and env_sparse_conv_backend in ['none', 'spconv', 'torchsparse', 'flex_gemm']: + CONV = env_sparse_conv_backend + if env_sparse_debug is not None: + DEBUG = env_sparse_debug == '1' + if env_sparse_attn_backend is not None and env_sparse_attn_backend in ['xformers', 'flash_attn', 'flash_attn_3', 'sdpa']: + ATTN = env_sparse_attn_backend + + print(f"[SPARSE] Conv backend: {CONV}; Attention backend: {ATTN}") + + +__from_env() + + +def set_conv_backend(backend: Literal['none', 'spconv', 'torchsparse', 'flex_gemm']): + global CONV + CONV = backend + +def set_debug(debug: bool): + global DEBUG + DEBUG = debug + +def set_attn_backend(backend: Literal['xformers', 'flash_attn']): + global ATTN + ATTN = backend diff --git a/trellis2/modules/sparse/conv/__init__.py b/trellis2/modules/sparse/conv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..74e8feadfa0c2ce51768a913909f8010c301faa9 --- /dev/null +++ b/trellis2/modules/sparse/conv/__init__.py @@ -0,0 +1,2 @@ +from .conv import SparseConv3d, SparseInverseConv3d +from . import config diff --git a/trellis2/modules/sparse/conv/__pycache__/__init__.cpython-311.pyc b/trellis2/modules/sparse/conv/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b38d9ce03ffc3a639cd50000a2335ef9db888b4b Binary files /dev/null and b/trellis2/modules/sparse/conv/__pycache__/__init__.cpython-311.pyc differ diff --git a/trellis2/modules/sparse/conv/__pycache__/config.cpython-311.pyc b/trellis2/modules/sparse/conv/__pycache__/config.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f5d448e988b5a6bcf1bb0dde93351d8b8b433c87 Binary files /dev/null and b/trellis2/modules/sparse/conv/__pycache__/config.cpython-311.pyc differ diff --git a/trellis2/modules/sparse/conv/__pycache__/conv.cpython-311.pyc b/trellis2/modules/sparse/conv/__pycache__/conv.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24c11bf1ca8c3d1811c94462f92ef5bdc5bcbcc5 Binary files /dev/null and b/trellis2/modules/sparse/conv/__pycache__/conv.cpython-311.pyc differ diff --git a/trellis2/modules/sparse/conv/__pycache__/conv_flex_gemm.cpython-311.pyc b/trellis2/modules/sparse/conv/__pycache__/conv_flex_gemm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80ea2d44fe28e9c78b182fa1336074066e446ce8 Binary files /dev/null and b/trellis2/modules/sparse/conv/__pycache__/conv_flex_gemm.cpython-311.pyc differ diff --git a/trellis2/modules/sparse/conv/config.py b/trellis2/modules/sparse/conv/config.py new file mode 100644 index 0000000000000000000000000000000000000000..07b2e8b4b11ae313d2a00fd934ae4e27792db229 --- /dev/null +++ b/trellis2/modules/sparse/conv/config.py @@ -0,0 +1,3 @@ +SPCONV_ALGO = 'auto' # 'auto', 'implicit_gemm', 'native' +FLEX_GEMM_ALGO = 'masked_implicit_gemm_splitk' # 'explicit_gemm', 'implicit_gemm', 'implicit_gemm_splitk', 'masked_implicit_gemm', 'masked_implicit_gemm_splitk' +FLEX_GEMM_HASHMAP_RATIO = 2.0 # Ratio of hashmap size to input size diff --git a/trellis2/modules/sparse/conv/conv.py b/trellis2/modules/sparse/conv/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..4deb537d5811c06a296565f55626eec871a6d0e6 --- /dev/null +++ b/trellis2/modules/sparse/conv/conv.py @@ -0,0 +1,30 @@ +from .. import config +import importlib +import torch +import torch.nn as nn +from .. import SparseTensor + + +_backends = {} + + +class SparseConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None): + super(SparseConv3d, self).__init__() + if config.CONV not in _backends: + _backends[config.CONV] = importlib.import_module(f'..conv_{config.CONV}', __name__) + _backends[config.CONV].sparse_conv3d_init(self, in_channels, out_channels, kernel_size, stride, dilation, padding, bias, indice_key) + + def forward(self, x: SparseTensor) -> SparseTensor: + return _backends[config.CONV].sparse_conv3d_forward(self, x) + + +class SparseInverseConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): + super(SparseInverseConv3d, self).__init__() + if config.CONV not in _backends: + _backends[config.CONV] = importlib.import_module(f'..conv_{config.CONV}', __name__) + _backends[config.CONV].sparse_inverse_conv3d_init(self, in_channels, out_channels, kernel_size, stride, dilation, bias, indice_key) + + def forward(self, x: SparseTensor) -> SparseTensor: + return _backends[config.CONV].sparse_inverse_conv3d_forward(self, x) diff --git a/trellis2/modules/sparse/conv/conv_flex_gemm.py b/trellis2/modules/sparse/conv/conv_flex_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..373692ad8ec38c5bf908ca9d83b9596f580c6382 --- /dev/null +++ b/trellis2/modules/sparse/conv/conv_flex_gemm.py @@ -0,0 +1,68 @@ +import math +import torch +import torch.nn as nn +from .. import SparseTensor +from . import config +import flex_gemm +from flex_gemm.ops.spconv import sparse_submanifold_conv3d + + +def sparse_conv3d_init(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None): + assert stride == 1 and (padding is None), 'Currently flex_gemm implementation only support submanifold sparse convolution (stride=1, padding=None)' + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = tuple(kernel_size) if isinstance(kernel_size, (list, tuple)) else (kernel_size, ) * 3 + self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, ) * 3 + self.dilation = tuple(dilation) if isinstance(dilation, (list, tuple)) else (dilation, ) * 3 + + self.weight = nn.Parameter(torch.empty((out_channels, in_channels, *self.kernel_size))) + if bias: + self.bias = nn.Parameter(torch.empty(out_channels)) + else: + self.register_parameter("bias", None) + + # initialize parameters + torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight) + if fan_in != 0: + bound = 1 / math.sqrt(fan_in) + torch.nn.init.uniform_(self.bias, -bound, bound) + + # Permute weight (Co, Ci, Kd, Kh, Kw) -> (Co, Kd, Kh, Kw, Ci) + self.weight = nn.Parameter(self.weight.permute(0, 2, 3, 4, 1).contiguous()) + + +def sparse_conv3d_forward(self, x: SparseTensor) -> SparseTensor: + flex_gemm.ops.spconv.set_algorithm(config.FLEX_GEMM_ALGO) + flex_gemm.ops.spconv.set_hashmap_ratio(config.FLEX_GEMM_HASHMAP_RATIO) + + # check if neighbor map is already computed + Co, Kd, Kh, Kw, Ci = self.weight.shape + neighbor_cache_key = f'SubMConv3d_neighbor_cache_{Kw}x{Kh}x{Kd}_dilation{self.dilation}' + neighbor_cache = x.get_spatial_cache(neighbor_cache_key) + + out, neighbor_cache_ = sparse_submanifold_conv3d( + x.feats, + x.coords, + torch.Size([*x.shape, *x.spatial_shape]), + self.weight, + self.bias, + neighbor_cache, + self.dilation + ) + + if neighbor_cache is None: + x.register_spatial_cache(neighbor_cache_key, neighbor_cache_) + + out = x.replace(out) + return out + + +def sparse_inverse_conv3d_init(self, *args, **kwargs): + raise NotImplementedError('SparseInverseConv3d with flex_gemm is not implemented yet') + + +def sparse_inverse_conv3d_forward(self, x: SparseTensor) -> SparseTensor: + raise NotImplementedError('SparseInverseConv3d with flex_gemm is not implemented yet') diff --git a/trellis2/modules/sparse/conv/conv_spconv.py b/trellis2/modules/sparse/conv/conv_spconv.py new file mode 100644 index 0000000000000000000000000000000000000000..eacfa3dcfa0f675017692e086cb68298805e16cf --- /dev/null +++ b/trellis2/modules/sparse/conv/conv_spconv.py @@ -0,0 +1,73 @@ +import torch +import torch.nn as nn +from .. import SparseTensor +from . import config +import spconv.pytorch as spconv + + +def sparse_conv3d_init(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None): + algo = None + if config.SPCONV_ALGO == 'native': + algo = spconv.ConvAlgo.Native + elif config.SPCONV_ALGO == 'implicit_gemm': + algo = spconv.ConvAlgo.MaskImplicitGemm + if stride == 1 and (padding is None): + self.conv = spconv.SubMConv3d(in_channels, out_channels, kernel_size, dilation=dilation, bias=bias, indice_key=indice_key, algo=algo) + else: + self.conv = spconv.SparseConv3d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, padding=padding, bias=bias, indice_key=indice_key, algo=algo) + self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride) + self.padding = padding + + +def sparse_conv3d_forward(self, x: SparseTensor) -> SparseTensor: + spatial_changed = any(s != 1 for s in self.stride) or (self.padding is not None) + new_data = self.conv(x.data) + new_shape = [x.shape[0], self.conv.out_channels] + new_layout = None if spatial_changed else x.layout + + if spatial_changed and (x.shape[0] != 1): + # spconv was non-1 stride will break the contiguous of the output tensor, sort by the coords + fwd = new_data.indices[:, 0].argsort() + bwd = torch.zeros_like(fwd).scatter_(0, fwd, torch.arange(fwd.shape[0], device=fwd.device)) + sorted_feats = new_data.features[fwd] + sorted_coords = new_data.indices[fwd] + unsorted_data = new_data + new_data = spconv.SparseConvTensor(sorted_feats, sorted_coords, unsorted_data.spatial_shape, unsorted_data.batch_size) # type: ignore + + out = SparseTensor( + new_data, shape=torch.Size(new_shape), layout=new_layout, + scale=tuple([s * stride for s, stride in zip(x._scale, self.stride)]), + spatial_cache=x._spatial_cache, + ) + + if spatial_changed and (x.shape[0] != 1): + out.register_spatial_cache(f'conv_{self.stride}_unsorted_data', unsorted_data) + out.register_spatial_cache(f'conv_{self.stride}_sort_bwd', bwd) + + return out + + +def sparse_inverse_conv3d_init(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): + self.conv = spconv.SparseInverseConv3d(in_channels, out_channels, kernel_size, bias=bias, indice_key=indice_key) + self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride) + + +def sparse_inverse_conv3d_forward(self, x: SparseTensor) -> SparseTensor: + spatial_changed = any(s != 1 for s in self.stride) + if spatial_changed: + # recover the original spconv order + data = x.get_spatial_cache(f'conv_{self.stride}_unsorted_data') + bwd = x.get_spatial_cache(f'conv_{self.stride}_sort_bwd') + data = data.replace_feature(x.feats[bwd]) + else: + data = x.data + + new_data = self.conv(data) + new_shape = [x.shape[0], self.conv.out_channels] + new_layout = None if spatial_changed else x.layout + out = SparseTensor( + new_data, shape=torch.Size(new_shape), layout=new_layout, + scale=tuple([s // stride for s, stride in zip(x._scale, self.stride)]), + spatial_cache=x._spatial_cache, + ) + return out diff --git a/trellis2/modules/sparse/conv/conv_torchsparse.py b/trellis2/modules/sparse/conv/conv_torchsparse.py new file mode 100644 index 0000000000000000000000000000000000000000..35caa4045623e42a236ef496050352b5c4a2b465 --- /dev/null +++ b/trellis2/modules/sparse/conv/conv_torchsparse.py @@ -0,0 +1,30 @@ +import torch +import torch.nn as nn +from .. import SparseTensor +import torchsparse + + +def sparse_conv3d_init(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None): + self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias) + + +def sparse_conv3d_forward(self, x: SparseTensor) -> SparseTensor: + out = self.conv(x.data) + new_shape = [x.shape[0], self.conv.out_channels] + out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None) + out._spatial_cache = x._spatial_cache + out._scale = tuple([s * stride for s, stride in zip(x._scale, self.conv.stride)]) + return out + + +def sparse_inverse_conv3d_init(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): + self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias, transposed=True) + + +def sparse_inverse_conv3d_forward(self, x: SparseTensor) -> SparseTensor: + out = self.conv(x.data) + new_shape = [x.shape[0], self.conv.out_channels] + out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None) + out._spatial_cache = x._spatial_cache + out._scale = tuple([s / stride for s, stride in zip(x._scale, self.conv.stride)]) + return out diff --git a/trellis2/modules/sparse/linear.py b/trellis2/modules/sparse/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..f98f48e0f387a11ae1b36550873ec4686ed02aaa --- /dev/null +++ b/trellis2/modules/sparse/linear.py @@ -0,0 +1,15 @@ +import torch +import torch.nn as nn +from . import VarLenTensor + +__all__ = [ + 'SparseLinear' +] + + +class SparseLinear(nn.Linear): + def __init__(self, in_features, out_features, bias=True): + super(SparseLinear, self).__init__(in_features, out_features, bias) + + def forward(self, input: VarLenTensor) -> VarLenTensor: + return input.replace(super().forward(input.feats)) diff --git a/trellis2/modules/sparse/nonlinearity.py b/trellis2/modules/sparse/nonlinearity.py new file mode 100644 index 0000000000000000000000000000000000000000..46bf72c28b6d36dd28994c488cc89e3c7550db83 --- /dev/null +++ b/trellis2/modules/sparse/nonlinearity.py @@ -0,0 +1,35 @@ +import torch +import torch.nn as nn +from . import VarLenTensor + +__all__ = [ + 'SparseReLU', + 'SparseSiLU', + 'SparseGELU', + 'SparseActivation' +] + + +class SparseReLU(nn.ReLU): + def forward(self, input: VarLenTensor) -> VarLenTensor: + return input.replace(super().forward(input.feats)) + + +class SparseSiLU(nn.SiLU): + def forward(self, input: VarLenTensor) -> VarLenTensor: + return input.replace(super().forward(input.feats)) + + +class SparseGELU(nn.GELU): + def forward(self, input: VarLenTensor) -> VarLenTensor: + return input.replace(super().forward(input.feats)) + + +class SparseActivation(nn.Module): + def __init__(self, activation: nn.Module): + super().__init__() + self.activation = activation + + def forward(self, input: VarLenTensor) -> VarLenTensor: + return input.replace(self.activation(input.feats)) + diff --git a/trellis2/modules/sparse/norm.py b/trellis2/modules/sparse/norm.py new file mode 100644 index 0000000000000000000000000000000000000000..cafda780637db8d6a93c6c873f1284129b4aee11 --- /dev/null +++ b/trellis2/modules/sparse/norm.py @@ -0,0 +1,64 @@ +import torch +import torch.nn as nn +from ..utils import manual_cast +from . import VarLenTensor +from . import config + +__all__ = [ + 'SparseGroupNorm', + 'SparseLayerNorm', + 'SparseGroupNorm32', + 'SparseLayerNorm32', +] + + +class SparseGroupNorm(nn.GroupNorm): + def __init__(self, num_groups, num_channels, eps=1e-5, affine=True): + super(SparseGroupNorm, self).__init__(num_groups, num_channels, eps, affine) + + def forward(self, input: VarLenTensor) -> VarLenTensor: + nfeats = torch.zeros_like(input.feats) + for k in range(input.shape[0]): + bfeats = input.feats[input.layout[k]] + bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1) + bfeats = super().forward(bfeats) + bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0) + nfeats[input.layout[k]] = bfeats + return input.replace(nfeats) + + +class SparseLayerNorm(nn.LayerNorm): + def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): + super(SparseLayerNorm, self).__init__(normalized_shape, eps, elementwise_affine) + + def forward(self, input: VarLenTensor) -> VarLenTensor: + nfeats = torch.zeros_like(input.feats) + for k in range(input.shape[0]): + bfeats = input.feats[input.layout[k]] + bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1) + bfeats = super().forward(bfeats) + bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0) + nfeats[input.layout[k]] = bfeats + return input.replace(nfeats) + + +class SparseGroupNorm32(SparseGroupNorm): + """ + A GroupNorm layer that converts to float32 before the forward pass. + """ + def forward(self, x: VarLenTensor) -> VarLenTensor: + x_dtype = x.dtype + x = manual_cast(x, torch.float32) + o = super().forward(x) + return manual_cast(o, x_dtype) + + +class SparseLayerNorm32(SparseLayerNorm): + """ + A LayerNorm layer that converts to float32 before the forward pass. + """ + def forward(self, x: VarLenTensor) -> VarLenTensor: + x_dtype = x.dtype + x = manual_cast(x, torch.float32) + o = super().forward(x) + return manual_cast(o, x_dtype) diff --git a/trellis2/modules/sparse/spatial/__init__.py b/trellis2/modules/sparse/spatial/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..012958188fd950f4cfb0d3bfc25a440a05e1cf91 --- /dev/null +++ b/trellis2/modules/sparse/spatial/__init__.py @@ -0,0 +1,2 @@ +from .basic import * +from .spatial2channel import * diff --git a/trellis2/modules/sparse/spatial/__pycache__/__init__.cpython-311.pyc b/trellis2/modules/sparse/spatial/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac758afe00e00fa0458fbb30a27b17a4265b7487 Binary files /dev/null and b/trellis2/modules/sparse/spatial/__pycache__/__init__.cpython-311.pyc differ diff --git a/trellis2/modules/sparse/spatial/__pycache__/basic.cpython-311.pyc b/trellis2/modules/sparse/spatial/__pycache__/basic.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0237bced80c6d5ec42c1bbea05c57538937c87aa Binary files /dev/null and b/trellis2/modules/sparse/spatial/__pycache__/basic.cpython-311.pyc differ diff --git a/trellis2/modules/sparse/spatial/__pycache__/spatial2channel.cpython-311.pyc b/trellis2/modules/sparse/spatial/__pycache__/spatial2channel.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d004083d4f9a70b597e2e638f6d58916866ee048 Binary files /dev/null and b/trellis2/modules/sparse/spatial/__pycache__/spatial2channel.cpython-311.pyc differ diff --git a/trellis2/modules/sparse/spatial/basic.py b/trellis2/modules/sparse/spatial/basic.py new file mode 100644 index 0000000000000000000000000000000000000000..0746570966ac39e562219ee33cc52b23b91fd718 --- /dev/null +++ b/trellis2/modules/sparse/spatial/basic.py @@ -0,0 +1,109 @@ +from typing import * +import torch +import torch.nn as nn +from .. import SparseTensor + +__all__ = [ + 'SparseDownsample', + 'SparseUpsample', +] + + +class SparseDownsample(nn.Module): + """ + Downsample a sparse tensor by a factor of `factor`. + Implemented as average pooling. + """ + def __init__(self, factor: int, mode: Literal['mean', 'max'] = 'mean'): + super(SparseDownsample, self).__init__() + self.factor = factor + self.mode = mode + assert self.mode in ['mean', 'max'], f'Invalid mode: {self.mode}' + + def forward(self, x: SparseTensor) -> SparseTensor: + cache = x.get_spatial_cache(f'downsample_{self.factor}') + if cache is None: + DIM = x.coords.shape[-1] - 1 + + coord = list(x.coords.unbind(dim=-1)) + for i in range(DIM): + coord[i+1] = coord[i+1] // self.factor + + MAX = [(s + self.factor - 1) // self.factor for s in x.spatial_shape] + OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1] + code = sum([c * o for c, o in zip(coord, OFFSET)]) + code, idx = code.unique(return_inverse=True) + + new_coords = torch.stack( + [code // OFFSET[0]] + + [(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)], + dim=-1 + ) + else: + new_coords, idx = cache + + new_feats = torch.scatter_reduce( + torch.zeros(new_coords.shape[0], x.feats.shape[1], device=x.feats.device, dtype=x.feats.dtype), + dim=0, + index=idx.unsqueeze(1).expand(-1, x.feats.shape[1]), + src=x.feats, + reduce=self.mode, + include_self=False, + ) + out = SparseTensor(new_feats, new_coords, x._shape) + out._scale = tuple([s * self.factor for s in x._scale]) + out._spatial_cache = x._spatial_cache + + if cache is None: + x.register_spatial_cache(f'downsample_{self.factor}', (new_coords, idx)) + out.register_spatial_cache(f'upsample_{self.factor}', (x.coords, idx)) + out.register_spatial_cache(f'shape', torch.Size(MAX)) + if self.training: + subidx = x.coords[:, 1:] % self.factor + subidx = sum([subidx[..., i] * self.factor ** i for i in range(DIM)]) + subdivision = torch.zeros((new_coords.shape[0], self.factor ** DIM), device=x.device, dtype=torch.bool) + subdivision[idx, subidx] = True + out.register_spatial_cache(f'subdivision', subdivision) + + return out + + +class SparseUpsample(nn.Module): + """ + Upsample a sparse tensor by a factor of `factor`. + Implemented as nearest neighbor interpolation. + """ + def __init__( + self, factor: int + ): + super(SparseUpsample, self).__init__() + self.factor = factor + + def forward(self, x: SparseTensor, subdivision: Optional[SparseTensor] = None) -> SparseTensor: + DIM = x.coords.shape[-1] - 1 + + cache = x.get_spatial_cache(f'upsample_{self.factor}') + if cache is None: + if subdivision is None: + raise ValueError('Cache not found. Provide subdivision tensor or pair SparseUpsample with SparseDownsample.') + else: + sub = subdivision.feats + N_leaf = sub.sum(dim=-1) + subidx = sub.nonzero()[:, -1] + new_coords = x.coords.clone().detach() + new_coords[:, 1:] *= self.factor + new_coords = torch.repeat_interleave(new_coords, N_leaf, dim=0, output_size=subidx.shape[0]) + for i in range(DIM): + new_coords[:, i+1] += subidx // self.factor ** i % self.factor + idx = torch.repeat_interleave(torch.arange(x.coords.shape[0], device=x.device), N_leaf, dim=0, output_size=subidx.shape[0]) + else: + new_coords, idx = cache + + new_feats = x.feats[idx] + out = SparseTensor(new_feats, new_coords, x._shape) + out._scale = tuple([s / self.factor for s in x._scale]) + if cache is not None: # only keep cache when subdiv following it + out._spatial_cache = x._spatial_cache + + return out + \ No newline at end of file diff --git a/trellis2/modules/sparse/spatial/spatial2channel.py b/trellis2/modules/sparse/spatial/spatial2channel.py new file mode 100644 index 0000000000000000000000000000000000000000..39e73c1088896b1866724bbe5f6dfb364c06c748 --- /dev/null +++ b/trellis2/modules/sparse/spatial/spatial2channel.py @@ -0,0 +1,93 @@ +from typing import * +import torch +import torch.nn as nn +from .. import SparseTensor + + +class SparseSpatial2Channel(nn.Module): + """ + Downsample a sparse tensor by a factor of `factor`. + Implemented as rearranging its features from spatial to channel. + """ + def __init__(self, factor: int = 2): + super(SparseSpatial2Channel, self).__init__() + self.factor = factor + + def forward(self, x: SparseTensor) -> SparseTensor: + DIM = x.coords.shape[-1] - 1 + cache = x.get_spatial_cache(f'spatial2channel_{self.factor}') + if cache is None: + coord = list(x.coords.unbind(dim=-1)) + for i in range(DIM): + coord[i+1] = coord[i+1] // self.factor + subidx = x.coords[:, 1:] % self.factor + subidx = sum([subidx[..., i] * self.factor ** i for i in range(DIM)]) + + MAX = [(s + self.factor - 1) // self.factor for s in x.spatial_shape] + OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1] + code = sum([c * o for c, o in zip(coord, OFFSET)]) + code, idx = code.unique(return_inverse=True) + + new_coords = torch.stack( + [code // OFFSET[0]] + + [(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)], + dim=-1 + ) + else: + new_coords, idx, subidx = cache + + new_feats = torch.zeros(new_coords.shape[0] * self.factor ** DIM, x.feats.shape[1], device=x.feats.device, dtype=x.feats.dtype) + new_feats[idx * self.factor ** DIM + subidx] = x.feats + + out = SparseTensor(new_feats.reshape(new_coords.shape[0], -1), new_coords, None if x._shape is None else torch.Size([x._shape[0], x._shape[1] * self.factor ** DIM])) + out._scale = tuple([s * self.factor for s in x._scale]) + out._spatial_cache = x._spatial_cache + + if cache is None: + x.register_spatial_cache(f'spatial2channel_{self.factor}', (new_coords, idx, subidx)) + out.register_spatial_cache(f'channel2spatial_{self.factor}', (x.coords, idx, subidx)) + out.register_spatial_cache(f'shape', torch.Size(MAX)) + if self.training: + subdivision = torch.zeros((new_coords.shape[0], self.factor ** DIM), device=x.device, dtype=torch.bool) + subdivision[idx, subidx] = True + out.register_spatial_cache(f'subdivision', subdivision) + + return out + + +class SparseChannel2Spatial(nn.Module): + """ + Upsample a sparse tensor by a factor of `factor`. + Implemented as rearranging its features from channel to spatial. + """ + def __init__(self, factor: int = 2): + super(SparseChannel2Spatial, self).__init__() + self.factor = factor + + def forward(self, x: SparseTensor, subdivision: Optional[SparseTensor] = None) -> SparseTensor: + DIM = x.coords.shape[-1] - 1 + + cache = x.get_spatial_cache(f'channel2spatial_{self.factor}') + if cache is None: + if subdivision is None: + raise ValueError('Cache not found. Provide subdivision tensor or pair SparseChannel2Spatial with SparseSpatial2Channel.') + else: + sub = subdivision.feats # [N, self.factor ** DIM] + N_leaf = sub.sum(dim=-1) # [N] + subidx = sub.nonzero()[:, -1] + new_coords = x.coords.clone().detach() + new_coords[:, 1:] *= self.factor + new_coords = torch.repeat_interleave(new_coords, N_leaf, dim=0, output_size=subidx.shape[0]) + for i in range(DIM): + new_coords[:, i+1] += subidx // self.factor ** i % self.factor + idx = torch.repeat_interleave(torch.arange(x.coords.shape[0], device=x.device), N_leaf, dim=0, output_size=subidx.shape[0]) + else: + new_coords, idx, subidx = cache + + x_feats = x.feats.reshape(x.feats.shape[0] * self.factor ** DIM, -1) + new_feats = x_feats[idx * self.factor ** DIM + subidx] + out = SparseTensor(new_feats, new_coords, None if x._shape is None else torch.Size([x._shape[0], x._shape[1] // self.factor ** DIM])) + out._scale = tuple([s / self.factor for s in x._scale]) + if cache is not None: # only keep cache when subdiv following it + out._spatial_cache = x._spatial_cache + return out diff --git a/trellis2/modules/sparse/transformer/__init__.py b/trellis2/modules/sparse/transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..67336cacd084ef5e779bf5a601d66720ea275fe6 --- /dev/null +++ b/trellis2/modules/sparse/transformer/__init__.py @@ -0,0 +1,2 @@ +from .blocks import * +from .modulated import * \ No newline at end of file diff --git a/trellis2/modules/sparse/transformer/__pycache__/__init__.cpython-311.pyc b/trellis2/modules/sparse/transformer/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24cb36835a7039c77342e95fce463d3ef1ef1e08 Binary files /dev/null and b/trellis2/modules/sparse/transformer/__pycache__/__init__.cpython-311.pyc differ diff --git a/trellis2/modules/sparse/transformer/__pycache__/blocks.cpython-311.pyc b/trellis2/modules/sparse/transformer/__pycache__/blocks.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7ef0e5a57e2c20a41f37c4d7cdbdb9ddf354f95 Binary files /dev/null and b/trellis2/modules/sparse/transformer/__pycache__/blocks.cpython-311.pyc differ diff --git a/trellis2/modules/sparse/transformer/__pycache__/modulated.cpython-311.pyc b/trellis2/modules/sparse/transformer/__pycache__/modulated.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d8d9a5ca2313d55fc954872de17ca81ff32880d Binary files /dev/null and b/trellis2/modules/sparse/transformer/__pycache__/modulated.cpython-311.pyc differ diff --git a/trellis2/modules/sparse/transformer/blocks.py b/trellis2/modules/sparse/transformer/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..95e96ed97c4f153a69dd9a6ea9c147f94f7d80ae --- /dev/null +++ b/trellis2/modules/sparse/transformer/blocks.py @@ -0,0 +1,145 @@ +from typing import * +import torch +import torch.nn as nn +from ..basic import VarLenTensor, SparseTensor +from ..linear import SparseLinear +from ..nonlinearity import SparseGELU +from ..attention import SparseMultiHeadAttention +from ...norm import LayerNorm32 + + +class SparseFeedForwardNet(nn.Module): + def __init__(self, channels: int, mlp_ratio: float = 4.0): + super().__init__() + self.mlp = nn.Sequential( + SparseLinear(channels, int(channels * mlp_ratio)), + SparseGELU(approximate="tanh"), + SparseLinear(int(channels * mlp_ratio), channels), + ) + + def forward(self, x: VarLenTensor) -> VarLenTensor: + return self.mlp(x) + + +class SparseTransformerBlock(nn.Module): + """ + Sparse Transformer block (MSA + FFN). + """ + def __init__( + self, + channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "swin"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + rope_freq: Tuple[int, int] = (1.0, 10000.0), + qk_rms_norm: bool = False, + qkv_bias: bool = True, + ln_affine: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + rope_freq=rope_freq, + qk_rms_norm=qk_rms_norm, + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + + def _forward(self, x: SparseTensor) -> SparseTensor: + h = x.replace(self.norm1(x.feats)) + h = self.attn(h) + x = x + h + h = x.replace(self.norm2(x.feats)) + h = self.mlp(h) + x = x + h + return x + + def forward(self, x: SparseTensor) -> SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) + else: + return self._forward(x) + + +class SparseTransformerCrossBlock(nn.Module): + """ + Sparse Transformer cross-attention block (MSA + MCA + FFN). + """ + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "swin"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + ln_affine: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.self_attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.cross_attn = SparseMultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + + def _forward(self, x: SparseTensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor: + h = x.replace(self.norm1(x.feats)) + h = self.self_attn(h) + x = x + h + h = x.replace(self.norm2(x.feats)) + h = self.cross_attn(h, context) + x = x + h + h = x.replace(self.norm3(x.feats)) + h = self.mlp(h) + x = x + h + return x + + def forward(self, x: SparseTensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False) + else: + return self._forward(x, context) diff --git a/trellis2/modules/sparse/transformer/modulated.py b/trellis2/modules/sparse/transformer/modulated.py new file mode 100644 index 0000000000000000000000000000000000000000..7505b754361d83f553cf5c0d0acb4b4f46e2ed32 --- /dev/null +++ b/trellis2/modules/sparse/transformer/modulated.py @@ -0,0 +1,166 @@ +from typing import * +import torch +import torch.nn as nn +from ..basic import VarLenTensor, SparseTensor +from ..attention import SparseMultiHeadAttention +from ...norm import LayerNorm32 +from .blocks import SparseFeedForwardNet + + +class ModulatedSparseTransformerBlock(nn.Module): + """ + Sparse Transformer block (MSA + FFN) with adaptive layer norm conditioning. + """ + def __init__( + self, + channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "swin"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + rope_freq: Tuple[float, float] = (1.0, 10000.0), + qk_rms_norm: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + rope_freq=rope_freq, + qk_rms_norm=qk_rms_norm, + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 6 * channels, bias=True) + ) + else: + self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5) + + def _forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor: + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = x.replace(self.norm1(x.feats)) + h = h * (1 + scale_msa) + shift_msa + h = self.attn(h) + h = h * gate_msa + x = x + h + h = x.replace(self.norm2(x.feats)) + h = h * (1 + scale_mlp) + shift_mlp + h = self.mlp(h) + h = h * gate_mlp + x = x + h + return x + + def forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False) + else: + return self._forward(x, mod) + + +class ModulatedSparseTransformerCrossBlock(nn.Module): + """ + Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning. + """ + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "swin"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + rope_freq: Tuple[float, float] = (1.0, 10000.0), + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.self_attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + rope_freq=rope_freq, + qk_rms_norm=qk_rms_norm, + ) + self.cross_attn = SparseMultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 6 * channels, bias=True) + ) + else: + self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5) + + def _forward(self, x: SparseTensor, mod: torch.Tensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor: + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = x.replace(self.norm1(x.feats)) + h = h * (1 + scale_msa) + shift_msa + h = self.self_attn(h) + h = h * gate_msa + x = x + h + h = x.replace(self.norm2(x.feats)) + h = self.cross_attn(h, context) + x = x + h + h = x.replace(self.norm3(x.feats)) + h = h * (1 + scale_mlp) + shift_mlp + h = self.mlp(h) + h = h * gate_mlp + x = x + h + return x + + def forward(self, x: SparseTensor, mod: torch.Tensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, use_reentrant=False) + else: + return self._forward(x, mod, context) diff --git a/trellis2/modules/spatial.py b/trellis2/modules/spatial.py new file mode 100644 index 0000000000000000000000000000000000000000..5e3b750c1da9462818ad5e25cc50e59a7d92f786 --- /dev/null +++ b/trellis2/modules/spatial.py @@ -0,0 +1,48 @@ +import torch + + +def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor: + """ + 3D pixel shuffle. + """ + B, C, H, W, D = x.shape + C_ = C // scale_factor**3 + x = x.reshape(B, C_, scale_factor, scale_factor, scale_factor, H, W, D) + x = x.permute(0, 1, 5, 2, 6, 3, 7, 4) + x = x.reshape(B, C_, H*scale_factor, W*scale_factor, D*scale_factor) + return x + + +def patchify(x: torch.Tensor, patch_size: int): + """ + Patchify a tensor. + + Args: + x (torch.Tensor): (N, C, *spatial) tensor + patch_size (int): Patch size + """ + DIM = x.dim() - 2 + for d in range(2, DIM + 2): + assert x.shape[d] % patch_size == 0, f"Dimension {d} of input tensor must be divisible by patch size, got {x.shape[d]} and {patch_size}" + + x = x.reshape(*x.shape[:2], *sum([[x.shape[d] // patch_size, patch_size] for d in range(2, DIM + 2)], [])) + x = x.permute(0, 1, *([2 * i + 3 for i in range(DIM)] + [2 * i + 2 for i in range(DIM)])) + x = x.reshape(x.shape[0], x.shape[1] * (patch_size ** DIM), *(x.shape[-DIM:])) + return x + + +def unpatchify(x: torch.Tensor, patch_size: int): + """ + Unpatchify a tensor. + + Args: + x (torch.Tensor): (N, C, *spatial) tensor + patch_size (int): Patch size + """ + DIM = x.dim() - 2 + assert x.shape[1] % (patch_size ** DIM) == 0, f"Second dimension of input tensor must be divisible by patch size to unpatchify, got {x.shape[1]} and {patch_size ** DIM}" + + x = x.reshape(x.shape[0], x.shape[1] // (patch_size ** DIM), *([patch_size] * DIM), *(x.shape[-DIM:])) + x = x.permute(0, 1, *(sum([[2 + DIM + i, 2 + i] for i in range(DIM)], []))) + x = x.reshape(x.shape[0], x.shape[1], *[x.shape[2 + 2 * i] * patch_size for i in range(DIM)]) + return x diff --git a/trellis2/modules/transformer/__init__.py b/trellis2/modules/transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..67336cacd084ef5e779bf5a601d66720ea275fe6 --- /dev/null +++ b/trellis2/modules/transformer/__init__.py @@ -0,0 +1,2 @@ +from .blocks import * +from .modulated import * \ No newline at end of file diff --git a/trellis2/modules/transformer/__pycache__/__init__.cpython-311.pyc b/trellis2/modules/transformer/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f8550397f379c1a37a985c73c9f1bffc40eac5c Binary files /dev/null and b/trellis2/modules/transformer/__pycache__/__init__.cpython-311.pyc differ diff --git a/trellis2/modules/transformer/__pycache__/blocks.cpython-311.pyc b/trellis2/modules/transformer/__pycache__/blocks.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74e63c7e674b58027bd41e75f968798504914952 Binary files /dev/null and b/trellis2/modules/transformer/__pycache__/blocks.cpython-311.pyc differ diff --git a/trellis2/modules/transformer/__pycache__/modulated.cpython-311.pyc b/trellis2/modules/transformer/__pycache__/modulated.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28af4be1c59be58009384422fe96788006bd66f1 Binary files /dev/null and b/trellis2/modules/transformer/__pycache__/modulated.cpython-311.pyc differ diff --git a/trellis2/modules/transformer/blocks.py b/trellis2/modules/transformer/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..97997e3a8bfa451a7a1d03dc97c36c9f01715e2e --- /dev/null +++ b/trellis2/modules/transformer/blocks.py @@ -0,0 +1,186 @@ +from typing import * +import torch +import torch.nn as nn +from ..attention import MultiHeadAttention +from ..norm import LayerNorm32 + + +class AbsolutePositionEmbedder(nn.Module): + """ + Embeds spatial positions into vector representations. + """ + def __init__(self, channels: int, in_channels: int = 3): + super().__init__() + self.channels = channels + self.in_channels = in_channels + self.freq_dim = channels // in_channels // 2 + self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim + self.freqs = 1.0 / (10000 ** self.freqs) + + def _sin_cos_embedding(self, x: torch.Tensor) -> torch.Tensor: + """ + Create sinusoidal position embeddings. + + Args: + x: a 1-D Tensor of N indices + + Returns: + an (N, D) Tensor of positional embeddings. + """ + self.freqs = self.freqs.to(x.device) + out = torch.outer(x, self.freqs) + out = torch.cat([torch.sin(out), torch.cos(out)], dim=-1) + return out + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): (N, D) tensor of spatial positions + """ + N, D = x.shape + assert D == self.in_channels, "Input dimension must match number of input channels" + embed = self._sin_cos_embedding(x.reshape(-1)) + embed = embed.reshape(N, -1) + if embed.shape[1] < self.channels: + embed = torch.cat([embed, torch.zeros(N, self.channels - embed.shape[1], device=embed.device)], dim=-1) + return embed + + +class FeedForwardNet(nn.Module): + def __init__(self, channels: int, mlp_ratio: float = 4.0): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(channels, int(channels * mlp_ratio)), + nn.GELU(approximate="tanh"), + nn.Linear(int(channels * mlp_ratio), channels), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.mlp(x) + + +class TransformerBlock(nn.Module): + """ + Transformer block (MSA + FFN). + """ + def __init__( + self, + channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[int] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + rope_freq: Tuple[int, int] = (1.0, 10000.0), + qk_rms_norm: bool = False, + qkv_bias: bool = True, + ln_affine: bool = True, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.attn = MultiHeadAttention( + channels, + num_heads=num_heads, + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + rope_freq=rope_freq, + qk_rms_norm=qk_rms_norm, + ) + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + + def _forward(self, x: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + h = self.norm1(x) + h = self.attn(h, phases=phases) + x = x + h + h = self.norm2(x) + h = self.mlp(h) + x = x + h + return x + + def forward(self, x: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, phases, use_reentrant=False) + else: + return self._forward(x, phases) + + +class TransformerCrossBlock(nn.Module): + """ + Transformer cross-attention block (MSA + MCA + FFN). + """ + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + rope_freq: Tuple[int, int] = (1.0, 10000.0), + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + ln_affine: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.self_attn = MultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + rope_freq=rope_freq, + qk_rms_norm=qk_rms_norm, + ) + self.cross_attn = MultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + + def _forward(self, x: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + h = self.norm1(x) + h = self.self_attn(h, phases=phases) + x = x + h + h = self.norm2(x) + h = self.cross_attn(h, context) + x = x + h + h = self.norm3(x) + h = self.mlp(h) + x = x + h + return x + + def forward(self, x: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, context, phases, use_reentrant=False) + else: + return self._forward(x, context, phases) + \ No newline at end of file diff --git a/trellis2/modules/transformer/modulated.py b/trellis2/modules/transformer/modulated.py new file mode 100644 index 0000000000000000000000000000000000000000..ce2ab8b32637a2f5a94ca6ee6a4052f767a62bbc --- /dev/null +++ b/trellis2/modules/transformer/modulated.py @@ -0,0 +1,165 @@ +from typing import * +import torch +import torch.nn as nn +from ..attention import MultiHeadAttention +from ..norm import LayerNorm32 +from .blocks import FeedForwardNet + + +class ModulatedTransformerBlock(nn.Module): + """ + Transformer block (MSA + FFN) with adaptive layer norm conditioning. + """ + def __init__( + self, + channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + rope_freq: Tuple[int, int] = (1.0, 10000.0), + qk_rms_norm: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.attn = MultiHeadAttention( + channels, + num_heads=num_heads, + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + rope_freq=rope_freq, + qk_rms_norm=qk_rms_norm, + ) + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 6 * channels, bias=True) + ) + else: + self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5) + + def _forward(self, x: torch.Tensor, mod: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = self.norm1(x) + h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1) + h = self.attn(h, phases=phases) + h = h * gate_msa.unsqueeze(1) + x = x + h + h = self.norm2(x) + h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) + h = self.mlp(h) + h = h * gate_mlp.unsqueeze(1) + x = x + h + return x + + def forward(self, x: torch.Tensor, mod: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, mod, phases, use_reentrant=False) + else: + return self._forward(x, mod, phases) + + +class ModulatedTransformerCrossBlock(nn.Module): + """ + Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning. + """ + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + rope_freq: Tuple[int, int] = (1.0, 10000.0), + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.self_attn = MultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + rope_freq=rope_freq, + qk_rms_norm=qk_rms_norm, + ) + self.cross_attn = MultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 6 * channels, bias=True) + ) + else: + self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5) + + def _forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = self.norm1(x) + h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1) + h = self.self_attn(h, phases=phases) + h = h * gate_msa.unsqueeze(1) + x = x + h + h = self.norm2(x) + h = self.cross_attn(h, context) + x = x + h + h = self.norm3(x) + h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) + h = self.mlp(h) + h = h * gate_mlp.unsqueeze(1) + x = x + h + return x + + def forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, phases, use_reentrant=False) + else: + return self._forward(x, mod, context, phases) + \ No newline at end of file diff --git a/trellis2/modules/utils.py b/trellis2/modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fd81549197c12d79abe8b4ddee06dbc55c77a59c --- /dev/null +++ b/trellis2/modules/utils.py @@ -0,0 +1,87 @@ +import torch +import torch.nn as nn +from ..modules import sparse as sp + +MIX_PRECISION_MODULES = ( + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nn.ConvTranspose1d, + nn.ConvTranspose2d, + nn.ConvTranspose3d, + nn.Linear, + sp.SparseConv3d, + sp.SparseInverseConv3d, + sp.SparseLinear, +) + + +def convert_module_to_f16(l): + """ + Convert primitive modules to float16. + """ + if isinstance(l, MIX_PRECISION_MODULES): + for p in l.parameters(): + p.data = p.data.half() + + +def convert_module_to_f32(l): + """ + Convert primitive modules to float32, undoing convert_module_to_f16(). + """ + if isinstance(l, MIX_PRECISION_MODULES): + for p in l.parameters(): + p.data = p.data.float() + + +def convert_module_to(l, dtype): + """ + Convert primitive modules to the given dtype. + """ + if isinstance(l, MIX_PRECISION_MODULES): + for p in l.parameters(): + p.data = p.data.to(dtype) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +def manual_cast(tensor, dtype): + """ + Cast if autocast is not enabled. + """ + if not torch.is_autocast_enabled(): + return tensor.type(dtype) + return tensor + + +def str_to_dtype(dtype_str: str): + return { + 'f16': torch.float16, + 'fp16': torch.float16, + 'float16': torch.float16, + 'bf16': torch.bfloat16, + 'bfloat16': torch.bfloat16, + 'f32': torch.float32, + 'fp32': torch.float32, + 'float32': torch.float32, + }[dtype_str] diff --git a/trellis2/pipelines/__init__.py b/trellis2/pipelines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1f13639e2f248fa495f7b676ae161c12b8266f0e --- /dev/null +++ b/trellis2/pipelines/__init__.py @@ -0,0 +1,52 @@ +import importlib + +__attributes = { + "Trellis2ImageTo3DPipeline": "trellis2_image_to_3d", + "Trellis2TexturingPipeline": "trellis2_texturing", +} + +__submodules = ['samplers', 'rembg'] + +__all__ = list(__attributes.keys()) + __submodules + +def __getattr__(name): + if name not in globals(): + if name in __attributes: + module_name = __attributes[name] + module = importlib.import_module(f".{module_name}", __name__) + globals()[name] = getattr(module, name) + elif name in __submodules: + module = importlib.import_module(f".{name}", __name__) + globals()[name] = module + else: + raise AttributeError(f"module {__name__} has no attribute {name}") + return globals()[name] + + +def from_pretrained(path: str): + """ + Load a pipeline from a model folder or a Hugging Face model hub. + + Args: + path: The path to the model. Can be either local path or a Hugging Face model name. + """ + import os + import json + is_local = os.path.exists(f"{path}/pipeline.json") + + if is_local: + config_file = f"{path}/pipeline.json" + else: + from huggingface_hub import hf_hub_download + config_file = hf_hub_download(path, "pipeline.json") + + with open(config_file, 'r') as f: + config = json.load(f) + return globals()[config['name']].from_pretrained(path) + + +# For PyLance +if __name__ == '__main__': + from . import samplers, rembg + from .trellis2_image_to_3d import Trellis2ImageTo3DPipeline + from .trellis2_texturing import Trellis2TexturingPipeline diff --git a/trellis2/pipelines/__pycache__/__init__.cpython-311.pyc b/trellis2/pipelines/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b88c070ce1314fa43c2a45f1180320738a6452cd Binary files /dev/null and b/trellis2/pipelines/__pycache__/__init__.cpython-311.pyc differ diff --git a/trellis2/pipelines/__pycache__/base.cpython-311.pyc b/trellis2/pipelines/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26cb11b500fdac555f7b42220a5c5141d1ea6642 Binary files /dev/null and b/trellis2/pipelines/__pycache__/base.cpython-311.pyc differ diff --git a/trellis2/pipelines/__pycache__/trellis2_image_to_3d.cpython-311.pyc b/trellis2/pipelines/__pycache__/trellis2_image_to_3d.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0bc142a7e40ad34dd8c55153edc9615948a83837 Binary files /dev/null and b/trellis2/pipelines/__pycache__/trellis2_image_to_3d.cpython-311.pyc differ diff --git a/trellis2/pipelines/base.py b/trellis2/pipelines/base.py new file mode 100644 index 0000000000000000000000000000000000000000..fd61b41314bfc576b9a877c7784c8ff97513786f --- /dev/null +++ b/trellis2/pipelines/base.py @@ -0,0 +1,72 @@ +from typing import * +import torch +import torch.nn as nn +from .. import models + + +class Pipeline: + """ + A base class for pipelines. + """ + def __init__( + self, + models: dict[str, nn.Module] = None, + ): + if models is None: + return + self.models = models + for model in self.models.values(): + model.eval() + + @classmethod + def from_pretrained(cls, path: str, config_file: str = "pipeline.json") -> "Pipeline": + """ + Load a pretrained model. + """ + import os + import json + is_local = os.path.exists(f"{path}/{config_file}") + + if is_local: + config_file = f"{path}/{config_file}" + else: + from huggingface_hub import hf_hub_download + config_file = hf_hub_download(path, config_file) + + with open(config_file, 'r') as f: + args = json.load(f)['args'] + + _models = {} + for k, v in args['models'].items(): + if hasattr(cls, 'model_names_to_load') and k not in cls.model_names_to_load: + continue + try: + _models[k] = models.from_pretrained(f"{path}/{v}") + except Exception as e: + _models[k] = models.from_pretrained(v) + + new_pipeline = cls(_models) + new_pipeline._pretrained_args = args + return new_pipeline + + @property + def device(self) -> torch.device: + if hasattr(self, '_device'): + return self._device + for model in self.models.values(): + if hasattr(model, 'device'): + return model.device + for model in self.models.values(): + if hasattr(model, 'parameters'): + return next(model.parameters()).device + raise RuntimeError("No device found.") + + def to(self, device: torch.device) -> None: + for model in self.models.values(): + model.to(device) + + def cuda(self) -> None: + self.to(torch.device("cuda")) + + def cpu(self) -> None: + self.to(torch.device("cpu")) \ No newline at end of file diff --git a/trellis2/pipelines/rembg/BiRefNet.py b/trellis2/pipelines/rembg/BiRefNet.py new file mode 100644 index 0000000000000000000000000000000000000000..cebdb21b64a1d9238265d071375c905fde62f8ad --- /dev/null +++ b/trellis2/pipelines/rembg/BiRefNet.py @@ -0,0 +1,91 @@ +from typing import * +from transformers import AutoModelForImageSegmentation +import torch +from torchvision import transforms +from PIL import Image +from contextlib import contextmanager + + +@contextmanager +def _force_cpu_linspace(): + """Force torch.linspace to create CPU tensors, working around meta tensor issues.""" + original_linspace = torch.linspace + def cpu_linspace(*args, **kwargs): + kwargs['device'] = 'cpu' + return original_linspace(*args, **kwargs) + torch.linspace = cpu_linspace + try: + yield + finally: + torch.linspace = original_linspace + + +def _patch_birefnet_class(): + """Patch BiRefNet class to add missing all_tied_weights_keys property for transformers compatibility.""" + from transformers.dynamic_module_utils import get_class_from_dynamic_module + import importlib + + try: + # Get the model class from the remote module + module_name = "briaai/RMBG-2.0" + cls = get_class_from_dynamic_module( + f"{module_name}--birefnet.BiRefNet", + module_name, + trust_remote_code=True, + ) + + # Add missing attribute if not present + if not hasattr(cls, 'all_tied_weights_keys'): + @property + def all_tied_weights_keys(self): + # Return empty dict if _tied_weights_keys not defined + if hasattr(self, '_tied_weights_keys') and self._tied_weights_keys: + return {k: k for k in self._tied_weights_keys} + return {} + cls.all_tied_weights_keys = all_tied_weights_keys + except Exception: + pass # Silently fail - will be caught during actual model loading + + +class BiRefNet: + def __init__(self, model_name: str = "ZhengPeng7/BiRefNet"): + # RMBG-2.0's SwinTransformer uses torch.linspace().item() which fails on meta tensors + # Force linspace to create CPU tensors during model initialization + _patch_birefnet_class() + with _force_cpu_linspace(): + self.model = AutoModelForImageSegmentation.from_pretrained( + model_name, + trust_remote_code=True, + device_map=None, + low_cpu_mem_usage=False, + ) + self.model.eval() + self.transform_image = transforms.Compose( + [ + transforms.Resize((1024, 1024)), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ] + ) + + def to(self, device: str): + self.model.to(device) + + def cuda(self): + self.model.cuda() + + def cpu(self): + self.model.cpu() + + def __call__(self, image: Image.Image) -> Image.Image: + image_size = image.size + input_images = self.transform_image(image).unsqueeze(0).to("cuda") + # Prediction + with torch.no_grad(): + preds = self.model(input_images)[-1].sigmoid().cpu() + pred = preds[0].squeeze() + pred_pil = transforms.ToPILImage()(pred) + mask = pred_pil.resize(image_size) + image.putalpha(mask) + return image + \ No newline at end of file diff --git a/trellis2/pipelines/rembg/__init__.py b/trellis2/pipelines/rembg/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4981b433c0f11b064afec758a640240466b0548a --- /dev/null +++ b/trellis2/pipelines/rembg/__init__.py @@ -0,0 +1 @@ +from .BiRefNet import * diff --git a/trellis2/pipelines/rembg/__pycache__/BiRefNet.cpython-311.pyc b/trellis2/pipelines/rembg/__pycache__/BiRefNet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0853272a0084b111d39537eea630f6458c36e0d Binary files /dev/null and b/trellis2/pipelines/rembg/__pycache__/BiRefNet.cpython-311.pyc differ diff --git a/trellis2/pipelines/rembg/__pycache__/__init__.cpython-311.pyc b/trellis2/pipelines/rembg/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01f530aa4125fd68ec2b210fdf0e0a34d1b3dbdf Binary files /dev/null and b/trellis2/pipelines/rembg/__pycache__/__init__.cpython-311.pyc differ diff --git a/trellis2/pipelines/samplers/__init__.py b/trellis2/pipelines/samplers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2c294cba593132a5e7b236e0796bfdc0fe7db8e7 --- /dev/null +++ b/trellis2/pipelines/samplers/__init__.py @@ -0,0 +1,6 @@ +from .base import Sampler +from .flow_euler import ( + FlowEulerSampler, + FlowEulerCfgSampler, + FlowEulerGuidanceIntervalSampler, +) \ No newline at end of file diff --git a/trellis2/pipelines/samplers/__pycache__/__init__.cpython-311.pyc b/trellis2/pipelines/samplers/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..73b08365c1cf9982d0c75280aba229327f0ea8c1 Binary files /dev/null and b/trellis2/pipelines/samplers/__pycache__/__init__.cpython-311.pyc differ diff --git a/trellis2/pipelines/samplers/__pycache__/base.cpython-311.pyc b/trellis2/pipelines/samplers/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35db73ce10d23a43fa9d436f84f52fc5f370a3d4 Binary files /dev/null and b/trellis2/pipelines/samplers/__pycache__/base.cpython-311.pyc differ diff --git a/trellis2/pipelines/samplers/__pycache__/classifier_free_guidance_mixin.cpython-311.pyc b/trellis2/pipelines/samplers/__pycache__/classifier_free_guidance_mixin.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..774f1f84bb1cb5d174e6d2aa7938225d44915829 Binary files /dev/null and b/trellis2/pipelines/samplers/__pycache__/classifier_free_guidance_mixin.cpython-311.pyc differ diff --git a/trellis2/pipelines/samplers/__pycache__/flow_euler.cpython-311.pyc b/trellis2/pipelines/samplers/__pycache__/flow_euler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb2945d1f2df67431dd244544f11bed686dd1330 Binary files /dev/null and b/trellis2/pipelines/samplers/__pycache__/flow_euler.cpython-311.pyc differ diff --git a/trellis2/pipelines/samplers/__pycache__/guidance_interval_mixin.cpython-311.pyc b/trellis2/pipelines/samplers/__pycache__/guidance_interval_mixin.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7a30d6b35b82243f68ff3fcd855466f44034b4b Binary files /dev/null and b/trellis2/pipelines/samplers/__pycache__/guidance_interval_mixin.cpython-311.pyc differ diff --git a/trellis2/pipelines/samplers/base.py b/trellis2/pipelines/samplers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..bb70700117317477e738845e566b9ea87a768d0a --- /dev/null +++ b/trellis2/pipelines/samplers/base.py @@ -0,0 +1,20 @@ +from typing import * +from abc import ABC, abstractmethod + + +class Sampler(ABC): + """ + A base class for samplers. + """ + + @abstractmethod + def sample( + self, + model, + **kwargs + ): + """ + Sample from a model. + """ + pass + \ No newline at end of file diff --git a/trellis2/pipelines/samplers/classifier_free_guidance_mixin.py b/trellis2/pipelines/samplers/classifier_free_guidance_mixin.py new file mode 100644 index 0000000000000000000000000000000000000000..0fb9f4723a2053193b56008d8dc2b6a5b61613b8 --- /dev/null +++ b/trellis2/pipelines/samplers/classifier_free_guidance_mixin.py @@ -0,0 +1,29 @@ +from typing import * + + +class ClassifierFreeGuidanceSamplerMixin: + """ + A mixin class for samplers that apply classifier-free guidance. + """ + + def _inference_model(self, model, x_t, t, cond, neg_cond, guidance_strength, guidance_rescale=0.0, **kwargs): + if guidance_strength == 1: + return super()._inference_model(model, x_t, t, cond, **kwargs) + elif guidance_strength == 0: + return super()._inference_model(model, x_t, t, neg_cond, **kwargs) + else: + pred_pos = super()._inference_model(model, x_t, t, cond, **kwargs) + pred_neg = super()._inference_model(model, x_t, t, neg_cond, **kwargs) + pred = guidance_strength * pred_pos + (1 - guidance_strength) * pred_neg + + # CFG rescale + if guidance_rescale > 0: + x_0_pos = self._pred_to_xstart(x_t, t, pred_pos) + x_0_cfg = self._pred_to_xstart(x_t, t, pred) + std_pos = x_0_pos.std(dim=list(range(1, x_0_pos.ndim)), keepdim=True) + std_cfg = x_0_cfg.std(dim=list(range(1, x_0_cfg.ndim)), keepdim=True) + x_0_rescaled = x_0_cfg * (std_pos / std_cfg) + x_0 = guidance_rescale * x_0_rescaled + (1 - guidance_rescale) * x_0_cfg + pred = self._xstart_to_pred(x_t, t, x_0) + + return pred diff --git a/trellis2/pipelines/samplers/flow_euler.py b/trellis2/pipelines/samplers/flow_euler.py new file mode 100644 index 0000000000000000000000000000000000000000..4716f11c1a1bcb191d7b3dd2edeab9e18488f425 --- /dev/null +++ b/trellis2/pipelines/samplers/flow_euler.py @@ -0,0 +1,208 @@ +from typing import * +import torch +import numpy as np +from tqdm import tqdm +from easydict import EasyDict as edict +from .base import Sampler +from .classifier_free_guidance_mixin import ClassifierFreeGuidanceSamplerMixin +from .guidance_interval_mixin import GuidanceIntervalSamplerMixin + + +class FlowEulerSampler(Sampler): + """ + Generate samples from a flow-matching model using Euler sampling. + + Args: + sigma_min: The minimum scale of noise in flow. + """ + def __init__( + self, + sigma_min: float, + ): + self.sigma_min = sigma_min + + def _eps_to_xstart(self, x_t, t, eps): + assert x_t.shape == eps.shape + return (x_t - (self.sigma_min + (1 - self.sigma_min) * t) * eps) / (1 - t) + + def _xstart_to_eps(self, x_t, t, x_0): + assert x_t.shape == x_0.shape + return (x_t - (1 - t) * x_0) / (self.sigma_min + (1 - self.sigma_min) * t) + + def _v_to_xstart_eps(self, x_t, t, v): + assert x_t.shape == v.shape + eps = (1 - t) * v + x_t + x_0 = (1 - self.sigma_min) * x_t - (self.sigma_min + (1 - self.sigma_min) * t) * v + return x_0, eps + + def _pred_to_xstart(self, x_t, t, pred): + return (1 - self.sigma_min) * x_t - (self.sigma_min + (1 - self.sigma_min) * t) * pred + + def _xstart_to_pred(self, x_t, t, x_0): + return ((1 - self.sigma_min) * x_t - x_0) / (self.sigma_min + (1 - self.sigma_min) * t) + + def _inference_model(self, model, x_t, t, cond=None, **kwargs): + t = torch.tensor([1000 * t] * x_t.shape[0], device=x_t.device, dtype=torch.float32) + return model(x_t, t, cond, **kwargs) + + def _get_model_prediction(self, model, x_t, t, cond=None, **kwargs): + pred_v = self._inference_model(model, x_t, t, cond, **kwargs) + pred_x_0, pred_eps = self._v_to_xstart_eps(x_t=x_t, t=t, v=pred_v) + return pred_x_0, pred_eps, pred_v + + @torch.no_grad() + def sample_once( + self, + model, + x_t, + t: float, + t_prev: float, + cond: Optional[Any] = None, + **kwargs + ): + """ + Sample x_{t-1} from the model using Euler method. + + Args: + model: The model to sample from. + x_t: The [N x C x ...] tensor of noisy inputs at time t. + t: The current timestep. + t_prev: The previous timestep. + cond: conditional information. + **kwargs: Additional arguments for model inference. + + Returns: + a dict containing the following + - 'pred_x_prev': x_{t-1}. + - 'pred_x_0': a prediction of x_0. + """ + pred_x_0, pred_eps, pred_v = self._get_model_prediction(model, x_t, t, cond, **kwargs) + pred_x_prev = x_t - (t - t_prev) * pred_v + return edict({"pred_x_prev": pred_x_prev, "pred_x_0": pred_x_0}) + + @torch.no_grad() + def sample( + self, + model, + noise, + cond: Optional[Any] = None, + steps: int = 50, + rescale_t: float = 1.0, + verbose: bool = True, + tqdm_desc: str = "Sampling", + **kwargs + ): + """ + Generate samples from the model using Euler method. + + Args: + model: The model to sample from. + noise: The initial noise tensor. + cond: conditional information. + steps: The number of steps to sample. + rescale_t: The rescale factor for t. + verbose: If True, show a progress bar. + tqdm_desc: A customized tqdm desc. + **kwargs: Additional arguments for model_inference. + + Returns: + a dict containing the following + - 'samples': the model samples. + - 'pred_x_t': a list of prediction of x_t. + - 'pred_x_0': a list of prediction of x_0. + """ + sample = noise + t_seq = np.linspace(1, 0, steps + 1) + t_seq = rescale_t * t_seq / (1 + (rescale_t - 1) * t_seq) + t_seq = t_seq.tolist() + t_pairs = list((t_seq[i], t_seq[i + 1]) for i in range(steps)) + ret = edict({"samples": None, "pred_x_t": [], "pred_x_0": []}) + for t, t_prev in tqdm(t_pairs, desc=tqdm_desc, disable=not verbose): + out = self.sample_once(model, sample, t, t_prev, cond, **kwargs) + sample = out.pred_x_prev + ret.pred_x_t.append(out.pred_x_prev) + ret.pred_x_0.append(out.pred_x_0) + ret.samples = sample + return ret + + +class FlowEulerCfgSampler(ClassifierFreeGuidanceSamplerMixin, FlowEulerSampler): + """ + Generate samples from a flow-matching model using Euler sampling with classifier-free guidance. + """ + @torch.no_grad() + def sample( + self, + model, + noise, + cond, + neg_cond, + steps: int = 50, + rescale_t: float = 1.0, + guidance_strength: float = 3.0, + verbose: bool = True, + **kwargs + ): + """ + Generate samples from the model using Euler method. + + Args: + model: The model to sample from. + noise: The initial noise tensor. + cond: conditional information. + neg_cond: negative conditional information. + steps: The number of steps to sample. + rescale_t: The rescale factor for t. + guidance_strength: The strength of classifier-free guidance. + verbose: If True, show a progress bar. + **kwargs: Additional arguments for model_inference. + + Returns: + a dict containing the following + - 'samples': the model samples. + - 'pred_x_t': a list of prediction of x_t. + - 'pred_x_0': a list of prediction of x_0. + """ + return super().sample(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, guidance_strength=guidance_strength, **kwargs) + + +class FlowEulerGuidanceIntervalSampler(GuidanceIntervalSamplerMixin, ClassifierFreeGuidanceSamplerMixin, FlowEulerSampler): + """ + Generate samples from a flow-matching model using Euler sampling with classifier-free guidance and interval. + """ + @torch.no_grad() + def sample( + self, + model, + noise, + cond, + neg_cond, + steps: int = 50, + rescale_t: float = 1.0, + guidance_strength: float = 3.0, + guidance_interval: Tuple[float, float] = (0.0, 1.0), + verbose: bool = True, + **kwargs + ): + """ + Generate samples from the model using Euler method. + + Args: + model: The model to sample from. + noise: The initial noise tensor. + cond: conditional information. + neg_cond: negative conditional information. + steps: The number of steps to sample. + rescale_t: The rescale factor for t. + guidance_strength: The strength of classifier-free guidance. + guidance_interval: The interval for classifier-free guidance. + verbose: If True, show a progress bar. + **kwargs: Additional arguments for model_inference. + + Returns: + a dict containing the following + - 'samples': the model samples. + - 'pred_x_t': a list of prediction of x_t. + - 'pred_x_0': a list of prediction of x_0. + """ + return super().sample(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, guidance_strength=guidance_strength, guidance_interval=guidance_interval, **kwargs) diff --git a/trellis2/pipelines/samplers/guidance_interval_mixin.py b/trellis2/pipelines/samplers/guidance_interval_mixin.py new file mode 100644 index 0000000000000000000000000000000000000000..33853bcd2d4eec55c18483b2406d6187d1a8fce1 --- /dev/null +++ b/trellis2/pipelines/samplers/guidance_interval_mixin.py @@ -0,0 +1,13 @@ +from typing import * + + +class GuidanceIntervalSamplerMixin: + """ + A mixin class for samplers that apply classifier-free guidance with interval. + """ + + def _inference_model(self, model, x_t, t, cond, guidance_strength, guidance_interval, **kwargs): + if guidance_interval[0] <= t <= guidance_interval[1]: + return super()._inference_model(model, x_t, t, cond, guidance_strength=guidance_strength, **kwargs) + else: + return super()._inference_model(model, x_t, t, cond, guidance_strength=1, **kwargs) diff --git a/trellis2/pipelines/trellis2_image_to_3d.py b/trellis2/pipelines/trellis2_image_to_3d.py new file mode 100644 index 0000000000000000000000000000000000000000..220d4a7ea62e50caf32e01f9e082117a7313bef9 --- /dev/null +++ b/trellis2/pipelines/trellis2_image_to_3d.py @@ -0,0 +1,802 @@ +from typing import * +import torch +import torch.nn as nn +import numpy as np +from PIL import Image +from .base import Pipeline +from . import samplers, rembg +from ..modules.sparse import SparseTensor +from ..modules import image_feature_extractor +from ..representations import Mesh, MeshWithVoxel +from contextlib import contextmanager + + +class Trellis2ImageTo3DPipeline(Pipeline): + """ + Pipeline for inferring Trellis2 image-to-3D models. + + Args: + models (dict[str, nn.Module]): The models to use in the pipeline. + sparse_structure_sampler (samplers.Sampler): The sampler for the sparse structure. + shape_slat_sampler (samplers.Sampler): The sampler for the structured latent. + tex_slat_sampler (samplers.Sampler): The sampler for the texture latent. + sparse_structure_sampler_params (dict): The parameters for the sparse structure sampler. + shape_slat_sampler_params (dict): The parameters for the structured latent sampler. + tex_slat_sampler_params (dict): The parameters for the texture latent sampler. + shape_slat_normalization (dict): The normalization parameters for the structured latent. + tex_slat_normalization (dict): The normalization parameters for the texture latent. + image_cond_model (Callable): The image conditioning model. + rembg_model (Callable): The model for removing background. + low_vram (bool): Whether to use low-VRAM mode. + """ + model_names_to_load = [ + 'sparse_structure_flow_model', + 'sparse_structure_decoder', + 'shape_slat_flow_model_512', + 'shape_slat_flow_model_1024', + 'shape_slat_decoder', + 'tex_slat_flow_model_512', + 'tex_slat_flow_model_1024', + 'tex_slat_decoder', + ] + + def __init__( + self, + models: dict[str, nn.Module] = None, + sparse_structure_sampler: samplers.Sampler = None, + shape_slat_sampler: samplers.Sampler = None, + tex_slat_sampler: samplers.Sampler = None, + sparse_structure_sampler_params: dict = None, + shape_slat_sampler_params: dict = None, + tex_slat_sampler_params: dict = None, + shape_slat_normalization: dict = None, + tex_slat_normalization: dict = None, + image_cond_model: Callable = None, + rembg_model: Callable = None, + low_vram: bool = True, + default_pipeline_type: str = '1024_cascade', + ): + if models is None: + return + super().__init__(models) + self.sparse_structure_sampler = sparse_structure_sampler + self.shape_slat_sampler = shape_slat_sampler + self.tex_slat_sampler = tex_slat_sampler + self.sparse_structure_sampler_params = sparse_structure_sampler_params + self.shape_slat_sampler_params = shape_slat_sampler_params + self.tex_slat_sampler_params = tex_slat_sampler_params + self.shape_slat_normalization = shape_slat_normalization + self.tex_slat_normalization = tex_slat_normalization + self.image_cond_model = image_cond_model + self.rembg_model = rembg_model + self.low_vram = low_vram + self.default_pipeline_type = default_pipeline_type + self.pbr_attr_layout = { + 'base_color': slice(0, 3), + 'metallic': slice(3, 4), + 'roughness': slice(4, 5), + 'alpha': slice(5, 6), + } + self._device = 'cpu' + + @classmethod + def from_pretrained(cls, path: str, config_file: str = "pipeline.json") -> "Trellis2ImageTo3DPipeline": + """ + Load a pretrained model. + + Args: + path (str): The path to the model. Can be either local path or a Hugging Face repository. + """ + pipeline = super().from_pretrained(path, config_file) + args = pipeline._pretrained_args + + pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])(**args['sparse_structure_sampler']['args']) + pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params'] + + pipeline.shape_slat_sampler = getattr(samplers, args['shape_slat_sampler']['name'])(**args['shape_slat_sampler']['args']) + pipeline.shape_slat_sampler_params = args['shape_slat_sampler']['params'] + + pipeline.tex_slat_sampler = getattr(samplers, args['tex_slat_sampler']['name'])(**args['tex_slat_sampler']['args']) + pipeline.tex_slat_sampler_params = args['tex_slat_sampler']['params'] + + pipeline.shape_slat_normalization = args['shape_slat_normalization'] + pipeline.tex_slat_normalization = args['tex_slat_normalization'] + + pipeline.image_cond_model = getattr(image_feature_extractor, args['image_cond_model']['name'])(**args['image_cond_model']['args']) + pipeline.rembg_model = getattr(rembg, args['rembg_model']['name'])(**args['rembg_model']['args']) + + pipeline.low_vram = args.get('low_vram', True) + pipeline.default_pipeline_type = args.get('default_pipeline_type', '1024_cascade') + pipeline.pbr_attr_layout = { + 'base_color': slice(0, 3), + 'metallic': slice(3, 4), + 'roughness': slice(4, 5), + 'alpha': slice(5, 6), + } + pipeline._device = 'cpu' + + return pipeline + + def to(self, device: torch.device) -> None: + self._device = device + if not self.low_vram: + super().to(device) + self.image_cond_model.to(device) + if self.rembg_model is not None: + self.rembg_model.to(device) + + def preprocess_image(self, input: Image.Image) -> Image.Image: + """ + Preprocess the input image. + """ + # if has alpha channel, use it directly; otherwise, remove background + has_alpha = False + if input.mode == 'RGBA': + alpha = np.array(input)[:, :, 3] + if not np.all(alpha == 255): + has_alpha = True + max_size = max(input.size) + scale = min(1, 1024 / max_size) + if scale < 1: + input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS) + if has_alpha: + output = input + else: + input = input.convert('RGB') + if self.low_vram: + self.rembg_model.to(self.device) + output = self.rembg_model(input) + if self.low_vram: + self.rembg_model.cpu() + output_np = np.array(output) + alpha = output_np[:, :, 3] + bbox = np.argwhere(alpha > 0.8 * 255) + bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0]) + center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2 + size = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) + size = int(size * 1) + bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2 + output = output.crop(bbox) # type: ignore + output = np.array(output).astype(np.float32) / 255 + output = output[:, :, :3] * output[:, :, 3:4] + output = Image.fromarray((output * 255).astype(np.uint8)) + return output + + def get_cond(self, image: Union[torch.Tensor, list[Image.Image]], resolution: int, include_neg_cond: bool = True) -> dict: + """ + Get the conditioning information for the model. + + Args: + image (Union[torch.Tensor, list[Image.Image]]): The image prompts. + + Returns: + dict: The conditioning information + """ + self.image_cond_model.image_size = resolution + if self.low_vram: + self.image_cond_model.to(self.device) + cond = self.image_cond_model(image) + if self.low_vram: + self.image_cond_model.cpu() + if not include_neg_cond: + return {'cond': cond} + neg_cond = torch.zeros_like(cond) + return { + 'cond': cond, + 'neg_cond': neg_cond, + } + + def sample_sparse_structure( + self, + cond: dict, + resolution: int, + num_samples: int = 1, + sampler_params: dict = {}, + ) -> torch.Tensor: + """ + Sample sparse structures with the given conditioning. + + Args: + cond (dict): The conditioning information. + resolution (int): The resolution of the sparse structure. + num_samples (int): The number of samples to generate. + sampler_params (dict): Additional parameters for the sampler. + """ + # Sample sparse structure latent + flow_model = self.models['sparse_structure_flow_model'] + reso = flow_model.resolution + in_channels = flow_model.in_channels + noise = torch.randn(num_samples, in_channels, reso, reso, reso).to(self.device) + sampler_params = {**self.sparse_structure_sampler_params, **sampler_params} + if self.low_vram: + flow_model.to(self.device) + z_s = self.sparse_structure_sampler.sample( + flow_model, + noise, + **cond, + **sampler_params, + verbose=True, + tqdm_desc="Sampling sparse structure", + ).samples + if self.low_vram: + flow_model.cpu() + + # Decode sparse structure latent + decoder = self.models['sparse_structure_decoder'] + if self.low_vram: + decoder.to(self.device) + decoded = decoder(z_s)>0 + if self.low_vram: + decoder.cpu() + if resolution != decoded.shape[2]: + ratio = decoded.shape[2] // resolution + decoded = torch.nn.functional.max_pool3d(decoded.float(), ratio, ratio, 0) > 0.5 + coords = torch.argwhere(decoded)[:, [0, 2, 3, 4]].int() + + return coords + + def sample_shape_slat( + self, + cond: dict, + flow_model, + coords: torch.Tensor, + sampler_params: dict = {}, + ) -> SparseTensor: + """ + Sample structured latent with the given conditioning. + + Args: + cond (dict): The conditioning information. + coords (torch.Tensor): The coordinates of the sparse structure. + sampler_params (dict): Additional parameters for the sampler. + """ + # Sample structured latent + noise = SparseTensor( + feats=torch.randn(coords.shape[0], flow_model.in_channels).to(self.device), + coords=coords, + ) + sampler_params = {**self.shape_slat_sampler_params, **sampler_params} + if self.low_vram: + flow_model.to(self.device) + slat = self.shape_slat_sampler.sample( + flow_model, + noise, + **cond, + **sampler_params, + verbose=True, + tqdm_desc="Sampling shape SLat", + ).samples + if self.low_vram: + flow_model.cpu() + + std = torch.tensor(self.shape_slat_normalization['std'])[None].to(slat.device) + mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(slat.device) + slat = slat * std + mean + + return slat + + def sample_shape_slat_cascade( + self, + lr_cond: dict, + cond: dict, + flow_model_lr, + flow_model, + lr_resolution: int, + resolution: int, + coords: torch.Tensor, + sampler_params: dict = {}, + max_num_tokens: int = 49152, + ) -> SparseTensor: + """ + Sample structured latent with the given conditioning. + + Args: + cond (dict): The conditioning information. + coords (torch.Tensor): The coordinates of the sparse structure. + sampler_params (dict): Additional parameters for the sampler. + """ + # LR + noise = SparseTensor( + feats=torch.randn(coords.shape[0], flow_model_lr.in_channels).to(self.device), + coords=coords, + ) + sampler_params = {**self.shape_slat_sampler_params, **sampler_params} + if self.low_vram: + flow_model_lr.to(self.device) + slat = self.shape_slat_sampler.sample( + flow_model_lr, + noise, + **lr_cond, + **sampler_params, + verbose=True, + tqdm_desc="Sampling shape SLat", + ).samples + if self.low_vram: + flow_model_lr.cpu() + std = torch.tensor(self.shape_slat_normalization['std'])[None].to(slat.device) + mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(slat.device) + slat = slat * std + mean + + # Upsample + if self.low_vram: + self.models['shape_slat_decoder'].to(self.device) + self.models['shape_slat_decoder'].low_vram = True + hr_coords = self.models['shape_slat_decoder'].upsample(slat, upsample_times=4) + if self.low_vram: + self.models['shape_slat_decoder'].cpu() + self.models['shape_slat_decoder'].low_vram = False + hr_resolution = resolution + while True: + quant_coords = torch.cat([ + hr_coords[:, :1], + ((hr_coords[:, 1:] + 0.5) / lr_resolution * (hr_resolution // 16)).int(), + ], dim=1) + coords = quant_coords.unique(dim=0) + num_tokens = coords.shape[0] + if num_tokens < max_num_tokens or hr_resolution == 1024: + if hr_resolution != resolution: + print(f"Due to the limited number of tokens, the resolution is reduced to {hr_resolution}.") + break + hr_resolution -= 128 + + # Sample structured latent + noise = SparseTensor( + feats=torch.randn(coords.shape[0], flow_model.in_channels).to(self.device), + coords=coords, + ) + sampler_params = {**self.shape_slat_sampler_params, **sampler_params} + if self.low_vram: + flow_model.to(self.device) + slat = self.shape_slat_sampler.sample( + flow_model, + noise, + **cond, + **sampler_params, + verbose=True, + tqdm_desc="Sampling shape SLat", + ).samples + if self.low_vram: + flow_model.cpu() + + std = torch.tensor(self.shape_slat_normalization['std'])[None].to(slat.device) + mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(slat.device) + slat = slat * std + mean + + return slat, hr_resolution + + def decode_shape_slat( + self, + slat: SparseTensor, + resolution: int, + ) -> Tuple[List[Mesh], List[SparseTensor]]: + """ + Decode the structured latent. + + Args: + slat (SparseTensor): The structured latent. + + Returns: + List[Mesh]: The decoded meshes. + List[SparseTensor]: The decoded substructures. + """ + self.models['shape_slat_decoder'].set_resolution(resolution) + if self.low_vram: + self.models['shape_slat_decoder'].to(self.device) + self.models['shape_slat_decoder'].low_vram = True + ret = self.models['shape_slat_decoder'](slat, return_subs=True) + if self.low_vram: + self.models['shape_slat_decoder'].cpu() + self.models['shape_slat_decoder'].low_vram = False + return ret + + def sample_tex_slat( + self, + cond: dict, + flow_model, + shape_slat: SparseTensor, + sampler_params: dict = {}, + ) -> SparseTensor: + """ + Sample structured latent with the given conditioning. + + Args: + cond (dict): The conditioning information. + shape_slat (SparseTensor): The structured latent for shape + sampler_params (dict): Additional parameters for the sampler. + """ + # Sample structured latent + std = torch.tensor(self.shape_slat_normalization['std'])[None].to(shape_slat.device) + mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(shape_slat.device) + shape_slat = (shape_slat - mean) / std + + in_channels = flow_model.in_channels if isinstance(flow_model, nn.Module) else flow_model[0].in_channels + noise = shape_slat.replace(feats=torch.randn(shape_slat.coords.shape[0], in_channels - shape_slat.feats.shape[1]).to(self.device)) + sampler_params = {**self.tex_slat_sampler_params, **sampler_params} + if self.low_vram: + flow_model.to(self.device) + slat = self.tex_slat_sampler.sample( + flow_model, + noise, + concat_cond=shape_slat, + **cond, + **sampler_params, + verbose=True, + tqdm_desc="Sampling texture SLat", + ).samples + if self.low_vram: + flow_model.cpu() + + std = torch.tensor(self.tex_slat_normalization['std'])[None].to(slat.device) + mean = torch.tensor(self.tex_slat_normalization['mean'])[None].to(slat.device) + slat = slat * std + mean + + return slat + + def decode_tex_slat( + self, + slat: SparseTensor, + subs: List[SparseTensor], + ) -> SparseTensor: + """ + Decode the structured latent. + + Args: + slat (SparseTensor): The structured latent. + + Returns: + SparseTensor: The decoded texture voxels + """ + if self.low_vram: + self.models['tex_slat_decoder'].to(self.device) + ret = self.models['tex_slat_decoder'](slat, guide_subs=subs) * 0.5 + 0.5 + if self.low_vram: + self.models['tex_slat_decoder'].cpu() + return ret + + @torch.no_grad() + def decode_latent( + self, + shape_slat: SparseTensor, + tex_slat: SparseTensor, + resolution: int, + ) -> List[MeshWithVoxel]: + """ + Decode the latent codes. + + Args: + shape_slat (SparseTensor): The structured latent for shape. + tex_slat (SparseTensor): The structured latent for texture. + resolution (int): The resolution of the output. + """ + meshes, subs = self.decode_shape_slat(shape_slat, resolution) + tex_voxels = self.decode_tex_slat(tex_slat, subs) + out_mesh = [] + for m, v in zip(meshes, tex_voxels): + m.fill_holes() + out_mesh.append( + MeshWithVoxel( + m.vertices, m.faces, + origin = [-0.5, -0.5, -0.5], + voxel_size = 1 / resolution, + coords = v.coords[:, 1:], + attrs = v.feats, + voxel_shape = torch.Size([*v.shape, *v.spatial_shape]), + layout=self.pbr_attr_layout + ) + ) + return out_mesh + + @torch.no_grad() + def run( + self, + image: Image.Image, + num_samples: int = 1, + seed: int = 42, + sparse_structure_sampler_params: dict = {}, + shape_slat_sampler_params: dict = {}, + tex_slat_sampler_params: dict = {}, + preprocess_image: bool = True, + return_latent: bool = False, + pipeline_type: Optional[str] = None, + max_num_tokens: int = 49152, + ) -> List[MeshWithVoxel]: + """ + Run the pipeline. + + Args: + image (Image.Image): The image prompt. + num_samples (int): The number of samples to generate. + seed (int): The random seed. + sparse_structure_sampler_params (dict): Additional parameters for the sparse structure sampler. + shape_slat_sampler_params (dict): Additional parameters for the shape SLat sampler. + tex_slat_sampler_params (dict): Additional parameters for the texture SLat sampler. + preprocess_image (bool): Whether to preprocess the image. + return_latent (bool): Whether to return the latent codes. + pipeline_type (str): The type of the pipeline. Options: '512', '1024', '1024_cascade', '1536_cascade'. + max_num_tokens (int): The maximum number of tokens to use. + """ + # Check pipeline type + pipeline_type = pipeline_type or self.default_pipeline_type + if pipeline_type == '512': + assert 'shape_slat_flow_model_512' in self.models, "No 512 resolution shape SLat flow model found." + assert 'tex_slat_flow_model_512' in self.models, "No 512 resolution texture SLat flow model found." + elif pipeline_type == '1024': + assert 'shape_slat_flow_model_1024' in self.models, "No 1024 resolution shape SLat flow model found." + assert 'tex_slat_flow_model_1024' in self.models, "No 1024 resolution texture SLat flow model found." + elif pipeline_type == '1024_cascade': + assert 'shape_slat_flow_model_512' in self.models, "No 512 resolution shape SLat flow model found." + assert 'shape_slat_flow_model_1024' in self.models, "No 1024 resolution shape SLat flow model found." + assert 'tex_slat_flow_model_1024' in self.models, "No 1024 resolution texture SLat flow model found." + elif pipeline_type == '1536_cascade': + assert 'shape_slat_flow_model_512' in self.models, "No 512 resolution shape SLat flow model found." + assert 'shape_slat_flow_model_1024' in self.models, "No 1024 resolution shape SLat flow model found." + assert 'tex_slat_flow_model_1024' in self.models, "No 1024 resolution texture SLat flow model found." + else: + raise ValueError(f"Invalid pipeline type: {pipeline_type}") + + if preprocess_image: + image = self.preprocess_image(image) + torch.manual_seed(seed) + cond_512 = self.get_cond([image], 512) + cond_1024 = self.get_cond([image], 1024) if pipeline_type != '512' else None + ss_res = {'512': 32, '1024': 64, '1024_cascade': 32, '1536_cascade': 32}[pipeline_type] + coords = self.sample_sparse_structure( + cond_512, ss_res, + num_samples, sparse_structure_sampler_params + ) + if pipeline_type == '512': + shape_slat = self.sample_shape_slat( + cond_512, self.models['shape_slat_flow_model_512'], + coords, shape_slat_sampler_params + ) + tex_slat = self.sample_tex_slat( + cond_512, self.models['tex_slat_flow_model_512'], + shape_slat, tex_slat_sampler_params + ) + res = 512 + elif pipeline_type == '1024': + shape_slat = self.sample_shape_slat( + cond_1024, self.models['shape_slat_flow_model_1024'], + coords, shape_slat_sampler_params + ) + tex_slat = self.sample_tex_slat( + cond_1024, self.models['tex_slat_flow_model_1024'], + shape_slat, tex_slat_sampler_params + ) + res = 1024 + elif pipeline_type == '1024_cascade': + shape_slat, res = self.sample_shape_slat_cascade( + cond_512, cond_1024, + self.models['shape_slat_flow_model_512'], self.models['shape_slat_flow_model_1024'], + 512, 1024, + coords, shape_slat_sampler_params, + max_num_tokens + ) + tex_slat = self.sample_tex_slat( + cond_1024, self.models['tex_slat_flow_model_1024'], + shape_slat, tex_slat_sampler_params + ) + elif pipeline_type == '1536_cascade': + shape_slat, res = self.sample_shape_slat_cascade( + cond_512, cond_1024, + self.models['shape_slat_flow_model_512'], self.models['shape_slat_flow_model_1024'], + 512, 1536, + coords, shape_slat_sampler_params, + max_num_tokens + ) + tex_slat = self.sample_tex_slat( + cond_1024, self.models['tex_slat_flow_model_1024'], + shape_slat, tex_slat_sampler_params + ) + torch.cuda.empty_cache() + out_mesh = self.decode_latent(shape_slat, tex_slat, res) + if return_latent: + return out_mesh, (shape_slat, tex_slat, res) + else: + return out_mesh + + @contextmanager + def inject_sampler_multi_image( + self, + sampler_name: str, + num_images: int, + num_steps: int, + mode: Literal['stochastic', 'multidiffusion'] = 'multidiffusion', + ): + """ + Inject a sampler with multiple images as condition. + + Args: + sampler_name (str): The name of the sampler to inject. + num_images (int): The number of images to condition on. + num_steps (int): The number of steps to run the sampler for. + """ + sampler = getattr(self, sampler_name) + setattr(sampler, f'_old_inference_model', sampler._inference_model) + + if mode == 'stochastic': + if num_images > num_steps: + print( + f"\033[93mWarning: number of conditioning images is greater than number of steps for {sampler_name}. " + "This may lead to performance degradation.\033[0m") + + cond_indices = (np.arange(num_steps) % num_images).tolist() + + def _new_inference_model(self, model, x_t, t, cond, **kwargs): + nonlocal cond_indices + if cond_indices: + cond_idx = cond_indices.pop(0) + else: + cond_indices = (np.arange(num_steps) % num_images).tolist() + cond_idx = cond_indices.pop(0) + cond_i = cond[cond_idx:cond_idx + 1] + return self._old_inference_model(model, x_t, t, cond=cond_i, **kwargs) + + elif mode == 'multidiffusion': + from .samplers import FlowEulerSampler + def _new_inference_model(self, model, x_t, t, cond, **kwargs): + # if cfg_interval[0] <= t <= cfg_interval[1]: + # preds = [] + # for i in range(len(cond)): + # preds.append(FlowEulerSampler._inference_model(self, model, x_t, t, cond[i:i + 1], **kwargs)) + # pred = sum(preds) / len(preds) + # neg_pred = FlowEulerSampler._inference_model(self, model, x_t, t, neg_cond, **kwargs) + # return (1 + cfg_strength) * pred - cfg_strength * neg_pred + # else: + + # Filter out guidance-related kwargs that the base sampler doesn't handle + filtered_kwargs = {k: v for k, v in kwargs.items() + if k not in ('neg_cond', 'guidance_strength', 'guidance_interval', 'guidance_rescale')} + preds = [] + for i in range(len(cond)): + preds.append(FlowEulerSampler._inference_model(self, model, x_t, t, cond[i:i + 1], **filtered_kwargs)) + pred = sum(preds) / len(preds) + return pred + + else: + raise ValueError(f"Unsupported mode: {mode}") + + sampler._inference_model = _new_inference_model.__get__(sampler, type(sampler)) + + yield + + sampler._inference_model = sampler._old_inference_model + delattr(sampler, f'_old_inference_model') + + @torch.no_grad() + def run_multi_image( + self, + images: List[Image.Image], + num_samples: int = 1, + seed: int = 42, + sparse_structure_sampler_params: dict = {}, + shape_slat_sampler_params: dict = {}, + tex_slat_sampler_params: dict = {}, + preprocess_image: bool = True, + return_latent: bool = False, + pipeline_type: Optional[str] = None, + max_num_tokens: int = 49152, + mode: Literal['stochastic', 'multidiffusion'] = 'multidiffusion', + ) -> List[MeshWithVoxel]: + """ + Run the multi-image pipeline. + + Args: + images (List[Image.Image]): The multi-image prompt. + num_samples (int): The number of samples to generate. + seed (int): The random seed. + sparse_structure_sampler_params (dict): Additional parameters for the sparse structure sampler. + shape_slat_sampler_params (dict): Additional parameters for the shape SLat sampler. + tex_slat_sampler_params (dict): Additional parameters for the texture SLat sampler. + preprocess_image (bool): Whether to preprocess the image. + return_latent (bool): Whether to return the latent codes. + pipeline_type (str): The type of the pipeline. Options: '512', '1024', '1024_cascade', '1536_cascade'. + max_num_tokens (int): The maximum number of tokens to use. + mode: The multi-image conditioning mode. + """ + # Check pipeline type + pipeline_type = pipeline_type or self.default_pipeline_type + if pipeline_type == '512': + assert 'shape_slat_flow_model_512' in self.models, "No 512 resolution shape SLat flow model found." + assert 'tex_slat_flow_model_512' in self.models, "No 512 resolution texture SLat flow model found." + elif pipeline_type == '1024': + assert 'shape_slat_flow_model_1024' in self.models, "No 1024 resolution shape SLat flow model found." + assert 'tex_slat_flow_model_1024' in self.models, "No 1024 resolution texture SLat flow model found." + elif pipeline_type == '1024_cascade': + assert 'shape_slat_flow_model_512' in self.models, "No 512 resolution shape SLat flow model found." + assert 'shape_slat_flow_model_1024' in self.models, "No 1024 resolution shape SLat flow model found." + assert 'tex_slat_flow_model_1024' in self.models, "No 1024 resolution texture SLat flow model found." + elif pipeline_type == '1536_cascade': + assert 'shape_slat_flow_model_512' in self.models, "No 512 resolution shape SLat flow model found." + assert 'shape_slat_flow_model_1024' in self.models, "No 1024 resolution shape SLat flow model found." + assert 'tex_slat_flow_model_1024' in self.models, "No 1024 resolution texture SLat flow model found." + else: + raise ValueError(f"Invalid pipeline type: {pipeline_type}") + + if preprocess_image: + images = [self.preprocess_image(image) for image in images] + torch.manual_seed(seed) + cond_512 = self.get_cond(images, 512) + cond_512['neg_cond'] = cond_512['neg_cond'][:1] + if pipeline_type != '512': + cond_1024 = self.get_cond(images, 1024) + cond_1024['neg_cond'] = cond_1024['neg_cond'][:1] + else: + cond_1024 = None + + ss_res = {'512': 32, '1024': 64, '1024_cascade': 32, '1536_cascade': 32}[pipeline_type] + ss_steps = {**self.sparse_structure_sampler_params, **sparse_structure_sampler_params}.get('steps') + with self.inject_sampler_multi_image('sparse_structure_sampler', len(images), ss_steps, mode=mode): + coords = self.sample_sparse_structure( + cond_512, ss_res, + num_samples, sparse_structure_sampler_params + ) + + shape_slat_steps = {**self.shape_slat_sampler_params, **shape_slat_sampler_params}.get('steps') + tex_slat_steps = {**self.tex_slat_sampler_params, **tex_slat_sampler_params}.get('steps') + if pipeline_type == '512': + with ( + self.inject_sampler_multi_image('shape_slat_sampler', len(images), shape_slat_steps, mode=mode), + self.inject_sampler_multi_image('tex_slat_sampler', len(images), tex_slat_steps, mode=mode), + ): + shape_slat = self.sample_shape_slat( + cond_512, self.models['shape_slat_flow_model_512'], + coords, shape_slat_sampler_params + ) + tex_slat = self.sample_tex_slat( + cond_512, self.models['tex_slat_flow_model_512'], + shape_slat, tex_slat_sampler_params + ) + res = 512 + elif pipeline_type == '1024': + with ( + self.inject_sampler_multi_image('shape_slat_sampler', len(images), shape_slat_steps, mode=mode), + self.inject_sampler_multi_image('tex_slat_sampler', len(images), tex_slat_steps, mode=mode), + ): + shape_slat = self.sample_shape_slat( + cond_1024, self.models['shape_slat_flow_model_1024'], + coords, shape_slat_sampler_params + ) + tex_slat = self.sample_tex_slat( + cond_1024, self.models['tex_slat_flow_model_1024'], + shape_slat, tex_slat_sampler_params + ) + res = 1024 + elif pipeline_type == '1024_cascade': + with ( + self.inject_sampler_multi_image('shape_slat_sampler', len(images), shape_slat_steps, mode=mode), + self.inject_sampler_multi_image('tex_slat_sampler', len(images), tex_slat_steps, mode=mode), + ): + shape_slat, res = self.sample_shape_slat_cascade( + cond_512, cond_1024, + self.models['shape_slat_flow_model_512'], self.models['shape_slat_flow_model_1024'], + 512, 1024, + coords, shape_slat_sampler_params, + max_num_tokens + ) + tex_slat = self.sample_tex_slat( + cond_1024, self.models['tex_slat_flow_model_1024'], + shape_slat, tex_slat_sampler_params + ) + elif pipeline_type == '1536_cascade': + with ( + self.inject_sampler_multi_image('shape_slat_sampler', len(images), shape_slat_steps, mode=mode), + self.inject_sampler_multi_image('tex_slat_sampler', len(images), tex_slat_steps, mode=mode), + ): + shape_slat, res = self.sample_shape_slat_cascade( + cond_512, cond_1024, + self.models['shape_slat_flow_model_512'], self.models['shape_slat_flow_model_1024'], + 512, 1536, + coords, shape_slat_sampler_params, + max_num_tokens + ) + tex_slat = self.sample_tex_slat( + cond_1024, self.models['tex_slat_flow_model_1024'], + shape_slat, tex_slat_sampler_params + ) + torch.cuda.empty_cache() + out_mesh = self.decode_latent(shape_slat, tex_slat, res) + if return_latent: + return out_mesh, (shape_slat, tex_slat, res) + else: + return out_mesh diff --git a/trellis2/pipelines/trellis2_texturing.py b/trellis2/pipelines/trellis2_texturing.py new file mode 100644 index 0000000000000000000000000000000000000000..61aab2951bef7a420910057266d2e7d1d46682e0 --- /dev/null +++ b/trellis2/pipelines/trellis2_texturing.py @@ -0,0 +1,408 @@ +from typing import * +import torch +import torch.nn as nn +import numpy as np +from PIL import Image +import trimesh +from .base import Pipeline +from . import samplers, rembg +from ..modules.sparse import SparseTensor +from ..modules import image_feature_extractor +import o_voxel +import cumesh +import nvdiffrast.torch as dr +import cv2 +import flex_gemm + + +class Trellis2TexturingPipeline(Pipeline): + """ + Pipeline for inferring Trellis2 image-to-3D models. + + Args: + models (dict[str, nn.Module]): The models to use in the pipeline. + tex_slat_sampler (samplers.Sampler): The sampler for the texture latent. + tex_slat_sampler_params (dict): The parameters for the texture latent sampler. + shape_slat_normalization (dict): The normalization parameters for the structured latent. + tex_slat_normalization (dict): The normalization parameters for the texture latent. + image_cond_model (Callable): The image conditioning model. + rembg_model (Callable): The model for removing background. + low_vram (bool): Whether to use low-VRAM mode. + """ + model_names_to_load = [ + 'shape_slat_encoder', + 'tex_slat_decoder', + 'tex_slat_flow_model_512', + 'tex_slat_flow_model_1024' + ] + + def __init__( + self, + models: dict[str, nn.Module] = None, + tex_slat_sampler: samplers.Sampler = None, + tex_slat_sampler_params: dict = None, + shape_slat_normalization: dict = None, + tex_slat_normalization: dict = None, + image_cond_model: Callable = None, + rembg_model: Callable = None, + low_vram: bool = True, + ): + if models is None: + return + super().__init__(models) + self.tex_slat_sampler = tex_slat_sampler + self.tex_slat_sampler_params = tex_slat_sampler_params + self.shape_slat_normalization = shape_slat_normalization + self.tex_slat_normalization = tex_slat_normalization + self.image_cond_model = image_cond_model + self.rembg_model = rembg_model + self.low_vram = low_vram + self.pbr_attr_layout = { + 'base_color': slice(0, 3), + 'metallic': slice(3, 4), + 'roughness': slice(4, 5), + 'alpha': slice(5, 6), + } + self._device = 'cpu' + + @classmethod + def from_pretrained(cls, path: str, config_file: str = "pipeline.json") -> "Trellis2TexturingPipeline": + """ + Load a pretrained model. + + Args: + path (str): The path to the model. Can be either local path or a Hugging Face repository. + """ + pipeline = super().from_pretrained(path, config_file) + args = pipeline._pretrained_args + + pipeline.tex_slat_sampler = getattr(samplers, args['tex_slat_sampler']['name'])(**args['tex_slat_sampler']['args']) + pipeline.tex_slat_sampler_params = args['tex_slat_sampler']['params'] + + pipeline.shape_slat_normalization = args['shape_slat_normalization'] + pipeline.tex_slat_normalization = args['tex_slat_normalization'] + + pipeline.image_cond_model = getattr(image_feature_extractor, args['image_cond_model']['name'])(**args['image_cond_model']['args']) + pipeline.rembg_model = getattr(rembg, args['rembg_model']['name'])(**args['rembg_model']['args']) + + pipeline.low_vram = args.get('low_vram', True) + pipeline.pbr_attr_layout = { + 'base_color': slice(0, 3), + 'metallic': slice(3, 4), + 'roughness': slice(4, 5), + 'alpha': slice(5, 6), + } + pipeline._device = 'cpu' + return pipeline + + def to(self, device: torch.device) -> None: + self._device = device + if not self.low_vram: + super().to(device) + self.image_cond_model.to(device) + if self.rembg_model is not None: + self.rembg_model.to(device) + + def preprocess_mesh(self, mesh: trimesh.Trimesh) -> trimesh.Trimesh: + """ + Preprocess the input mesh. + """ + vertices = mesh.vertices + vertices_min = vertices.min(axis=0) + vertices_max = vertices.max(axis=0) + center = (vertices_min + vertices_max) / 2 + scale = 0.99999 / (vertices_max - vertices_min).max() + vertices = (vertices - center) * scale + tmp = vertices[:, 1].copy() + vertices[:, 1] = -vertices[:, 2] + vertices[:, 2] = tmp + assert np.all(vertices >= -0.5) and np.all(vertices <= 0.5), 'vertices out of range' + return trimesh.Trimesh(vertices=vertices, faces=mesh.faces, process=False) + + def preprocess_image(self, input: Image.Image) -> Image.Image: + """ + Preprocess the input image. + """ + # if has alpha channel, use it directly; otherwise, remove background + has_alpha = False + if input.mode == 'RGBA': + alpha = np.array(input)[:, :, 3] + if not np.all(alpha == 255): + has_alpha = True + max_size = max(input.size) + scale = min(1, 1024 / max_size) + if scale < 1: + input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS) + if has_alpha: + output = input + else: + input = input.convert('RGB') + if self.low_vram: + self.rembg_model.to(self.device) + output = self.rembg_model(input) + if self.low_vram: + self.rembg_model.cpu() + output_np = np.array(output) + alpha = output_np[:, :, 3] + bbox = np.argwhere(alpha > 0.8 * 255) + bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0]) + center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2 + size = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) + size = int(size * 1) + bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2 + output = output.crop(bbox) # type: ignore + output = np.array(output).astype(np.float32) / 255 + output = output[:, :, :3] * output[:, :, 3:4] + output = Image.fromarray((output * 255).astype(np.uint8)) + return output + + def get_cond(self, image: Union[torch.Tensor, list[Image.Image]], resolution: int, include_neg_cond: bool = True) -> dict: + """ + Get the conditioning information for the model. + + Args: + image (Union[torch.Tensor, list[Image.Image]]): The image prompts. + + Returns: + dict: The conditioning information + """ + self.image_cond_model.image_size = resolution + if self.low_vram: + self.image_cond_model.to(self.device) + cond = self.image_cond_model(image) + if self.low_vram: + self.image_cond_model.cpu() + if not include_neg_cond: + return {'cond': cond} + neg_cond = torch.zeros_like(cond) + return { + 'cond': cond, + 'neg_cond': neg_cond, + } + + def encode_shape_slat( + self, + mesh: trimesh.Trimesh, + resolution: int = 1024, + ) -> SparseTensor: + """ + Encode the meshes to structured latent. + + Args: + mesh (trimesh.Trimesh): The mesh to encode. + resolution (int): The resolution of mesh + + Returns: + SparseTensor: The encoded structured latent. + """ + vertices = torch.from_numpy(mesh.vertices).float() + faces = torch.from_numpy(mesh.faces).long() + + voxel_indices, dual_vertices, intersected = o_voxel.convert.mesh_to_flexible_dual_grid( + vertices.cpu(), faces.cpu(), + grid_size=resolution, + aabb=[[-0.5,-0.5,-0.5],[0.5,0.5,0.5]], + face_weight=1.0, + boundary_weight=0.2, + regularization_weight=1e-2, + timing=True, + ) + + vertices = SparseTensor( + feats=dual_vertices * resolution - voxel_indices, + coords=torch.cat([torch.zeros_like(voxel_indices[:, 0:1]), voxel_indices], dim=-1) + ).to(self.device) + intersected = vertices.replace(intersected).to(self.device) + + if self.low_vram: + self.models['shape_slat_encoder'].to(self.device) + shape_slat = self.models['shape_slat_encoder'](vertices, intersected) + if self.low_vram: + self.models['shape_slat_encoder'].cpu() + return shape_slat + + def sample_tex_slat( + self, + cond: dict, + flow_model, + shape_slat: SparseTensor, + sampler_params: dict = {}, + ) -> SparseTensor: + """ + Sample structured latent with the given conditioning. + + Args: + cond (dict): The conditioning information. + shape_slat (SparseTensor): The structured latent for shape + sampler_params (dict): Additional parameters for the sampler. + """ + # Sample structured latent + std = torch.tensor(self.shape_slat_normalization['std'])[None].to(shape_slat.device) + mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(shape_slat.device) + shape_slat = (shape_slat - mean) / std + + in_channels = flow_model.in_channels if isinstance(flow_model, nn.Module) else flow_model[0].in_channels + noise = shape_slat.replace(feats=torch.randn(shape_slat.coords.shape[0], in_channels - shape_slat.feats.shape[1]).to(self.device)) + sampler_params = {**self.tex_slat_sampler_params, **sampler_params} + if self.low_vram: + flow_model.to(self.device) + slat = self.tex_slat_sampler.sample( + flow_model, + noise, + concat_cond=shape_slat, + **cond, + **sampler_params, + verbose=True, + tqdm_desc="Sampling texture SLat", + ).samples + if self.low_vram: + flow_model.cpu() + + std = torch.tensor(self.tex_slat_normalization['std'])[None].to(slat.device) + mean = torch.tensor(self.tex_slat_normalization['mean'])[None].to(slat.device) + slat = slat * std + mean + + return slat + + def decode_tex_slat( + self, + slat: SparseTensor, + ) -> SparseTensor: + """ + Decode the structured latent. + + Args: + slat (SparseTensor): The structured latent. + + Returns: + SparseTensor: The decoded texture voxels + """ + if self.low_vram: + self.models['tex_slat_decoder'].to(self.device) + ret = self.models['tex_slat_decoder'](slat) * 0.5 + 0.5 + if self.low_vram: + self.models['tex_slat_decoder'].cpu() + return ret + + def postprocess_mesh( + self, + mesh: trimesh.Trimesh, + pbr_voxel: SparseTensor, + resolution: int = 1024, + texture_size: int = 1024, + ) -> trimesh.Trimesh: + vertices = mesh.vertices + faces = mesh.faces + normals = mesh.vertex_normals + vertices_torch = torch.from_numpy(vertices).float().cuda() + faces_torch = torch.from_numpy(faces).int().cuda() + if hasattr(mesh, 'visual') and hasattr(mesh.visual, 'uv') and mesh.visual.uv is not None: + uvs = mesh.visual.uv.copy() + uvs[:, 1] = 1 - uvs[:, 1] + uvs_torch = torch.from_numpy(uvs).float().cuda() + else: + _cumesh = cumesh.CuMesh() + _cumesh.init(vertices_torch, faces_torch) + vertices_torch, faces_torch, uvs_torch, vmap = _cumesh.uv_unwrap(return_vmaps=True) + vertices_torch = vertices_torch.cuda() + faces_torch = faces_torch.cuda() + uvs_torch = uvs_torch.cuda() + vertices = vertices_torch.cpu().numpy() + faces = faces_torch.cpu().numpy() + uvs = uvs_torch.cpu().numpy() + normals = normals[vmap.cpu().numpy()] + + # rasterize + ctx = dr.RasterizeCudaContext() + uvs_torch = torch.cat([uvs_torch * 2 - 1, torch.zeros_like(uvs_torch[:, :1]), torch.ones_like(uvs_torch[:, :1])], dim=-1).unsqueeze(0) + rast, _ = dr.rasterize( + ctx, uvs_torch, faces_torch, + resolution=[texture_size, texture_size], + ) + mask = rast[0, ..., 3] > 0 + pos = dr.interpolate(vertices_torch.unsqueeze(0), rast, faces_torch)[0][0] + + attrs = torch.zeros(texture_size, texture_size, pbr_voxel.shape[1], device=self.device) + attrs[mask] = flex_gemm.ops.grid_sample.grid_sample_3d( + pbr_voxel.feats, + pbr_voxel.coords, + shape=torch.Size([*pbr_voxel.shape, *pbr_voxel.spatial_shape]), + grid=((pos[mask] + 0.5) * resolution).reshape(1, -1, 3), + mode='trilinear', + ) + + # construct mesh + mask = mask.cpu().numpy() + base_color = np.clip(attrs[..., self.pbr_attr_layout['base_color']].cpu().numpy() * 255, 0, 255).astype(np.uint8) + metallic = np.clip(attrs[..., self.pbr_attr_layout['metallic']].cpu().numpy() * 255, 0, 255).astype(np.uint8) + roughness = np.clip(attrs[..., self.pbr_attr_layout['roughness']].cpu().numpy() * 255, 0, 255).astype(np.uint8) + alpha = np.clip(attrs[..., self.pbr_attr_layout['alpha']].cpu().numpy() * 255, 0, 255).astype(np.uint8) + + # extend + mask = (~mask).astype(np.uint8) + base_color = cv2.inpaint(base_color, mask, 3, cv2.INPAINT_TELEA) + metallic = cv2.inpaint(metallic, mask, 1, cv2.INPAINT_TELEA)[..., None] + roughness = cv2.inpaint(roughness, mask, 1, cv2.INPAINT_TELEA)[..., None] + alpha = cv2.inpaint(alpha, mask, 1, cv2.INPAINT_TELEA)[..., None] + + material = trimesh.visual.material.PBRMaterial( + baseColorTexture=Image.fromarray(np.concatenate([base_color, alpha], axis=-1)), + baseColorFactor=np.array([255, 255, 255, 255], dtype=np.uint8), + metallicRoughnessTexture=Image.fromarray(np.concatenate([np.zeros_like(metallic), roughness, metallic], axis=-1)), + metallicFactor=1.0, + roughnessFactor=1.0, + alphaMode='OPAQUE', + doubleSided=True, + ) + + # Swap Y and Z axes, invert Y (common conversion for GLB compatibility) + vertices[:, 1], vertices[:, 2] = vertices[:, 2], -vertices[:, 1] + normals[:, 1], normals[:, 2] = normals[:, 2], -normals[:, 1] + uvs[:, 1] = 1 - uvs[:, 1] # Flip UV V-coordinate + + textured_mesh = trimesh.Trimesh( + vertices=vertices, + faces=faces, + vertex_normals=normals, + process=False, + visual=trimesh.visual.TextureVisuals(uv=uvs, material=material) + ) + + return textured_mesh + + + @torch.no_grad() + def run( + self, + mesh: trimesh.Trimesh, + image: Image.Image, + seed: int = 42, + tex_slat_sampler_params: dict = {}, + preprocess_image: bool = True, + resolution: int = 1024, + texture_size: int = 2048, + ) -> trimesh.Trimesh: + """ + Run the pipeline. + + Args: + mesh (trimesh.Trimesh): The mesh to texture. + image (Image.Image): The image prompt. + seed (int): The random seed. + tex_slat_sampler_params (dict): Additional parameters for the texture latent sampler. + preprocess_image (bool): Whether to preprocess the image. + """ + if preprocess_image: + image = self.preprocess_image(image) + mesh = self.preprocess_mesh(mesh) + torch.manual_seed(seed) + cond = self.get_cond([image], 512) if resolution == 512 else self.get_cond([image], 1024) + shape_slat = self.encode_shape_slat(mesh, resolution) + tex_model = self.models['tex_slat_flow_model_512'] if resolution == 512 else self.models['tex_slat_flow_model_1024'] + tex_slat = self.sample_tex_slat( + cond, tex_model, + shape_slat, tex_slat_sampler_params + ) + pbr_voxel = self.decode_tex_slat(tex_slat) + out_mesh = self.postprocess_mesh(mesh, pbr_voxel, resolution, texture_size) + return out_mesh diff --git a/trellis2/renderers/__init__.py b/trellis2/renderers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6c883fed69683905b9cdf52158029163cc9e520f --- /dev/null +++ b/trellis2/renderers/__init__.py @@ -0,0 +1,33 @@ +import importlib + +__attributes = { + 'MeshRenderer': 'mesh_renderer', + 'VoxelRenderer': 'voxel_renderer', + 'PbrMeshRenderer': 'pbr_mesh_renderer', + 'EnvMap': 'pbr_mesh_renderer', +} + +__submodules = [] + +__all__ = list(__attributes.keys()) + __submodules + +def __getattr__(name): + if name not in globals(): + if name in __attributes: + module_name = __attributes[name] + module = importlib.import_module(f".{module_name}", __name__) + globals()[name] = getattr(module, name) + elif name in __submodules: + module = importlib.import_module(f".{name}", __name__) + globals()[name] = module + else: + raise AttributeError(f"module {__name__} has no attribute {name}") + return globals()[name] + + +# For Pylance +if __name__ == '__main__': + from .mesh_renderer import MeshRenderer + from .voxel_renderer import VoxelRenderer + from .pbr_mesh_renderer import PbrMeshRenderer, EnvMap + \ No newline at end of file diff --git a/trellis2/renderers/__pycache__/__init__.cpython-311.pyc b/trellis2/renderers/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b84d9796884f45c56f6518d2072c7a3f4e92081 Binary files /dev/null and b/trellis2/renderers/__pycache__/__init__.cpython-311.pyc differ diff --git a/trellis2/renderers/mesh_renderer.py b/trellis2/renderers/mesh_renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..f35817804055adac70e590a8da3bd50974f8c984 --- /dev/null +++ b/trellis2/renderers/mesh_renderer.py @@ -0,0 +1,414 @@ +from typing import * +import torch +from easydict import EasyDict as edict +from ..representations.mesh import Mesh, MeshWithVoxel, MeshWithPbrMaterial, TextureFilterMode, AlphaMode, TextureWrapMode +import torch.nn.functional as F + + +def intrinsics_to_projection( + intrinsics: torch.Tensor, + near: float, + far: float, + ) -> torch.Tensor: + """ + OpenCV intrinsics to OpenGL perspective matrix + + Args: + intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix + near (float): near plane to clip + far (float): far plane to clip + Returns: + (torch.Tensor): [4, 4] OpenGL perspective matrix + """ + fx, fy = intrinsics[0, 0], intrinsics[1, 1] + cx, cy = intrinsics[0, 2], intrinsics[1, 2] + ret = torch.zeros((4, 4), dtype=intrinsics.dtype, device=intrinsics.device) + ret[0, 0] = 2 * fx + ret[1, 1] = 2 * fy + ret[0, 2] = 2 * cx - 1 + ret[1, 2] = - 2 * cy + 1 + ret[2, 2] = (far + near) / (far - near) + ret[2, 3] = 2 * near * far / (near - far) + ret[3, 2] = 1. + return ret + + +class MeshRenderer: + """ + Renderer for the Mesh representation. + + Args: + rendering_options (dict): Rendering options. + """ + def __init__(self, rendering_options={}, device='cuda'): + if 'dr' not in globals(): + import nvdiffrast.torch as dr + + self.rendering_options = edict({ + "resolution": None, + "near": None, + "far": None, + "ssaa": 1, + "chunk_size": None, + "antialias": True, + "clamp_barycentric_coords": False, + }) + self.rendering_options.update(rendering_options) + self.glctx = dr.RasterizeCudaContext(device=device) + self.device=device + + def render( + self, + mesh : Mesh, + extrinsics: torch.Tensor, + intrinsics: torch.Tensor, + return_types = ["mask", "normal", "depth"], + transformation : Optional[torch.Tensor] = None + ) -> edict: + """ + Render the mesh. + + Args: + mesh : meshmodel + extrinsics (torch.Tensor): (4, 4) camera extrinsics + intrinsics (torch.Tensor): (3, 3) camera intrinsics + return_types (list): list of return types, can be "attr", "mask", "depth", "coord", "normal" + + Returns: + edict based on return_types containing: + attr (torch.Tensor): [C, H, W] rendered attr image + depth (torch.Tensor): [H, W] rendered depth image + normal (torch.Tensor): [3, H, W] rendered normal image + mask (torch.Tensor): [H, W] rendered mask image + """ + if 'dr' not in globals(): + import nvdiffrast.torch as dr + + resolution = self.rendering_options["resolution"] + near = self.rendering_options["near"] + far = self.rendering_options["far"] + ssaa = self.rendering_options["ssaa"] + chunk_size = self.rendering_options["chunk_size"] + antialias = self.rendering_options["antialias"] + clamp_barycentric_coords = self.rendering_options["clamp_barycentric_coords"] + + if mesh.vertices.shape[0] == 0 or mesh.faces.shape[0] == 0: + ret_dict = edict() + for type in return_types: + if type == "mask" : + ret_dict[type] = torch.zeros((resolution, resolution), dtype=torch.float32, device=self.device) + elif type == "depth": + ret_dict[type] = torch.zeros((resolution, resolution), dtype=torch.float32, device=self.device) + elif type == "normal": + ret_dict[type] = torch.full((3, resolution, resolution), 0.5, dtype=torch.float32, device=self.device) + elif type == "coord": + ret_dict[type] = torch.zeros((3, resolution, resolution), dtype=torch.float32, device=self.device) + elif type == "attr": + if isinstance(mesh, MeshWithVoxel): + ret_dict[type] = torch.zeros((mesh.attrs.shape[-1], resolution, resolution), dtype=torch.float32, device=self.device) + else: + ret_dict[type] = torch.zeros((mesh.vertex_attrs.shape[-1], resolution, resolution), dtype=torch.float32, device=self.device) + return ret_dict + + perspective = intrinsics_to_projection(intrinsics, near, far) + + full_proj = (perspective @ extrinsics).unsqueeze(0) + extrinsics = extrinsics.unsqueeze(0) + + vertices = mesh.vertices.unsqueeze(0) + vertices_homo = torch.cat([vertices, torch.ones_like(vertices[..., :1])], dim=-1) + if transformation is not None: + vertices_homo = torch.bmm(vertices_homo, transformation.unsqueeze(0).transpose(-1, -2)) + vertices = vertices_homo[..., :3].contiguous() + vertices_camera = torch.bmm(vertices_homo, extrinsics.transpose(-1, -2)) + vertices_clip = torch.bmm(vertices_homo, full_proj.transpose(-1, -2)) + faces = mesh.faces + + if 'normal' in return_types: + v0 = vertices_camera[0, mesh.faces[:, 0], :3] + v1 = vertices_camera[0, mesh.faces[:, 1], :3] + v2 = vertices_camera[0, mesh.faces[:, 2], :3] + e0 = v1 - v0 + e1 = v2 - v0 + face_normal = torch.cross(e0, e1, dim=1) + face_normal = F.normalize(face_normal, dim=1) + face_normal = torch.where(torch.sum(face_normal * v0, dim=1, keepdim=True) > 0, face_normal, -face_normal) + + out_dict = edict() + if chunk_size is None: + rast, rast_db = dr.rasterize( + self.glctx, vertices_clip, faces, (resolution * ssaa, resolution * ssaa) + ) + if clamp_barycentric_coords: + rast[..., :2] = torch.clamp(rast[..., :2], 0, 1) + rast[..., :2] /= torch.where(rast[..., :2].sum(dim=-1, keepdim=True) > 1, rast[..., :2].sum(dim=-1, keepdim=True), torch.ones_like(rast[..., :2])) + for type in return_types: + img = None + if type == "mask" : + img = (rast[..., -1:] > 0).float() + if antialias: img = dr.antialias(img, rast, vertices_clip, faces) + elif type == "depth": + img = dr.interpolate(vertices_camera[..., 2:3].contiguous(), rast, faces)[0] + if antialias: img = dr.antialias(img, rast, vertices_clip, faces) + elif type == "normal" : + img = dr.interpolate(face_normal.unsqueeze(0), rast, torch.arange(face_normal.shape[0], dtype=torch.int, device=self.device).unsqueeze(1).repeat(1, 3).contiguous())[0] + if antialias: img = dr.antialias(img, rast, vertices_clip, faces) + img = (img + 1) / 2 + elif type == "coord": + img = dr.interpolate(vertices, rast, faces)[0] + if antialias: img = dr.antialias(img, rast, vertices_clip, faces) + elif type == "attr": + if isinstance(mesh, MeshWithVoxel): + if 'grid_sample_3d' not in globals(): + from flex_gemm.ops.grid_sample import grid_sample_3d + mask = rast[..., -1:] > 0 + xyz = dr.interpolate(vertices, rast, faces)[0] + xyz = ((xyz - mesh.origin) / mesh.voxel_size).reshape(1, -1, 3) + img = grid_sample_3d( + mesh.attrs, + torch.cat([torch.zeros_like(mesh.coords[..., :1]), mesh.coords], dim=-1), + mesh.voxel_shape, + xyz, + mode='trilinear' + ) + img = img.reshape(1, resolution * ssaa, resolution * ssaa, mesh.attrs.shape[-1]) * mask + elif isinstance(mesh, MeshWithPbrMaterial): + tri_id = rast[0, :, :, -1:] + mask = tri_id > 0 + uv_coords = mesh.uv_coords.reshape(1, -1, 2) + texc, texd = dr.interpolate( + uv_coords, + rast, + torch.arange(mesh.uv_coords.shape[0] * 3, dtype=torch.int, device=self.device).reshape(-1, 3), + rast_db=rast_db, + diff_attrs='all' + ) + # Fix problematic texture coordinates + texc = torch.nan_to_num(texc, nan=0.0, posinf=1e3, neginf=-1e3) + texc = torch.clamp(texc, min=-1e3, max=1e3) + texd = torch.nan_to_num(texd, nan=0.0, posinf=1e3, neginf=-1e3) + texd = torch.clamp(texd, min=-1e3, max=1e3) + mid = mesh.material_ids[(tri_id - 1).long()] + imgs = { + 'base_color': torch.zeros((resolution * ssaa, resolution * ssaa, 3), dtype=torch.float32, device=self.device), + 'metallic': torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device), + 'roughness': torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device), + 'alpha': torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device) + } + for id, mat in enumerate(mesh.materials): + mat_mask = (mid == id).float() * mask.float() + mat_texc = texc * mat_mask + mat_texd = texd * mat_mask + + if mat.base_color_texture is not None: + base_color = dr.texture( + mat.base_color_texture.image.unsqueeze(0), + mat_texc, + mat_texd, + filter_mode='linear-mipmap-linear' if mat.base_color_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest', + boundary_mode='clamp' if mat.base_color_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap' + )[0] + imgs['base_color'] += base_color * mat.base_color_factor * mat_mask + else: + imgs['base_color'] += mat.base_color_factor * mat_mask + + if mat.metallic_texture is not None: + metallic = dr.texture( + mat.metallic_texture.image.unsqueeze(0), + mat_texc, + mat_texd, + filter_mode='linear-mipmap-linear' if mat.metallic_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest', + boundary_mode='clamp' if mat.metallic_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap' + )[0] + imgs['metallic'] += metallic * mat.metallic_factor * mat_mask + else: + imgs['metallic'] += mat.metallic_factor * mat_mask + + if mat.roughness_texture is not None: + roughness = dr.texture( + mat.roughness_texture.image.unsqueeze(0), + mat_texc, + mat_texd, + filter_mode='linear-mipmap-linear' if mat.roughness_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest', + boundary_mode='clamp' if mat.roughness_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap' + )[0] + imgs['roughness'] += roughness * mat.roughness_factor * mat_mask + else: + imgs['roughness'] += mat.roughness_factor * mat_mask + + if mat.alpha_mode == AlphaMode.OPAQUE: + imgs['alpha'] += 1.0 * mat_mask + else: + if mat.alpha_texture is not None: + alpha = dr.texture( + mat.alpha_texture.image.unsqueeze(0), + mat_texc, + mat_texd, + filter_mode='linear-mipmap-linear' if mat.alpha_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest', + boundary_mode='clamp' if mat.alpha_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap' + )[0] + if mat.alpha_mode == AlphaMode.MASK: + imgs['alpha'] += (alpha * mat.alpha_factor > mat.alpha_cutoff).float() * mat_mask + elif mat.alpha_mode == AlphaMode.BLEND: + imgs['alpha'] += alpha * mat.alpha_factor * mat_mask + else: + if mat.alpha_mode == AlphaMode.MASK: + imgs['alpha'] += (mat.alpha_factor > mat.alpha_cutoff).float() * mat_mask + elif mat.alpha_mode == AlphaMode.BLEND: + imgs['alpha'] += mat.alpha_factor * mat_mask + + img = torch.cat([imgs[name] for name in imgs.keys()], dim=-1).unsqueeze(0) + else: + img = dr.interpolate(mesh.vertex_attrs.unsqueeze(0), rast, faces)[0] + if antialias: img = dr.antialias(img, rast, vertices_clip, faces) + + out_dict[type] = img + else: + z_buffer = torch.full((1, resolution * ssaa, resolution * ssaa), torch.inf, device=self.device, dtype=torch.float32) + for i in range(0, faces.shape[0], chunk_size): + faces_chunk = faces[i:i+chunk_size] + rast, rast_db = dr.rasterize( + self.glctx, vertices_clip, faces_chunk, (resolution * ssaa, resolution * ssaa) + ) + z_filter = torch.logical_and( + rast[..., 3] != 0, + rast[..., 2] < z_buffer + ) + z_buffer[z_filter] = rast[z_filter][..., 2] + + for type in return_types: + img = None + if type == "mask" : + img = (rast[..., -1:] > 0).float() + elif type == "depth": + img = dr.interpolate(vertices_camera[..., 2:3].contiguous(), rast, faces_chunk)[0] + elif type == "normal" : + face_normal_chunk = face_normal[i:i+chunk_size] + img = dr.interpolate(face_normal_chunk.unsqueeze(0), rast, torch.arange(face_normal_chunk.shape[0], dtype=torch.int, device=self.device).unsqueeze(1).repeat(1, 3).contiguous())[0] + img = (img + 1) / 2 + elif type == "coord": + img = dr.interpolate(vertices, rast, faces_chunk)[0] + elif type == "attr": + if isinstance(mesh, MeshWithVoxel): + if 'grid_sample_3d' not in globals(): + from flex_gemm.ops.grid_sample import grid_sample_3d + mask = rast[..., -1:] > 0 + xyz = dr.interpolate(vertices, rast, faces_chunk)[0] + xyz = ((xyz - mesh.origin) / mesh.voxel_size).reshape(1, -1, 3) + img = grid_sample_3d( + mesh.attrs, + torch.cat([torch.zeros_like(mesh.coords[..., :1]), mesh.coords], dim=-1), + mesh.voxel_shape, + xyz, + mode='trilinear' + ) + img = img.reshape(1, resolution * ssaa, resolution * ssaa, mesh.attrs.shape[-1]) * mask + elif isinstance(mesh, MeshWithPbrMaterial): + tri_id = rast[0, :, :, -1:] + mask = tri_id > 0 + uv_coords = mesh.uv_coords.reshape(1, -1, 2) + texc, texd = dr.interpolate( + uv_coords, + rast, + torch.arange(mesh.uv_coords.shape[0] * 3, dtype=torch.int, device=self.device).reshape(-1, 3), + rast_db=rast_db, + diff_attrs='all' + ) + # Fix problematic texture coordinates + texc = torch.nan_to_num(texc, nan=0.0, posinf=1e3, neginf=-1e3) + texc = torch.clamp(texc, min=-1e3, max=1e3) + texd = torch.nan_to_num(texd, nan=0.0, posinf=1e3, neginf=-1e3) + texd = torch.clamp(texd, min=-1e3, max=1e3) + mid = mesh.material_ids[(tri_id - 1).long()] + imgs = { + 'base_color': torch.zeros((resolution * ssaa, resolution * ssaa, 3), dtype=torch.float32, device=self.device), + 'metallic': torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device), + 'roughness': torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device), + 'alpha': torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device) + } + for id, mat in enumerate(mesh.materials): + mat_mask = (mid == id).float() * mask.float() + mat_texc = texc * mat_mask + mat_texd = texd * mat_mask + + if mat.base_color_texture is not None: + base_color = dr.texture( + mat.base_color_texture.image.unsqueeze(0), + mat_texc, + mat_texd, + filter_mode='linear-mipmap-linear' if mat.base_color_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest', + boundary_mode='clamp' if mat.base_color_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap' + )[0] + imgs['base_color'] += base_color * mat.base_color_factor * mat_mask + else: + imgs['base_color'] += mat.base_color_factor * mat_mask + + if mat.metallic_texture is not None: + metallic = dr.texture( + mat.metallic_texture.image.unsqueeze(0), + mat_texc, + mat_texd, + filter_mode='linear-mipmap-linear' if mat.metallic_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest', + boundary_mode='clamp' if mat.metallic_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap' + )[0] + imgs['metallic'] += metallic * mat.metallic_factor * mat_mask + else: + imgs['metallic'] += mat.metallic_factor * mat_mask + + if mat.roughness_texture is not None: + roughness = dr.texture( + mat.roughness_texture.image.unsqueeze(0), + mat_texc, + mat_texd, + filter_mode='linear-mipmap-linear' if mat.roughness_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest', + boundary_mode='clamp' if mat.roughness_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap' + )[0] + imgs['roughness'] += roughness * mat.roughness_factor * mat_mask + else: + imgs['roughness'] += mat.roughness_factor * mat_mask + + if mat.alpha_mode == AlphaMode.OPAQUE: + imgs['alpha'] += 1.0 * mat_mask + else: + if mat.alpha_texture is not None: + alpha = dr.texture( + mat.alpha_texture.image.unsqueeze(0), + mat_texc, + mat_texd, + filter_mode='linear-mipmap-linear' if mat.alpha_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest', + boundary_mode='clamp' if mat.alpha_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap' + )[0] + if mat.alpha_mode == AlphaMode.MASK: + imgs['alpha'] += (alpha * mat.alpha_factor > mat.alpha_cutoff).float() * mat_mask + elif mat.alpha_mode == AlphaMode.BLEND: + imgs['alpha'] += alpha * mat.alpha_factor * mat_mask + else: + if mat.alpha_mode == AlphaMode.MASK: + imgs['alpha'] += (mat.alpha_factor > mat.alpha_cutoff).float() * mat_mask + elif mat.alpha_mode == AlphaMode.BLEND: + imgs['alpha'] += mat.alpha_factor * mat_mask + + img = torch.cat([imgs[name] for name in imgs.keys()], dim=-1).unsqueeze(0) + else: + img = dr.interpolate(mesh.vertex_attrs.unsqueeze(0), rast, faces_chunk)[0] + + if type not in out_dict: + out_dict[type] = img + else: + out_dict[type][z_filter] = img[z_filter] + + for type in return_types: + img = out_dict[type] + if ssaa > 1: + img = F.interpolate(img.permute(0, 3, 1, 2), (resolution, resolution), mode='bilinear', align_corners=False, antialias=True) + img = img.squeeze() + else: + img = img.permute(0, 3, 1, 2).squeeze() + out_dict[type] = img + + if isinstance(mesh, (MeshWithVoxel, MeshWithPbrMaterial)) and 'attr' in return_types: + for k, s in mesh.layout.items(): + out_dict[k] = out_dict['attr'][s] + del out_dict['attr'] + + return out_dict diff --git a/trellis2/renderers/pbr_mesh_renderer.py b/trellis2/renderers/pbr_mesh_renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..4180524153e61fe234652108088696b15ee66e46 --- /dev/null +++ b/trellis2/renderers/pbr_mesh_renderer.py @@ -0,0 +1,490 @@ +from typing import * +import torch +from easydict import EasyDict as edict +import numpy as np +import utils3d +from ..representations.mesh import Mesh, MeshWithVoxel, MeshWithPbrMaterial, TextureFilterMode, AlphaMode, TextureWrapMode +import torch.nn.functional as F + + +def cube_to_dir(s, x, y): + if s == 0: rx, ry, rz = torch.ones_like(x), -x, -y + elif s == 1: rx, ry, rz = -torch.ones_like(x), x, -y + elif s == 2: rx, ry, rz = x, y, torch.ones_like(x) + elif s == 3: rx, ry, rz = x, -y, -torch.ones_like(x) + elif s == 4: rx, ry, rz = x, torch.ones_like(x), -y + elif s == 5: rx, ry, rz = -x, -torch.ones_like(x), -y + return torch.stack((rx, ry, rz), dim=-1) + + +def latlong_to_cubemap(latlong_map, res): + if 'dr' not in globals(): + import nvdiffrast.torch as dr + cubemap = torch.zeros(6, res[0], res[1], latlong_map.shape[-1], dtype=torch.float32, device='cuda') + for s in range(6): + gy, gx = torch.meshgrid(torch.linspace(-1.0 + 1.0 / res[0], 1.0 - 1.0 / res[0], res[0], device='cuda'), + torch.linspace(-1.0 + 1.0 / res[1], 1.0 - 1.0 / res[1], res[1], device='cuda'), + indexing='ij') + v = F.normalize(cube_to_dir(s, gx, gy), dim=-1) + + tu = torch.atan2(v[..., 0:1], -v[..., 2:3]) / (2 * np.pi) + 0.5 + tv = torch.acos(torch.clamp(v[..., 1:2], min=-1, max=1)) / np.pi + texcoord = torch.cat((tu, tv), dim=-1) + + cubemap[s, ...] = dr.texture(latlong_map[None, ...], texcoord[None, ...], filter_mode='linear')[0] + return cubemap + + +class EnvMap: + def __init__(self, image: torch.Tensor): + self.image = image + + @property + def _backend(self): + if not hasattr(self, '_nvdiffrec_envlight'): + if 'EnvironmentLight' not in globals(): + from nvdiffrec_render.light import EnvironmentLight + cubemap = latlong_to_cubemap(self.image, [512, 512]) + self._nvdiffrec_envlight = EnvironmentLight(cubemap) + self._nvdiffrec_envlight.build_mips() + return self._nvdiffrec_envlight + + def shade(self, gb_pos, gb_normal, kd, ks, view_pos, specular=True): + return self._backend.shade(gb_pos, gb_normal, kd, ks, view_pos, specular) + + def sample(self, directions: torch.Tensor): + if 'dr' not in globals(): + import nvdiffrast.torch as dr + return dr.texture( + self._backend.base.unsqueeze(0), + directions.unsqueeze(0), + boundary_mode='cube', + )[0] + + +def intrinsics_to_projection( + intrinsics: torch.Tensor, + near: float, + far: float, + ) -> torch.Tensor: + """ + OpenCV intrinsics to OpenGL perspective matrix + + Args: + intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix + near (float): near plane to clip + far (float): far plane to clip + Returns: + (torch.Tensor): [4, 4] OpenGL perspective matrix + """ + fx, fy = intrinsics[0, 0], intrinsics[1, 1] + cx, cy = intrinsics[0, 2], intrinsics[1, 2] + ret = torch.zeros((4, 4), dtype=intrinsics.dtype, device=intrinsics.device) + ret[0, 0] = 2 * fx + ret[1, 1] = 2 * fy + ret[0, 2] = 2 * cx - 1 + ret[1, 2] = - 2 * cy + 1 + ret[2, 2] = (far + near) / (far - near) + ret[2, 3] = 2 * near * far / (near - far) + ret[3, 2] = 1. + return ret + + +def screen_space_ambient_occlusion( + depth: torch.Tensor, + normal: torch.Tensor, + perspective: torch.Tensor, + radius: float = 0.1, + bias: float = 1e-6, + samples: int = 64, + intensity: float = 1.0, +) -> torch.Tensor: + """ + Screen space ambient occlusion (SSAO) + + Args: + depth (torch.Tensor): [H, W, 1] depth image + normal (torch.Tensor): [H, W, 3] normal image + perspective (torch.Tensor): [4, 4] camera projection matrix + radius (float): radius of the SSAO kernel + bias (float): bias to avoid self-occlusion + samples (int): number of samples to use for the SSAO kernel + intensity (float): intensity of the SSAO effect + Returns: + (torch.Tensor): [H, W, 1] SSAO image + """ + device = depth.device + H, W, _ = depth.shape + + fx = perspective[0, 0] + fy = perspective[1, 1] + cx = perspective[0, 2] + cy = perspective[1, 2] + + y_grid, x_grid = torch.meshgrid( + (torch.arange(H, device=device) + 0.5) / H * 2 - 1, + (torch.arange(W, device=device) + 0.5) / W * 2 - 1, + indexing='ij' + ) + x_view = (x_grid.float() - cx) * depth[..., 0] / fx + y_view = (y_grid.float() - cy) * depth[..., 0] / fy + view_pos = torch.stack([x_view, y_view, depth[..., 0]], dim=-1) # [H, W, 3] + + depth_feat = depth.permute(2, 0, 1).unsqueeze(0) + occlusion = torch.zeros((H, W), device=device) + + # start sampling + for _ in range(samples): + # sample normal distribution, if inside, flip the sign + rnd_vec = torch.randn(H, W, 3, device=device) + rnd_vec = F.normalize(rnd_vec, p=2, dim=-1) + dot_val = torch.sum(rnd_vec * normal, dim=-1, keepdim=True) + sample_dir = torch.sign(dot_val) * rnd_vec + scale = torch.rand(H, W, 1, device=device) + scale = scale * scale + sample_pos = view_pos + sample_dir * radius * scale + sample_z = sample_pos[..., 2] + + # project to screen space + z_safe = torch.clamp(sample_pos[..., 2], min=1e-5) + proj_u = (sample_pos[..., 0] * fx / z_safe) + cx + proj_v = (sample_pos[..., 1] * fy / z_safe) + cy + grid = torch.stack([proj_u, proj_v], dim=-1).unsqueeze(0) + geo_z = F.grid_sample(depth_feat, grid, mode='nearest', padding_mode='border').squeeze() + range_check = torch.abs(geo_z - sample_z) < radius + is_occluded = (geo_z <= sample_z - bias) & range_check + occlusion += is_occluded.float() + + f_occ = occlusion / samples * intensity + f_occ = torch.clamp(f_occ, 0.0, 1.0) + + return f_occ.unsqueeze(-1) + + +def aces_tonemapping(x: torch.Tensor) -> torch.Tensor: + """ + Applies ACES tone mapping curve to an HDR image tensor. + Input: x - HDR tensor, shape (..., 3), range [0, +inf) + Output: LDR tensor, same shape, range [0, 1] + """ + a = 2.51 + b = 0.03 + c = 2.43 + d = 0.59 + e = 0.14 + + # Apply the ACES fitted curve + mapped = (x * (a * x + b)) / (x * (c * x + d) + e) + + # Clamp to [0, 1] for display or saving + return torch.clamp(mapped, 0.0, 1.0) + + +def gamma_correction(x: torch.Tensor, gamma: float = 2.2) -> torch.Tensor: + """ + Applies gamma correction to an HDR image tensor. + """ + return torch.clamp(x ** (1.0 / gamma), 0.0, 1.0) + + +class PbrMeshRenderer: + """ + Renderer for the PBR mesh. + + Args: + rendering_options (dict): Rendering options. + """ + def __init__(self, rendering_options={}, device='cuda'): + if 'dr' not in globals(): + import nvdiffrast.torch as dr + + self.rendering_options = edict({ + "resolution": None, + "near": None, + "far": None, + "ssaa": 1, + "peel_layers": 8, + }) + self.rendering_options.update(rendering_options) + self.glctx = dr.RasterizeCudaContext(device=device) + self.device=device + + def render( + self, + mesh : Mesh, + extrinsics: torch.Tensor, + intrinsics: torch.Tensor, + envmap : Union[EnvMap, Dict[str, EnvMap]], + use_envmap_bg : bool = False, + transformation : Optional[torch.Tensor] = None + ) -> edict: + """ + Render the mesh. + + Args: + mesh : meshmodel + extrinsics (torch.Tensor): (4, 4) camera extrinsics + intrinsics (torch.Tensor): (3, 3) camera intrinsics + envmap (Union[EnvMap, Dict[str, EnvMap]]): environment map or a dictionary of environment maps + use_envmap_bg (bool): whether to use envmap as background + transformation (torch.Tensor): (4, 4) transformation matrix + + Returns: + edict based on return_types containing: + shaded (torch.Tensor): [3, H, W] shaded color image + normal (torch.Tensor): [3, H, W] normal image + base_color (torch.Tensor): [3, H, W] base color image + metallic (torch.Tensor): [H, W] metallic image + roughness (torch.Tensor): [H, W] roughness image + """ + if 'dr' not in globals(): + import nvdiffrast.torch as dr + + if not isinstance(envmap, dict): + envmap = {'' : envmap} + num_envmaps = len(envmap) + + resolution = self.rendering_options["resolution"] + near = self.rendering_options["near"] + far = self.rendering_options["far"] + ssaa = self.rendering_options["ssaa"] + + if mesh.vertices.shape[0] == 0 or mesh.faces.shape[0] == 0: + out_dict = edict( + normal=torch.zeros((3, resolution, resolution), dtype=torch.float32, device=self.device), + mask=torch.zeros((resolution, resolution), dtype=torch.float32, device=self.device), + base_color=torch.zeros((3, resolution, resolution), dtype=torch.float32, device=self.device), + metallic=torch.zeros((resolution, resolution), dtype=torch.float32, device=self.device), + roughness=torch.zeros((resolution, resolution), dtype=torch.float32, device=self.device), + alpha=torch.zeros((resolution, resolution), dtype=torch.float32, device=self.device), + clay=torch.zeros((resolution, resolution), dtype=torch.float32, device=self.device), + ) + for i, k in enumerate(envmap.keys()): + shaded_key = f"shaded_{k}" if k != '' else "shaded" + out_dict[shaded_key] = torch.zeros((3, resolution, resolution), dtype=torch.float32, device=self.device) + return out_dict + + rays_o, rays_d = utils3d.torch.get_image_rays( + extrinsics, intrinsics, resolution * ssaa, resolution * ssaa + ) + + perspective = intrinsics_to_projection(intrinsics, near, far) + + full_proj = (perspective @ extrinsics).unsqueeze(0) + extrinsics = extrinsics.unsqueeze(0) + + vertices = mesh.vertices.unsqueeze(0) + vertices_orig = vertices.clone() + vertices_homo = torch.cat([vertices, torch.ones_like(vertices[..., :1])], dim=-1) + if transformation is not None: + vertices_homo = torch.bmm(vertices_homo, transformation.unsqueeze(0).transpose(-1, -2)) + vertices = vertices_homo[..., :3].contiguous() + vertices_camera = torch.bmm(vertices_homo, extrinsics.transpose(-1, -2)) + vertices_clip = torch.bmm(vertices_homo, full_proj.transpose(-1, -2)) + faces = mesh.faces + + v0 = vertices[0, mesh.faces[:, 0], :3] + v1 = vertices[0, mesh.faces[:, 1], :3] + v2 = vertices[0, mesh.faces[:, 2], :3] + e0 = v1 - v0 + e1 = v2 - v0 + face_normal = torch.cross(e0, e1, dim=1) + face_normal = F.normalize(face_normal, dim=1) + + out_dict = edict() + shaded = torch.zeros((num_envmaps, resolution * ssaa, resolution * ssaa, 3), dtype=torch.float32, device=self.device) + depth = torch.full((resolution * ssaa, resolution * ssaa, 1), 1e10, dtype=torch.float32, device=self.device) + normal = torch.zeros((resolution * ssaa, resolution * ssaa, 3), dtype=torch.float32, device=self.device) + max_w = torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device) + alpha = torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device) + with dr.DepthPeeler(self.glctx, vertices_clip, faces, (resolution * ssaa, resolution * ssaa)) as peeler: + for _ in range(self.rendering_options["peel_layers"]): + rast, rast_db = peeler.rasterize_next_layer() + + # Pos + pos = dr.interpolate(vertices, rast, faces)[0][0] + + # Depth + gb_depth = dr.interpolate(vertices_camera[..., 2:3].contiguous(), rast, faces)[0][0] + + # Normal + gb_normal = dr.interpolate(face_normal.unsqueeze(0), rast, torch.arange(face_normal.shape[0], dtype=torch.int, device=self.device).unsqueeze(1).repeat(1, 3).contiguous())[0][0] + gb_normal = torch.where( + torch.sum(gb_normal * (pos - rays_o), dim=-1, keepdim=True) > 0, + -gb_normal, + gb_normal + ) + gb_cam_normal = (extrinsics[..., :3, :3].reshape(1, 1, 3, 3) @ gb_normal.unsqueeze(-1)).squeeze(-1) + if _ == 0: + out_dict.normal = -gb_cam_normal * 0.5 + 0.5 + mask = (rast[0, ..., -1:] > 0).float() + out_dict.mask = mask + + # PBR attributes + if isinstance(mesh, MeshWithVoxel): + if 'grid_sample_3d' not in globals(): + from flex_gemm.ops.grid_sample import grid_sample_3d + mask = rast[..., -1:] > 0 + xyz = dr.interpolate(vertices_orig, rast, faces)[0] + xyz = ((xyz - mesh.origin) / mesh.voxel_size).reshape(1, -1, 3) + img = grid_sample_3d( + mesh.attrs, + torch.cat([torch.zeros_like(mesh.coords[..., :1]), mesh.coords], dim=-1), + mesh.voxel_shape, + xyz, + mode='trilinear' + ) + img = img.reshape(1, resolution * ssaa, resolution * ssaa, mesh.attrs.shape[-1]) * mask + gb_basecolor = img[0, ..., mesh.layout['base_color']] + gb_metallic = img[0, ..., mesh.layout['metallic']] + gb_roughness = img[0, ..., mesh.layout['roughness']] + gb_alpha = img[0, ..., mesh.layout['alpha']] + elif isinstance(mesh, MeshWithPbrMaterial): + tri_id = rast[0, :, :, -1:] + mask = tri_id > 0 + uv_coords = mesh.uv_coords.reshape(1, -1, 2) + texc, texd = dr.interpolate( + uv_coords, + rast, + torch.arange(mesh.uv_coords.shape[0] * 3, dtype=torch.int, device=self.device).reshape(-1, 3), + rast_db=rast_db, + diff_attrs='all' + ) + # Fix problematic texture coordinates + texc = torch.nan_to_num(texc, nan=0.0, posinf=1e3, neginf=-1e3) + texc = torch.clamp(texc, min=-1e3, max=1e3) + texd = torch.nan_to_num(texd, nan=0.0, posinf=1e3, neginf=-1e3) + texd = torch.clamp(texd, min=-1e3, max=1e3) + mid = mesh.material_ids[(tri_id - 1).long()] + gb_basecolor = torch.zeros((resolution * ssaa, resolution * ssaa, 3), dtype=torch.float32, device=self.device) + gb_metallic = torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device) + gb_roughness = torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device) + gb_alpha = torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device) + for id, mat in enumerate(mesh.materials): + mat_mask = (mid == id).float() * mask.float() + mat_texc = texc * mat_mask + mat_texd = texd * mat_mask + + if mat.base_color_texture is not None: + bc = dr.texture( + mat.base_color_texture.image.unsqueeze(0), + mat_texc, + mat_texd, + filter_mode='linear-mipmap-linear' if mat.base_color_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest', + boundary_mode='clamp' if mat.base_color_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap' + )[0] + gb_basecolor += bc * mat.base_color_factor * mat_mask + else: + gb_basecolor += mat.base_color_factor * mat_mask + + if mat.metallic_texture is not None: + m = dr.texture( + mat.metallic_texture.image.unsqueeze(0), + mat_texc, + mat_texd, + filter_mode='linear-mipmap-linear' if mat.metallic_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest', + boundary_mode='clamp' if mat.metallic_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap' + )[0] + gb_metallic += m * mat.metallic_factor * mat_mask + else: + gb_metallic += mat.metallic_factor * mat_mask + + if mat.roughness_texture is not None: + r = dr.texture( + mat.roughness_texture.image.unsqueeze(0), + mat_texc, + mat_texd, + filter_mode='linear-mipmap-linear' if mat.roughness_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest', + boundary_mode='clamp' if mat.roughness_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap' + )[0] + gb_roughness += r * mat.roughness_factor * mat_mask + else: + gb_roughness += mat.roughness_factor * mat_mask + + if mat.alpha_mode == AlphaMode.OPAQUE: + gb_alpha += 1.0 * mat_mask + else: + if mat.alpha_texture is not None: + a = dr.texture( + mat.alpha_texture.image.unsqueeze(0), + mat_texc, + mat_texd, + filter_mode='linear-mipmap-linear' if mat.alpha_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest', + boundary_mode='clamp' if mat.alpha_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap' + )[0] + if mat.alpha_mode == AlphaMode.MASK: + gb_alpha += (a * mat.alpha_factor > mat.alpha_cutoff).float() * mat_mask + elif mat.alpha_mode == AlphaMode.BLEND: + gb_alpha += a * mat.alpha_factor * mat_mask + else: + if mat.alpha_mode == AlphaMode.MASK: + gb_alpha += (mat.alpha_factor > mat.alpha_cutoff).float() * mat_mask + elif mat.alpha_mode == AlphaMode.BLEND: + gb_alpha += mat.alpha_factor * mat_mask + if _ == 0: + out_dict.base_color = gb_basecolor + out_dict.metallic = gb_metallic + out_dict.roughness = gb_roughness + out_dict.alpha = gb_alpha + + # Shading + gb_basecolor = torch.clamp(gb_basecolor, 0.0, 1.0) ** 2.2 + gb_metallic = torch.clamp(gb_metallic, 0.0, 1.0) + gb_roughness = torch.clamp(gb_roughness, 0.0, 1.0) + gb_alpha = torch.clamp(gb_alpha, 0.0, 1.0) + gb_orm = torch.cat([ + torch.zeros_like(gb_metallic), + gb_roughness, + gb_metallic, + ], dim=-1) + gb_shaded = torch.stack([ + e.shade( + pos.unsqueeze(0), + gb_normal.unsqueeze(0), + gb_basecolor.unsqueeze(0), + gb_orm.unsqueeze(0), + rays_o, + specular=True, + )[0] + for e in envmap.values() + ], dim=0) + + # Compositing + w = (1 - alpha) * gb_alpha + depth = torch.where(w > max_w, gb_depth, depth) + normal = torch.where(w > max_w, gb_cam_normal, normal) + max_w = torch.maximum(max_w, w) + shaded += w * gb_shaded + alpha += w + + # Ambient occulusion + f_occ = screen_space_ambient_occlusion( + depth, normal, perspective, intensity=1.5 + ) + shaded *= (1 - f_occ) + out_dict.clay = (1 - f_occ) + + # Background + if use_envmap_bg: + bg = torch.stack([e.sample(rays_d) for e in envmap.values()], dim=0) + shaded += (1 - alpha) * bg + + for i, k in enumerate(envmap.keys()): + shaded_key = f"shaded_{k}" if k != '' else "shaded" + out_dict[shaded_key] = shaded[i] + + # SSAA + for k in out_dict.keys(): + if ssaa > 1: + out_dict[k] = F.interpolate(out_dict[k].unsqueeze(0).permute(0, 3, 1, 2), (resolution, resolution), mode='bilinear', align_corners=False, antialias=True) + else: + out_dict[k] = out_dict[k].permute(2, 0, 1) + out_dict[k] = out_dict[k].squeeze() + + # Post processing + for k in envmap.keys(): + shaded_key = f"shaded_{k}" if k != '' else "shaded" + out_dict[shaded_key] = aces_tonemapping(out_dict[shaded_key]) + out_dict[shaded_key] = gamma_correction(out_dict[shaded_key]) + + return out_dict diff --git a/trellis2/renderers/voxel_renderer.py b/trellis2/renderers/voxel_renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..ee691d307e44248b1a10f226035106544be87579 --- /dev/null +++ b/trellis2/renderers/voxel_renderer.py @@ -0,0 +1,68 @@ +import torch +from easydict import EasyDict as edict +from ..representations import Voxel +from easydict import EasyDict as edict + + +class VoxelRenderer: + """ + Renderer for the Voxel representation. + + Args: + rendering_options (dict): Rendering options. + """ + + def __init__(self, rendering_options={}) -> None: + self.rendering_options = edict({ + "resolution": None, + "near": 0.1, + "far": 10.0, + "ssaa": 1, + }) + self.rendering_options.update(rendering_options) + + def render( + self, + voxel: Voxel, + extrinsics: torch.Tensor, + intrinsics: torch.Tensor, + colors_overwrite: torch.Tensor = None + ) -> edict: + """ + Render the gausssian. + + Args: + voxel (Voxel): Voxel representation. + extrinsics (torch.Tensor): (4, 4) camera extrinsics + intrinsics (torch.Tensor): (3, 3) camera intrinsics + colors_overwrite (torch.Tensor): (N, 3) override color + + Returns: + edict containing: + color (torch.Tensor): (3, H, W) rendered color image + depth (torch.Tensor): (H, W) rendered depth + alpha (torch.Tensor): (H, W) rendered alpha + ... + """ + # lazy import + if 'o_voxel' not in globals(): + import o_voxel + renderer = o_voxel.rasterize.VoxelRenderer(self.rendering_options) + positions = voxel.position + attrs = voxel.attrs if colors_overwrite is None else colors_overwrite + voxel_size = voxel.voxel_size + + # Render + render_ret = renderer.render(positions, attrs, voxel_size, extrinsics, intrinsics) + + ret = { + 'depth': render_ret['depth'], + 'alpha': render_ret['alpha'], + } + if colors_overwrite is not None: + ret['color'] = render_ret['attr'] + else: + for k, s in voxel.layout.items(): + ret[k] = render_ret['attr'][s] + + return ret diff --git a/trellis2/representations/__init__.py b/trellis2/representations/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a6771ea8eb7d2bdd59b6d4423f8ad45c449b4564 --- /dev/null +++ b/trellis2/representations/__init__.py @@ -0,0 +1,31 @@ +import importlib + +__attributes = { + 'Mesh': 'mesh', + 'Voxel': 'voxel', + 'MeshWithVoxel': 'mesh', + 'MeshWithPbrMaterial': 'mesh', +} + +__submodules = [] + +__all__ = list(__attributes.keys()) + __submodules + +def __getattr__(name): + if name not in globals(): + if name in __attributes: + module_name = __attributes[name] + module = importlib.import_module(f".{module_name}", __name__) + globals()[name] = getattr(module, name) + elif name in __submodules: + module = importlib.import_module(f".{name}", __name__) + globals()[name] = module + else: + raise AttributeError(f"module {__name__} has no attribute {name}") + return globals()[name] + + +# For Pylance +if __name__ == '__main__': + from .mesh import Mesh, MeshWithVoxel, MeshWithPbrMaterial + from .voxel import Voxel diff --git a/trellis2/representations/__pycache__/__init__.cpython-311.pyc b/trellis2/representations/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f57bc0b9b077f37478072bbb94bb474f6731552d Binary files /dev/null and b/trellis2/representations/__pycache__/__init__.cpython-311.pyc differ diff --git a/trellis2/representations/mesh/__init__.py b/trellis2/representations/mesh/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d62daac11015853edeebe4278ee1990d13956994 --- /dev/null +++ b/trellis2/representations/mesh/__init__.py @@ -0,0 +1 @@ +from .base import Mesh, MeshWithVoxel, MeshWithPbrMaterial, TextureFilterMode, TextureWrapMode, AlphaMode, PbrMaterial, Texture diff --git a/trellis2/representations/mesh/__pycache__/__init__.cpython-311.pyc b/trellis2/representations/mesh/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4381c00879ef0723cd098b766db9558336c218a8 Binary files /dev/null and b/trellis2/representations/mesh/__pycache__/__init__.cpython-311.pyc differ diff --git a/trellis2/representations/mesh/__pycache__/base.cpython-311.pyc b/trellis2/representations/mesh/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9747da02ba4170a82e2d6d42e7004796284aaf6 Binary files /dev/null and b/trellis2/representations/mesh/__pycache__/base.cpython-311.pyc differ diff --git a/trellis2/representations/mesh/base.py b/trellis2/representations/mesh/base.py new file mode 100644 index 0000000000000000000000000000000000000000..c80ac586970406df18f0a581e94d328c7afdfbb3 --- /dev/null +++ b/trellis2/representations/mesh/base.py @@ -0,0 +1,234 @@ +from typing import * +import torch +from ..voxel import Voxel +import cumesh +from flex_gemm.ops.grid_sample import grid_sample_3d + + +class Mesh: + def __init__(self, + vertices, + faces, + vertex_attrs=None + ): + self.vertices = vertices.float() + self.faces = faces.int() + self.vertex_attrs = vertex_attrs + + @property + def device(self): + return self.vertices.device + + def to(self, device, non_blocking=False): + return Mesh( + self.vertices.to(device, non_blocking=non_blocking), + self.faces.to(device, non_blocking=non_blocking), + self.vertex_attrs.to(device, non_blocking=non_blocking) if self.vertex_attrs is not None else None, + ) + + def cuda(self, non_blocking=False): + return self.to('cuda', non_blocking=non_blocking) + + def cpu(self): + return self.to('cpu') + + def fill_holes(self, max_hole_perimeter=3e-2): + vertices = self.vertices.cuda() + faces = self.faces.cuda() + + mesh = cumesh.CuMesh() + mesh.init(vertices, faces) + mesh.get_edges() + mesh.get_boundary_info() + if mesh.num_boundaries == 0: + return + mesh.get_vertex_edge_adjacency() + mesh.get_vertex_boundary_adjacency() + mesh.get_manifold_boundary_adjacency() + mesh.read_manifold_boundary_adjacency() + mesh.get_boundary_connected_components() + mesh.get_boundary_loops() + if mesh.num_boundary_loops == 0: + return + mesh.fill_holes(max_hole_perimeter=max_hole_perimeter) + new_vertices, new_faces = mesh.read() + + self.vertices = new_vertices.to(self.device) + self.faces = new_faces.to(self.device) + + def remove_faces(self, face_mask: torch.Tensor): + vertices = self.vertices.cuda() + faces = self.faces.cuda() + + mesh = cumesh.CuMesh() + mesh.init(vertices, faces) + mesh.remove_faces(face_mask) + new_vertices, new_faces = mesh.read() + + self.vertices = new_vertices.to(self.device) + self.faces = new_faces.to(self.device) + + def simplify(self, target=1000000, verbose: bool=False, options: dict={}): + vertices = self.vertices.cuda() + faces = self.faces.cuda() + + mesh = cumesh.CuMesh() + mesh.init(vertices, faces) + mesh.simplify(target, verbose=verbose, options=options) + new_vertices, new_faces = mesh.read() + + self.vertices = new_vertices.to(self.device) + self.faces = new_faces.to(self.device) + + +class TextureFilterMode: + CLOSEST = 0 + LINEAR = 1 + + +class TextureWrapMode: + CLAMP_TO_EDGE = 0 + REPEAT = 1 + MIRRORED_REPEAT = 2 + + +class AlphaMode: + OPAQUE = 0 + MASK = 1 + BLEND = 2 + + +class Texture: + def __init__( + self, + image: torch.Tensor, + filter_mode: TextureFilterMode = TextureFilterMode.LINEAR, + wrap_mode: TextureWrapMode = TextureWrapMode.REPEAT + ): + self.image = image + self.filter_mode = filter_mode + self.wrap_mode = wrap_mode + + def to(self, device, non_blocking=False): + return Texture( + self.image.to(device, non_blocking=non_blocking), + self.filter_mode, + self.wrap_mode, + ) + + +class PbrMaterial: + def __init__( + self, + base_color_texture: Optional[Texture] = None, + base_color_factor: Union[torch.Tensor, List[float]] = [1.0, 1.0, 1.0], + metallic_texture: Optional[Texture] = None, + metallic_factor: float = 1.0, + roughness_texture: Optional[Texture] = None, + roughness_factor: float = 1.0, + alpha_texture: Optional[Texture] = None, + alpha_factor: float = 1.0, + alpha_mode: AlphaMode = AlphaMode.OPAQUE, + alpha_cutoff: float = 0.5, + ): + self.base_color_texture = base_color_texture + self.base_color_factor = torch.tensor(base_color_factor, dtype=torch.float32)[:3] + self.metallic_texture = metallic_texture + self.metallic_factor = metallic_factor + self.roughness_texture = roughness_texture + self.roughness_factor = roughness_factor + self.alpha_texture = alpha_texture + self.alpha_factor = alpha_factor + self.alpha_mode = alpha_mode + self.alpha_cutoff = alpha_cutoff + + def to(self, device, non_blocking=False): + return PbrMaterial( + base_color_texture=self.base_color_texture.to(device, non_blocking=non_blocking) if self.base_color_texture is not None else None, + base_color_factor=self.base_color_factor.to(device, non_blocking=non_blocking), + metallic_texture=self.metallic_texture.to(device, non_blocking=non_blocking) if self.metallic_texture is not None else None, + metallic_factor=self.metallic_factor, + roughness_texture=self.roughness_texture.to(device, non_blocking=non_blocking) if self.roughness_texture is not None else None, + roughness_factor=self.roughness_factor, + alpha_texture=self.alpha_texture.to(device, non_blocking=non_blocking) if self.alpha_texture is not None else None, + alpha_factor=self.alpha_factor, + alpha_mode=self.alpha_mode, + alpha_cutoff=self.alpha_cutoff, + ) + + +class MeshWithPbrMaterial(Mesh): + def __init__(self, + vertices, + faces, + material_ids, + uv_coords, + materials: List[PbrMaterial], + ): + self.vertices = vertices.float() + self.faces = faces.int() + self.material_ids = material_ids # [M] + self.uv_coords = uv_coords # [M, 3, 2] + self.materials = materials + self.layout = { + 'base_color': slice(0, 3), + 'metallic': slice(3, 4), + 'roughness': slice(4, 5), + 'alpha': slice(5, 6), + } + + def to(self, device, non_blocking=False): + return MeshWithPbrMaterial( + self.vertices.to(device, non_blocking=non_blocking), + self.faces.to(device, non_blocking=non_blocking), + self.material_ids.to(device, non_blocking=non_blocking), + self.uv_coords.to(device, non_blocking=non_blocking), + [material.to(device, non_blocking=non_blocking) for material in self.materials], + ) + + +class MeshWithVoxel(Mesh, Voxel): + def __init__(self, + vertices: torch.Tensor, + faces: torch.Tensor, + origin: list, + voxel_size: float, + coords: torch.Tensor, + attrs: torch.Tensor, + voxel_shape: torch.Size, + layout: Dict = {}, + ): + self.vertices = vertices.float() + self.faces = faces.int() + self.origin = torch.tensor(origin, dtype=torch.float32, device=self.device) + self.voxel_size = voxel_size + self.coords = coords + self.attrs = attrs + self.voxel_shape = voxel_shape + self.layout = layout + + def to(self, device, non_blocking=False): + return MeshWithVoxel( + self.vertices.to(device, non_blocking=non_blocking), + self.faces.to(device, non_blocking=non_blocking), + self.origin.tolist(), + self.voxel_size, + self.coords.to(device, non_blocking=non_blocking), + self.attrs.to(device, non_blocking=non_blocking), + self.voxel_shape, + self.layout, + ) + + def query_attrs(self, xyz): + grid = ((xyz - self.origin) / self.voxel_size).reshape(1, -1, 3) + vertex_attrs = grid_sample_3d( + self.attrs, + torch.cat([torch.zeros_like(self.coords[..., :1]), self.coords], dim=-1), + self.voxel_shape, + grid, + mode='trilinear' + )[0] + return vertex_attrs + + def query_vertex_attrs(self): + return self.query_attrs(self.vertices) diff --git a/trellis2/representations/voxel/__init__.py b/trellis2/representations/voxel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b5792ea14b2371a96c4130371eb976f0aff4b5dd --- /dev/null +++ b/trellis2/representations/voxel/__init__.py @@ -0,0 +1 @@ +from .voxel_model import Voxel \ No newline at end of file diff --git a/trellis2/representations/voxel/__pycache__/__init__.cpython-311.pyc b/trellis2/representations/voxel/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a5927e04ae9df48498765644d218a5aefa7148b Binary files /dev/null and b/trellis2/representations/voxel/__pycache__/__init__.cpython-311.pyc differ diff --git a/trellis2/representations/voxel/__pycache__/voxel_model.cpython-311.pyc b/trellis2/representations/voxel/__pycache__/voxel_model.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..091b536ba7998ee444b1f33c61b2acb87e40c51a Binary files /dev/null and b/trellis2/representations/voxel/__pycache__/voxel_model.cpython-311.pyc differ diff --git a/trellis2/representations/voxel/voxel_model.py b/trellis2/representations/voxel/voxel_model.py new file mode 100644 index 0000000000000000000000000000000000000000..8a4208a688c20ee183a9a2f72691ae08e99015a3 --- /dev/null +++ b/trellis2/representations/voxel/voxel_model.py @@ -0,0 +1,54 @@ +from typing import Dict +import torch + + +class Voxel: + def __init__( + self, + origin: list, + voxel_size: float, + coords: torch.Tensor = None, + attrs: torch.Tensor = None, + layout: Dict = {}, + device: torch.device = 'cuda' + ): + self.origin = torch.tensor(origin, dtype=torch.float32, device=device) + self.voxel_size = voxel_size + self.coords = coords + self.attrs = attrs + self.layout = layout + self.device = device + + @property + def position(self): + return (self.coords + 0.5) * self.voxel_size + self.origin[None, :] + + def split_attrs(self): + return { + k: self.attrs[:, self.layout[k]] + for k in self.layout + } + + def save(self, path): + # lazy import + if 'o_voxel' not in globals(): + import o_voxel + o_voxel.io.write( + path, + self.coords, + self.split_attrs(), + ) + + def load(self, path): + # lazy import + if 'o_voxel' not in globals(): + import o_voxel + coord, attrs = o_voxel.io.read(path) + self.coords = coord.int().to(self.device) + self.attrs = torch.cat([attrs[k] for k in attrs], dim=1).to(self.device) + # build layout + start = 0 + self.layout = {} + for k in attrs: + self.layout[k] = slice(start, start + attrs[k].shape[1]) + start += attrs[k].shape[1] diff --git a/trellis2/trainers/__init__.py b/trellis2/trainers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c30b39f3edefc2ca9886c7900af74f30b67dd1c4 --- /dev/null +++ b/trellis2/trainers/__init__.py @@ -0,0 +1,68 @@ +import importlib + +__attributes = { + 'BasicTrainer': 'basic', + + 'SparseStructureVaeTrainer': 'vae.sparse_structure_vae', + 'ShapeVaeTrainer': 'vae.shape_vae', + 'PbrVaeTrainer': 'vae.pbr_vae', + + 'FlowMatchingTrainer': 'flow_matching.flow_matching', + 'FlowMatchingCFGTrainer': 'flow_matching.flow_matching', + 'TextConditionedFlowMatchingCFGTrainer': 'flow_matching.flow_matching', + 'ImageConditionedFlowMatchingCFGTrainer': 'flow_matching.flow_matching', + + 'SparseFlowMatchingTrainer': 'flow_matching.sparse_flow_matching', + 'SparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching', + 'TextConditionedSparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching', + 'ImageConditionedSparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching', + 'MultiImageConditionedSparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching', + + 'DinoV2FeatureExtractor': 'flow_matching.mixins.image_conditioned', + 'DinoV3FeatureExtractor': 'flow_matching.mixins.image_conditioned', +} + +__submodules = [] + +__all__ = list(__attributes.keys()) + __submodules + +def __getattr__(name): + if name not in globals(): + if name in __attributes: + module_name = __attributes[name] + module = importlib.import_module(f".{module_name}", __name__) + globals()[name] = getattr(module, name) + elif name in __submodules: + module = importlib.import_module(f".{name}", __name__) + globals()[name] = module + else: + raise AttributeError(f"module {__name__} has no attribute {name}") + return globals()[name] + + +# For Pylance +if __name__ == '__main__': + from .basic import BasicTrainer + + from .vae.sparse_structure_vae import SparseStructureVaeTrainer + from .vae.shape_vae import ShapeVaeTrainer + from .vae.pbr_vae import PbrVaeTrainer + + from .flow_matching.flow_matching import ( + FlowMatchingTrainer, + FlowMatchingCFGTrainer, + TextConditionedFlowMatchingCFGTrainer, + ImageConditionedFlowMatchingCFGTrainer, + ) + + from .flow_matching.sparse_flow_matching import ( + SparseFlowMatchingTrainer, + SparseFlowMatchingCFGTrainer, + TextConditionedSparseFlowMatchingCFGTrainer, + ImageConditionedSparseFlowMatchingCFGTrainer, + ) + + from .flow_matching.mixins.image_conditioned import ( + DinoV2FeatureExtractor, + DinoV3FeatureExtractor, + ) diff --git a/trellis2/trainers/basic.py b/trellis2/trainers/basic.py new file mode 100644 index 0000000000000000000000000000000000000000..891b9ad84085c34c8830635c251536168c477bf9 --- /dev/null +++ b/trellis2/trainers/basic.py @@ -0,0 +1,910 @@ +from abc import abstractmethod +import os +import time +import json +import copy +import threading +from functools import partial +from contextlib import nullcontext + +import torch +import torch.distributed as dist +from torch.utils.data import DataLoader +from torch.nn.parallel import DistributedDataParallel as DDP +import numpy as np + +from torchvision import utils +from torch.utils.tensorboard import SummaryWriter + +from .utils import * +from ..utils.general_utils import * +from ..utils.data_utils import recursive_to_device, cycle, ResumableSampler +from ..utils.dist_utils import * +from ..utils import grad_clip_utils, elastic_utils + + +class BasicTrainer: + """ + Trainer for basic training loop. + + Args: + models (dict[str, nn.Module]): Models to train. + dataset (torch.utils.data.Dataset): Dataset. + output_dir (str): Output directory. + load_dir (str): Load directory. + step (int): Step to load. + batch_size (int): Batch size. + batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. + batch_split (int): Split batch with gradient accumulation. + max_steps (int): Max steps. + optimizer (dict): Optimizer config. + lr_scheduler (dict): Learning rate scheduler config. + elastic (dict): Elastic memory management config. + grad_clip (float or dict): Gradient clip config. + ema_rate (float or list): Exponential moving average rates. + mix_precision_mode (str): + - None: No mixed precision. + - 'inflat_all': Hold a inflated fp32 master param for all params. + - 'amp': Automatic mixed precision. + mix_precision_dtype (str): Mixed precision dtype. + fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. + parallel_mode (str): Parallel mode. Options are 'ddp'. + finetune_ckpt (dict): Finetune checkpoint. + log_param_stats (bool): Log parameter stats. + i_print (int): Print interval. + i_log (int): Log interval. + i_sample (int): Sample interval. + i_save (int): Save interval. + i_ddpcheck (int): DDP check interval. + """ + def __init__(self, + models, + dataset, + *, + output_dir, + load_dir, + step, + max_steps, + batch_size=None, + batch_size_per_gpu=None, + batch_split=None, + optimizer={}, + lr_scheduler=None, + elastic=None, + grad_clip=None, + ema_rate=0.9999, + fp16_mode=None, + mix_precision_mode='inflat_all', + mix_precision_dtype='float16', + fp16_scale_growth=1e-3, + parallel_mode='ddp', + finetune_ckpt=None, + log_param_stats=False, + prefetch_data=True, + snapshot_batch_size=4, + i_print=1000, + i_log=500, + i_sample=10000, + i_save=10000, + i_ddpcheck=10000, + **kwargs + ): + assert batch_size is not None or batch_size_per_gpu is not None, 'Either batch_size or batch_size_per_gpu must be specified.' + + self.models = models + self.dataset = dataset + self.batch_split = batch_split if batch_split is not None else 1 + self.max_steps = max_steps + self.optimizer_config = optimizer + self.lr_scheduler_config = lr_scheduler + self.elastic_controller_config = elastic + self.grad_clip = grad_clip + self.ema_rate = [ema_rate] if isinstance(ema_rate, float) else ema_rate + if fp16_mode is not None: + mix_precision_dtype = 'float16' + mix_precision_mode = fp16_mode + self.mix_precision_mode = mix_precision_mode + self.mix_precision_dtype = str_to_dtype(mix_precision_dtype) + self.fp16_scale_growth = fp16_scale_growth + self.parallel_mode = parallel_mode + self.log_param_stats = log_param_stats + self.prefetch_data = prefetch_data + self.snapshot_batch_size = snapshot_batch_size + self.log = [] + if self.prefetch_data: + self._data_prefetched = None + + self.output_dir = output_dir + self.i_print = i_print + self.i_log = i_log + self.i_sample = i_sample + self.i_save = i_save + self.i_ddpcheck = i_ddpcheck + + if dist.is_initialized(): + # Multi-GPU params + self.world_size = dist.get_world_size() + self.rank = dist.get_rank() + self.local_rank = dist.get_rank() % torch.cuda.device_count() + self.is_master = self.rank == 0 + else: + # Single-GPU params + self.world_size = 1 + self.rank = 0 + self.local_rank = 0 + self.is_master = True + + self.batch_size = batch_size if batch_size_per_gpu is None else batch_size_per_gpu * self.world_size + self.batch_size_per_gpu = batch_size_per_gpu if batch_size_per_gpu is not None else batch_size // self.world_size + assert self.batch_size % self.world_size == 0, 'Batch size must be divisible by the number of GPUs.' + assert self.batch_size_per_gpu % self.batch_split == 0, 'Batch size per GPU must be divisible by batch split.' + + self.init_models_and_more(**kwargs) + self.prepare_dataloader(**kwargs) + + # Load checkpoint + self.step = 0 + if load_dir is not None and step is not None: + self.load(load_dir, step) + elif finetune_ckpt is not None: + self.finetune_from(finetune_ckpt) + + if self.is_master: + os.makedirs(os.path.join(self.output_dir, 'ckpts'), exist_ok=True) + os.makedirs(os.path.join(self.output_dir, 'samples'), exist_ok=True) + self.writer = SummaryWriter(os.path.join(self.output_dir, 'tb_logs')) + + if self.parallel_mode == 'ddp' and self.world_size > 1: + self.check_ddp() + + if self.is_master: + print('\n\nTrainer initialized.') + print(self) + + def __str__(self): + lines = [] + lines.append(self.__class__.__name__) + lines.append(f' - Models:') + for name, model in self.models.items(): + lines.append(f' - {name}: {model.__class__.__name__}') + lines.append(f' - Dataset: {indent(str(self.dataset), 2)}') + lines.append(f' - Dataloader:') + lines.append(f' - Sampler: {self.dataloader.sampler.__class__.__name__}') + lines.append(f' - Num workers: {self.dataloader.num_workers}') + lines.append(f' - Number of steps: {self.max_steps}') + lines.append(f' - Number of GPUs: {self.world_size}') + lines.append(f' - Batch size: {self.batch_size}') + lines.append(f' - Batch size per GPU: {self.batch_size_per_gpu}') + lines.append(f' - Batch split: {self.batch_split}') + lines.append(f' - Optimizer: {self.optimizer.__class__.__name__}') + lines.append(f' - Learning rate: {self.optimizer.param_groups[0]["lr"]}') + if self.lr_scheduler_config is not None: + lines.append(f' - LR scheduler: {self.lr_scheduler.__class__.__name__}') + if self.elastic_controller_config is not None: + lines.append(f' - Elastic memory: {indent(str(self.elastic_controller), 2)}') + if self.grad_clip is not None: + lines.append(f' - Gradient clip: {indent(str(self.grad_clip), 2)}') + lines.append(f' - EMA rate: {self.ema_rate}') + lines.append(f' - Mixed precision dtype: {self.mix_precision_dtype}') + lines.append(f' - Mixed precision mode: {self.mix_precision_mode}') + if self.mix_precision_mode == 'amp' and self.mix_precision_dtype == torch.float16: + lines.append(f' - FP16 scale growth: {self.fp16_scale_growth}') + lines.append(f' - Parallel mode: {self.parallel_mode}') + return '\n'.join(lines) + + @property + def device(self): + for _, model in self.models.items(): + if hasattr(model, 'device'): + return model.device + return next(list(self.models.values())[0].parameters()).device + + def init_models_and_more(self, **kwargs): + """ + Initialize models and more. + """ + if self.world_size > 1: + # Prepare distributed data parallel + self.training_models = { + name: DDP( + model, + device_ids=[self.local_rank], + output_device=self.local_rank, + bucket_cap_mb=128, + find_unused_parameters=False + ) + for name, model in self.models.items() + } + else: + self.training_models = self.models + + # Build master params + self.model_params = sum( + [[p for p in model.parameters() if p.requires_grad] for model in self.models.values()] + , []) + if self.mix_precision_mode == 'amp': + self.master_params = self.model_params + if self.mix_precision_dtype == torch.float16: + self.scaler = torch.GradScaler() + elif self.mix_precision_mode == 'inflat_all': + self.master_params = make_master_params(self.model_params) + if self.mix_precision_dtype == torch.float16: + self.log_scale = 20.0 + elif self.mix_precision_mode is None: + self.master_params = self.model_params + else: + raise NotImplementedError(f'Mix precision mode {self.mix_precision_mode} is not implemented.') + + # Build EMA params + if self.is_master: + self.ema_params = [copy.deepcopy(self.master_params) for _ in self.ema_rate] + + # Initialize optimizer + if hasattr(torch.optim, self.optimizer_config['name']): + self.optimizer = getattr(torch.optim, self.optimizer_config['name'])(self.master_params, **self.optimizer_config['args']) + else: + self.optimizer = globals()[self.optimizer_config['name']](self.master_params, **self.optimizer_config['args']) + + # Initalize learning rate scheduler + if self.lr_scheduler_config is not None: + if hasattr(torch.optim.lr_scheduler, self.lr_scheduler_config['name']): + self.lr_scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_config['name'])(self.optimizer, **self.lr_scheduler_config['args']) + else: + self.lr_scheduler = globals()[self.lr_scheduler_config['name']](self.optimizer, **self.lr_scheduler_config['args']) + + # Initialize elastic memory controller + if self.elastic_controller_config is not None: + assert any([isinstance(model, (elastic_utils.ElasticModule, elastic_utils.ElasticModuleMixin)) for model in self.models.values()]), \ + 'No elastic module found in models, please inherit from ElasticModule or ElasticModuleMixin' + self.elastic_controller = getattr(elastic_utils, self.elastic_controller_config['name'])(**self.elastic_controller_config['args']) + for model in self.models.values(): + if isinstance(model, (elastic_utils.ElasticModule, elastic_utils.ElasticModuleMixin)): + model.register_memory_controller(self.elastic_controller) + + # Initialize gradient clipper + if self.grad_clip is not None: + if isinstance(self.grad_clip, (float, int)): + self.grad_clip = float(self.grad_clip) + else: + self.grad_clip = getattr(grad_clip_utils, self.grad_clip['name'])(**self.grad_clip['args']) + + def prepare_dataloader(self, **kwargs): + """ + Prepare dataloader. + """ + self.data_sampler = ResumableSampler( + self.dataset, + shuffle=True, + ) + self.dataloader = DataLoader( + self.dataset, + batch_size=self.batch_size_per_gpu, + num_workers=int(np.ceil(os.cpu_count() / torch.cuda.device_count())), + pin_memory=True, + drop_last=True, + persistent_workers=True, + collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None, + sampler=self.data_sampler, + ) + self.data_iterator = cycle(self.dataloader) + + def _master_params_to_state_dicts(self, master_params): + """ + Convert master params to dict of state_dicts. + """ + if self.mix_precision_mode == 'inflat_all': + master_params = unflatten_master_params(self.model_params, master_params) + state_dicts = {name: model.state_dict() for name, model in self.models.items()} + master_params_names = sum( + [[(name, n) for n, p in model.named_parameters() if p.requires_grad] for name, model in self.models.items()] + , []) + for i, (model_name, param_name) in enumerate(master_params_names): + state_dicts[model_name][param_name] = master_params[i] + return state_dicts + + def _state_dicts_to_master_params(self, master_params, state_dicts): + """ + Convert a state_dict to master params. + """ + master_params_names = sum( + [[(name, n) for n, p in model.named_parameters() if p.requires_grad] for name, model in self.models.items()] + , []) + params = [state_dicts[name][param_name] for name, param_name in master_params_names] + if self.mix_precision_mode == 'inflat_all': + model_params_to_master_params(params, master_params) + else: + for i, param in enumerate(params): + master_params[i].data.copy_(param.data) + + def load(self, load_dir, step=0): + """ + Load a checkpoint. + Should be called by all processes. + """ + if self.is_master: + print(f'\nLoading checkpoint from step {step}...', end='') + + model_ckpts = {} + for name, model in self.models.items(): + model_ckpt = torch.load(read_file_dist(os.path.join(load_dir, 'ckpts', f'{name}_step{step:07d}.pt')), map_location=self.device, weights_only=True) + model_ckpts[name] = model_ckpt + model.load_state_dict(model_ckpt) + self._state_dicts_to_master_params(self.master_params, model_ckpts) + del model_ckpts + + if self.is_master: + for i, ema_rate in enumerate(self.ema_rate): + ema_ckpts = {} + for name, model in self.models.items(): + ema_ckpt = torch.load(os.path.join(load_dir, 'ckpts', f'{name}_ema{ema_rate}_step{step:07d}.pt'), map_location=self.device, weights_only=True) + ema_ckpts[name] = ema_ckpt + self._state_dicts_to_master_params(self.ema_params[i], ema_ckpts) + del ema_ckpts + + misc_ckpt = torch.load(read_file_dist(os.path.join(load_dir, 'ckpts', f'misc_step{step:07d}.pt')), map_location=torch.device('cpu'), weights_only=False) + self.optimizer.load_state_dict(misc_ckpt['optimizer']) + self.step = misc_ckpt['step'] + self.data_sampler.load_state_dict(misc_ckpt['data_sampler']) + if self.mix_precision_mode == 'amp' and self.mix_precision_dtype == torch.float16: + self.scaler.load_state_dict(misc_ckpt['scaler']) + elif self.mix_precision_mode == 'inflat_all' and self.mix_precision_dtype == torch.float16: + self.log_scale = misc_ckpt['log_scale'] + if self.lr_scheduler_config is not None: + self.lr_scheduler.load_state_dict(misc_ckpt['lr_scheduler']) + if self.elastic_controller_config is not None: + self.elastic_controller.load_state_dict(misc_ckpt['elastic_controller']) + if self.grad_clip is not None and not isinstance(self.grad_clip, float): + self.grad_clip.load_state_dict(misc_ckpt['grad_clip']) + del misc_ckpt + + if self.world_size > 1: + dist.barrier() + if self.is_master: + print(' Done.') + + if self.world_size > 1: + self.check_ddp() + + def save(self, non_blocking=True): + """ + Save a checkpoint. + Should be called only by the rank 0 process. + """ + assert self.is_master, 'save() should be called only by the rank 0 process.' + print(f'\nSaving checkpoint at step {self.step}...', end='') + + model_ckpts = self._master_params_to_state_dicts(self.master_params) + for name, model_ckpt in model_ckpts.items(): + model_ckpt = {k: v.cpu() for k, v in model_ckpt.items()} # Move to CPU for saving + if non_blocking: + threading.Thread( + target=torch.save, + args=(model_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_step{self.step:07d}.pt')), + ).start() + else: + torch.save(model_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_step{self.step:07d}.pt')) + + for i, ema_rate in enumerate(self.ema_rate): + ema_ckpts = self._master_params_to_state_dicts(self.ema_params[i]) + for name, ema_ckpt in ema_ckpts.items(): + ema_ckpt = {k: v.cpu() for k, v in ema_ckpt.items()} # Move to CPU for saving + if non_blocking: + threading.Thread( + target=torch.save, + args=(ema_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_ema{ema_rate}_step{self.step:07d}.pt')), + ).start() + else: + torch.save(ema_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_ema{ema_rate}_step{self.step:07d}.pt')) + + misc_ckpt = { + 'optimizer': self.optimizer.state_dict(), + 'step': self.step, + 'data_sampler': self.data_sampler.state_dict(), + } + if self.mix_precision_mode == 'amp' and self.mix_precision_dtype == torch.float16: + misc_ckpt['scaler'] = self.scaler.state_dict() + elif self.mix_precision_mode == 'inflat_all' and self.mix_precision_dtype == torch.float16: + misc_ckpt['log_scale'] = self.log_scale + if self.lr_scheduler_config is not None: + misc_ckpt['lr_scheduler'] = self.lr_scheduler.state_dict() + if self.elastic_controller_config is not None: + misc_ckpt['elastic_controller'] = self.elastic_controller.state_dict() + if self.grad_clip is not None and not isinstance(self.grad_clip, float): + misc_ckpt['grad_clip'] = self.grad_clip.state_dict() + if non_blocking: + threading.Thread( + target=torch.save, + args=(misc_ckpt, os.path.join(self.output_dir, 'ckpts', f'misc_step{self.step:07d}.pt')), + ).start() + else: + torch.save(misc_ckpt, os.path.join(self.output_dir, 'ckpts', f'misc_step{self.step:07d}.pt')) + print(' Done.') + + def finetune_from(self, finetune_ckpt): + """ + Finetune from a checkpoint. + Should be called by all processes. + """ + if self.is_master: + print('\nFinetuning from:') + for name, path in finetune_ckpt.items(): + print(f' - {name}: {path}') + + model_ckpts = {} + for name, model in self.models.items(): + model_state_dict = model.state_dict() + if name in finetune_ckpt: + model_ckpt = torch.load(read_file_dist(finetune_ckpt[name]), map_location=self.device, weights_only=True) + for k, v in model_ckpt.items(): + if k not in model_state_dict: + if self.is_master: + print(f'Warning: {k} not found in model_state_dict, skipped.') + model_ckpt[k] = None + elif model_ckpt[k].shape != model_state_dict[k].shape: + if self.is_master: + print(f'Warning: {k} shape mismatch, {model_ckpt[k].shape} vs {model_state_dict[k].shape}, skipped.') + model_ckpt[k] = model_state_dict[k] + model_ckpt = {k: v for k, v in model_ckpt.items() if v is not None} + model_ckpts[name] = model_ckpt + model.load_state_dict(model_ckpt) + else: + if self.is_master: + print(f'Warning: {name} not found in finetune_ckpt, skipped.') + model_ckpts[name] = model_state_dict + self._state_dicts_to_master_params(self.master_params, model_ckpts) + if self.is_master: + for i, ema_rate in enumerate(self.ema_rate): + self._state_dicts_to_master_params(self.ema_params[i], model_ckpts) + del model_ckpts + + if self.world_size > 1: + dist.barrier() + if self.is_master: + print('Done.') + + if self.world_size > 1: + self.check_ddp() + + @abstractmethod + def run_snapshot(self, num_samples, batch_size=4, verbose=False, **kwargs): + """ + Run a snapshot of the model. + """ + pass + + @torch.no_grad() + def visualize_sample(self, sample): + """ + Convert a sample to an image. + """ + if hasattr(self.dataset, 'visualize_sample'): + return self.dataset.visualize_sample(sample) + else: + return sample + + @torch.no_grad() + def snapshot_dataset(self, num_samples=100, batch_size=4): + """ + Sample images from the dataset. + """ + dataloader = torch.utils.data.DataLoader( + self.dataset, + batch_size=batch_size, + num_workers=1, + shuffle=True, + collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None, + ) + save_cfg = {} + for i in range(0, num_samples, batch_size): + data = next(iter(dataloader)) + data = {k: v[:min(num_samples - i, batch_size)] for k, v in data.items()} + data = recursive_to_device(data, self.device) + vis = self.visualize_sample(data) + if isinstance(vis, dict): + for k, v in vis.items(): + if f'dataset_{k}' not in save_cfg: + save_cfg[f'dataset_{k}'] = [] + save_cfg[f'dataset_{k}'].append(v) + else: + if 'dataset' not in save_cfg: + save_cfg['dataset'] = [] + save_cfg['dataset'].append(vis) + for name, image in save_cfg.items(): + utils.save_image( + torch.cat(image, dim=0), + os.path.join(self.output_dir, 'samples', f'{name}.jpg'), + nrow=int(np.sqrt(num_samples)), + normalize=True, + value_range=self.dataset.value_range, + ) + + @torch.no_grad() + def snapshot(self, suffix=None, num_samples=64, batch_size=4, verbose=False): + """ + Sample images from the model. + NOTE: This function should be called by all processes. + """ + if self.is_master: + print(f'\nSampling {num_samples} images...', end='') + + if suffix is None: + suffix = f'step{self.step:07d}' + + # Assign tasks + num_samples_per_process = int(np.ceil(num_samples / self.world_size)) + amp_context = partial(torch.autocast, device_type='cuda', dtype=self.mix_precision_dtype) if self.mix_precision_mode == 'amp' else nullcontext + with amp_context(): + samples = self.run_snapshot(num_samples_per_process, batch_size=batch_size, verbose=verbose) + + # Preprocess images + for key in list(samples.keys()): + if samples[key]['type'] == 'sample': + vis = self.visualize_sample(samples[key]['value']) + if isinstance(vis, dict): + for k, v in vis.items(): + samples[f'{key}_{k}'] = {'value': v, 'type': 'image'} + del samples[key] + else: + samples[key] = {'value': vis, 'type': 'image'} + + # Gather results + if self.world_size > 1: + for key in samples.keys(): + samples[key]['value'] = samples[key]['value'].contiguous() + if self.is_master: + all_images = [torch.empty_like(samples[key]['value']) for _ in range(self.world_size)] + else: + all_images = [] + dist.gather(samples[key]['value'], all_images, dst=0) + if self.is_master: + samples[key]['value'] = torch.cat(all_images, dim=0)[:num_samples] + + # Save images + if self.is_master: + os.makedirs(os.path.join(self.output_dir, 'samples', suffix), exist_ok=True) + for key in samples.keys(): + if samples[key]['type'] == 'image': + utils.save_image( + samples[key]['value'], + os.path.join(self.output_dir, 'samples', suffix, f'{key}_{suffix}.jpg'), + nrow=int(np.sqrt(num_samples)), + normalize=True, + value_range=self.dataset.value_range, + ) + elif samples[key]['type'] == 'number': + min = samples[key]['value'].min() + max = samples[key]['value'].max() + images = (samples[key]['value'] - min) / (max - min) + images = utils.make_grid( + images, + nrow=int(np.sqrt(num_samples)), + normalize=False, + ) + save_image_with_notes( + images, + os.path.join(self.output_dir, 'samples', suffix, f'{key}_{suffix}.jpg'), + notes=f'{key} min: {min}, max: {max}', + ) + + if self.is_master: + print(' Done.') + + def update_ema(self): + """ + Update exponential moving average. + Should only be called by the rank 0 process. + """ + assert self.is_master, 'update_ema() should be called only by the rank 0 process.' + for i, ema_rate in enumerate(self.ema_rate): + for master_param, ema_param in zip(self.master_params, self.ema_params[i]): + ema_param.detach().mul_(ema_rate).add_(master_param, alpha=1.0 - ema_rate) + + def check_ddp(self): + """ + Check if DDP is working properly. + Should be called by all process. + """ + if self.is_master: + print('\nPerforming DDP check...') + + if self.is_master: + print('Checking if parameters are consistent across processes...') + dist.barrier() + try: + for p in self.master_params: + # split to avoid OOM + for i in range(0, p.numel(), 10000000): + sub_size = min(10000000, p.numel() - i) + sub_p = p.detach().view(-1)[i:i+sub_size] + # gather from all processes + sub_p_gather = [torch.empty_like(sub_p) for _ in range(self.world_size)] + dist.all_gather(sub_p_gather, sub_p) + # check if equal + assert all([torch.equal(sub_p, sub_p_gather[i]) for i in range(self.world_size)]), 'parameters are not consistent across processes' + except AssertionError as e: + if self.is_master: + print(f'\n\033[91mError: {e}\033[0m') + print('DDP check failed.') + raise e + + dist.barrier() + if self.is_master: + print('Done.') + + @abstractmethod + def training_losses(**mb_data): + """ + Compute training losses. + """ + pass + + def load_data(self): + """ + Load data. + """ + if self.prefetch_data: + if self._data_prefetched is None: + self._data_prefetched = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True) + data = self._data_prefetched + self._data_prefetched = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True) + else: + data = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True) + + # if the data is a dict, we need to split it into multiple dicts with batch_size_per_gpu + if isinstance(data, dict): + if self.batch_split == 1: + data_list = [data] + else: + batch_size = list(data.values())[0].shape[0] + data_list = [ + {k: v[i * batch_size // self.batch_split:(i + 1) * batch_size // self.batch_split] for k, v in data.items()} + for i in range(self.batch_split) + ] + elif isinstance(data, list): + data_list = data + else: + raise ValueError('Data must be a dict or a list of dicts.') + + return data_list + + def run_step(self, data_list): + """ + Run a training step. + """ + step_log = {'loss': {}, 'status': {}} + amp_context = partial(torch.autocast, device_type='cuda', dtype=self.mix_precision_dtype) if self.mix_precision_mode == 'amp' else nullcontext + elastic_controller_context = self.elastic_controller.record if self.elastic_controller_config is not None else nullcontext + + # Train + losses = [] + statuses = [] + elastic_controller_logs = [] + zero_grad(self.model_params) + for i, mb_data in enumerate(data_list): + ## sync at the end of each batch split + sync_contexts = [self.training_models[name].no_sync for name in self.training_models] if i != len(data_list) - 1 and self.world_size > 1 else [nullcontext] + with nested_contexts(*sync_contexts), elastic_controller_context(): + with amp_context(): + loss, status = self.training_losses(**mb_data) + l = loss['loss'] / len(data_list) + ## backward + if self.mix_precision_mode == 'amp' and self.mix_precision_dtype == torch.float16: + self.scaler.scale(l).backward() + elif self.mix_precision_mode == 'inflat_all' and self.mix_precision_dtype == torch.float16: + scaled_l = l * (2 ** self.log_scale) + scaled_l.backward() + else: + l.backward() + ## log + losses.append(dict_foreach(loss, lambda x: x.item() if isinstance(x, torch.Tensor) else x)) + statuses.append(dict_foreach(status, lambda x: x.item() if isinstance(x, torch.Tensor) else x)) + if self.elastic_controller_config is not None: + elastic_controller_logs.append(self.elastic_controller.log()) + ## gradient clip + if self.grad_clip is not None: + if self.mix_precision_mode == 'amp' and self.mix_precision_dtype == torch.float16: + self.scaler.unscale_(self.optimizer) + elif self.mix_precision_mode == 'inflat_all': + model_grads_to_master_grads(self.model_params, self.master_params) + if self.mix_precision_dtype == torch.float16: + self.master_params[0].grad.mul_(1.0 / (2 ** self.log_scale)) + if isinstance(self.grad_clip, float): + grad_norm = torch.nn.utils.clip_grad_norm_(self.master_params, self.grad_clip) + else: + grad_norm = self.grad_clip(self.master_params) + if torch.isfinite(grad_norm): + statuses[-1]['grad_norm'] = grad_norm.item() + ## step + if self.mix_precision_mode == 'amp' and self.mix_precision_dtype == torch.float16: + prev_scale = self.scaler.get_scale() + self.scaler.step(self.optimizer) + self.scaler.update() + elif self.mix_precision_mode == 'inflat_all': + if self.mix_precision_dtype == torch.float16: + prev_scale = 2 ** self.log_scale + if not any(not p.grad.isfinite().all() for p in self.model_params): + if self.grad_clip is None: + model_grads_to_master_grads(self.model_params, self.master_params) + self.master_params[0].grad.mul_(1.0 / (2 ** self.log_scale)) + self.optimizer.step() + master_params_to_model_params(self.model_params, self.master_params) + self.log_scale += self.fp16_scale_growth + else: + self.log_scale -= 1 + else: + prev_scale = 1.0 + if self.grad_clip is None: + model_grads_to_master_grads(self.model_params, self.master_params) + if not any(not p.grad.isfinite().all() for p in self.master_params): + self.optimizer.step() + master_params_to_model_params(self.model_params, self.master_params) + else: + print('\n\033[93mWarning: NaN detected in gradients. Skipping update.\033[0m') + else: + prev_scale = 1.0 + if not any(not p.grad.isfinite().all() for p in self.model_params): + self.optimizer.step() + else: + print('\n\033[93mWarning: NaN detected in gradients. Skipping update.\033[0m') + ## adjust learning rate + if self.lr_scheduler_config is not None: + statuses[-1]['lr'] = self.lr_scheduler.get_last_lr()[0] + self.lr_scheduler.step() + + # Logs + step_log['loss'] = dict_reduce(losses, lambda x: np.mean(x)) + step_log['status'] = dict_reduce(statuses, lambda x: np.mean(x), special_func={'min': lambda x: np.min(x), 'max': lambda x: np.max(x)}) + if self.elastic_controller_config is not None: + step_log['elastic'] = dict_reduce(elastic_controller_logs, lambda x: np.mean(x)) + if self.grad_clip is not None: + step_log['grad_clip'] = self.grad_clip if isinstance(self.grad_clip, float) else self.grad_clip.log() + + # Check grad and norm of each param + if self.log_param_stats: + param_norms = {} + param_grads = {} + for model_name, model in self.models.items(): + for name, param in model.named_parameters(): + if param.requires_grad: + param_norms[f'{model_name}.{name}'] = param.norm().item() + if param.grad is not None and torch.isfinite(param.grad).all(): + param_grads[f'{model_name}.{name}'] = param.grad.norm().item() / prev_scale + step_log['param_norms'] = param_norms + step_log['param_grads'] = param_grads + + # Update exponential moving average + if self.is_master: + self.update_ema() + + return step_log + + def save_logs(self): + log_str = '\n'.join([ + f'{step}: {json.dumps(dict_foreach(log, lambda x: float(x)))}' for step, log in self.log + ]) + with open(os.path.join(self.output_dir, 'log.txt'), 'a') as log_file: + log_file.write(log_str + '\n') + + # show with mlflow + log_show = [l for _, l in self.log if not dict_any(l, lambda x: np.isnan(x))] + log_show = dict_reduce(log_show, lambda x: np.mean(x)) + log_show = dict_flatten(log_show, sep='/') + for key, value in log_show.items(): + self.writer.add_scalar(key, value, self.step) + self.log = [] + + def check_abort(self): + """ + Check if training should be aborted due to certain conditions. + """ + # 1. If log_scale in inflat_all mode is less than 0 + if self.mix_precision_dtype == torch.float16 and \ + self.mix_precision_mode == 'inflat_all' and \ + self.log_scale < 0: + if self.is_master: + print ('\n\n\033[91m') + print (f'ABORT: log_scale in inflat_all mode is less than 0 at step {self.step}.') + print ('This indicates that the model is diverging. You should look into the model and the data.') + print ('\033[0m') + self.save(non_blocking=False) + self.save_logs() + if self.world_size > 1: + dist.barrier() + raise ValueError('ABORT: log_scale in inflat_all mode is less than 0.') + + def run(self): + """ + Run training. + """ + if self.is_master: + print('\nStarting training...') + self.snapshot_dataset(batch_size=self.snapshot_batch_size) + if self.step == 0: + self.snapshot(suffix='init', batch_size=self.snapshot_batch_size) + else: # resume + self.snapshot(suffix=f'resume_step{self.step:07d}', batch_size=self.snapshot_batch_size) + + time_last_print = 0.0 + time_elapsed = 0.0 + while self.step < self.max_steps: + time_start = time.time() + + data_list = self.load_data() + step_log = self.run_step(data_list) + + time_end = time.time() + time_elapsed += time_end - time_start + + self.step += 1 + + # Print progress + if self.is_master and self.step % self.i_print == 0: + speed = self.i_print / (time_elapsed - time_last_print) * 3600 + columns = [ + f'Step: {self.step}/{self.max_steps} ({self.step / self.max_steps * 100:.2f}%)', + f'Elapsed: {time_elapsed / 3600:.2f} h', + f'Speed: {speed:.2f} steps/h', + f'ETA: {(self.max_steps - self.step) / speed:.2f} h', + ] + print(' | '.join([c.ljust(25) for c in columns]), flush=True) + time_last_print = time_elapsed + + # Check ddp + if self.parallel_mode == 'ddp' and self.world_size > 1 and self.i_ddpcheck is not None and self.step % self.i_ddpcheck == 0: + self.check_ddp() + + # Sample images + if self.step % self.i_sample == 0: + self.snapshot() + + if self.is_master: + self.log.append((self.step, {})) + + # Log time + self.log[-1][1]['time'] = { + 'step': time_end - time_start, + 'elapsed': time_elapsed, + } + + # Log losses + if step_log is not None: + self.log[-1][1].update(step_log) + + # Log scale + if self.mix_precision_dtype == torch.float16: + if self.mix_precision_mode == 'amp': + self.log[-1][1]['scale'] = self.scaler.get_scale() + elif self.mix_precision_mode == 'inflat_all': + self.log[-1][1]['log_scale'] = self.log_scale + + # Save log + if self.step % self.i_log == 0: + self.save_logs() + + # Save checkpoint + if self.step % self.i_save == 0: + self.save() + + # Check abort + self.check_abort() + + self.snapshot(suffix='final', batch_size=self.snapshot_batch_size) + if self.world_size > 1: + dist.barrier() + if self.is_master: + self.writer.close() + print('Training finished.') + + def profile(self, wait=2, warmup=3, active=5): + """ + Profile the training loop. + """ + with torch.profiler.profile( + schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=1), + on_trace_ready=torch.profiler.tensorboard_trace_handler(os.path.join(self.output_dir, 'profile')), + profile_memory=True, + with_stack=True, + ) as prof: + for _ in range(wait + warmup + active): + self.run_step() + prof.step() diff --git a/trellis2/trainers/flow_matching/flow_matching.py b/trellis2/trainers/flow_matching/flow_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..2850ce4190cdda085618ab0ac651c02a33e6da3b --- /dev/null +++ b/trellis2/trainers/flow_matching/flow_matching.py @@ -0,0 +1,353 @@ +from typing import * +import copy +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader +import numpy as np +from easydict import EasyDict as edict + +from ..basic import BasicTrainer +from ...pipelines import samplers +from ...utils.general_utils import dict_reduce +from .mixins.classifier_free_guidance import ClassifierFreeGuidanceMixin +from .mixins.text_conditioned import TextConditionedMixin +from .mixins.image_conditioned import ImageConditionedMixin + + +class FlowMatchingTrainer(BasicTrainer): + """ + Trainer for diffusion model with flow matching objective. + + Args: + models (dict[str, nn.Module]): Models to train. + dataset (torch.utils.data.Dataset): Dataset. + output_dir (str): Output directory. + load_dir (str): Load directory. + step (int): Step to load. + batch_size (int): Batch size. + batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. + batch_split (int): Split batch with gradient accumulation. + max_steps (int): Max steps. + optimizer (dict): Optimizer config. + lr_scheduler (dict): Learning rate scheduler config. + elastic (dict): Elastic memory management config. + grad_clip (float or dict): Gradient clip config. + ema_rate (float or list): Exponential moving average rates. + fp16_mode (str): FP16 mode. + - None: No FP16. + - 'inflat_all': Hold a inflated fp32 master param for all params. + - 'amp': Automatic mixed precision. + fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. + finetune_ckpt (dict): Finetune checkpoint. + log_param_stats (bool): Log parameter stats. + i_print (int): Print interval. + i_log (int): Log interval. + i_sample (int): Sample interval. + i_save (int): Save interval. + i_ddpcheck (int): DDP check interval. + + t_schedule (dict): Time schedule for flow matching. + sigma_min (float): Minimum noise level. + """ + def __init__( + self, + *args, + t_schedule: dict = { + 'name': 'logitNormal', + 'args': { + 'mean': 0.0, + 'std': 1.0, + } + }, + sigma_min: float = 1e-5, + **kwargs + ): + super().__init__(*args, **kwargs) + self.t_schedule = t_schedule + self.sigma_min = sigma_min + + def diffuse(self, x_0: torch.Tensor, t: torch.Tensor, noise: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Diffuse the data for a given number of diffusion steps. + In other words, sample from q(x_t | x_0). + + Args: + x_0: The [N x C x ...] tensor of noiseless inputs. + t: The [N] tensor of diffusion steps [0-1]. + noise: If specified, use this noise instead of generating new noise. + + Returns: + x_t, the noisy version of x_0 under timestep t. + """ + if noise is None: + noise = torch.randn_like(x_0) + assert noise.shape == x_0.shape, "noise must have same shape as x_0" + + t = t.view(-1, *[1 for _ in range(len(x_0.shape) - 1)]) + x_t = (1 - t) * x_0 + (self.sigma_min + (1 - self.sigma_min) * t) * noise + + return x_t + + def reverse_diffuse(self, x_t: torch.Tensor, t: torch.Tensor, noise: torch.Tensor) -> torch.Tensor: + """ + Get original image from noisy version under timestep t. + """ + assert noise.shape == x_t.shape, "noise must have same shape as x_t" + t = t.view(-1, *[1 for _ in range(len(x_t.shape) - 1)]) + x_0 = (x_t - (self.sigma_min + (1 - self.sigma_min) * t) * noise) / (1 - t) + return x_0 + + def get_v(self, x_0: torch.Tensor, noise: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + """ + Compute the velocity of the diffusion process at time t. + """ + return (1 - self.sigma_min) * noise - x_0 + + def get_cond(self, cond, **kwargs): + """ + Get the conditioning data. + """ + return cond + + def get_inference_cond(self, cond, **kwargs): + """ + Get the conditioning data for inference. + """ + return {'cond': cond, **kwargs} + + def get_sampler(self, **kwargs) -> samplers.FlowEulerSampler: + """ + Get the sampler for the diffusion process. + """ + return samplers.FlowEulerSampler(self.sigma_min) + + def vis_cond(self, **kwargs): + """ + Visualize the conditioning data. + """ + return {} + + def sample_t(self, batch_size: int) -> torch.Tensor: + """ + Sample timesteps. + """ + if self.t_schedule['name'] == 'uniform': + t = torch.rand(batch_size) + elif self.t_schedule['name'] == 'logitNormal': + mean = self.t_schedule['args']['mean'] + std = self.t_schedule['args']['std'] + t = torch.sigmoid(torch.randn(batch_size) * std + mean) + else: + raise ValueError(f"Unknown t_schedule: {self.t_schedule['name']}") + return t + + def training_losses( + self, + x_0: torch.Tensor, + cond=None, + **kwargs + ) -> Tuple[Dict, Dict]: + """ + Compute training losses for a single timestep. + + Args: + x_0: The [N x C x ...] tensor of noiseless inputs. + cond: The [N x ...] tensor of additional conditions. + kwargs: Additional arguments to pass to the backbone. + + Returns: + a dict with the key "loss" containing a tensor of shape [N]. + may also contain other keys for different terms. + """ + noise = torch.randn_like(x_0) + t = self.sample_t(x_0.shape[0]).to(x_0.device).float() + x_t = self.diffuse(x_0, t, noise=noise) + cond = self.get_cond(cond, **kwargs) + + pred = self.training_models['denoiser'](x_t, t * 1000, cond, **kwargs) + assert pred.shape == noise.shape == x_0.shape + target = self.get_v(x_0, noise, t) + terms = edict() + terms["mse"] = F.mse_loss(pred, target) + terms["loss"] = terms["mse"] + + # log loss with time bins + mse_per_instance = np.array([ + F.mse_loss(pred[i], target[i]).item() + for i in range(x_0.shape[0]) + ]) + time_bin = np.digitize(t.cpu().numpy(), np.linspace(0, 1, 11)) - 1 + for i in range(10): + if (time_bin == i).sum() != 0: + terms[f"bin_{i}"] = {"mse": mse_per_instance[time_bin == i].mean()} + + return terms, {} + + @torch.no_grad() + def run_snapshot( + self, + num_samples: int, + batch_size: int, + verbose: bool = False, + ) -> Dict: + dataloader = DataLoader( + copy.deepcopy(self.dataset), + batch_size=batch_size, + shuffle=True, + num_workers=0, + collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None, + ) + + # inference + sampler = self.get_sampler() + sample_gt = [] + sample = [] + cond_vis = [] + for i in range(0, num_samples, batch_size): + batch = min(batch_size, num_samples - i) + data = next(iter(dataloader)) + data = {k: v[:batch].cuda() if isinstance(v, torch.Tensor) else v[:batch] for k, v in data.items()} + noise = torch.randn_like(data['x_0']) + sample_gt.append(data['x_0']) + cond_vis.append(self.vis_cond(**data)) + del data['x_0'] + args = self.get_inference_cond(**data) + res = sampler.sample( + self.models['denoiser'], + noise=noise, + **args, + steps=50, guidance_strength=3.0, verbose=verbose, + ) + sample.append(res.samples) + + sample_gt = torch.cat(sample_gt, dim=0) + sample = torch.cat(sample, dim=0) + sample_dict = { + 'sample_gt': {'value': sample_gt, 'type': 'sample'}, + 'sample': {'value': sample, 'type': 'sample'}, + } + sample_dict.update(dict_reduce(cond_vis, None, { + 'value': lambda x: torch.cat(x, dim=0), + 'type': lambda x: x[0], + })) + + return sample_dict + + +class FlowMatchingCFGTrainer(ClassifierFreeGuidanceMixin, FlowMatchingTrainer): + """ + Trainer for diffusion model with flow matching objective and classifier-free guidance. + + Args: + models (dict[str, nn.Module]): Models to train. + dataset (torch.utils.data.Dataset): Dataset. + output_dir (str): Output directory. + load_dir (str): Load directory. + step (int): Step to load. + batch_size (int): Batch size. + batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. + batch_split (int): Split batch with gradient accumulation. + max_steps (int): Max steps. + optimizer (dict): Optimizer config. + lr_scheduler (dict): Learning rate scheduler config. + elastic (dict): Elastic memory management config. + grad_clip (float or dict): Gradient clip config. + ema_rate (float or list): Exponential moving average rates. + fp16_mode (str): FP16 mode. + - None: No FP16. + - 'inflat_all': Hold a inflated fp32 master param for all params. + - 'amp': Automatic mixed precision. + fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. + finetune_ckpt (dict): Finetune checkpoint. + log_param_stats (bool): Log parameter stats. + i_print (int): Print interval. + i_log (int): Log interval. + i_sample (int): Sample interval. + i_save (int): Save interval. + i_ddpcheck (int): DDP check interval. + + t_schedule (dict): Time schedule for flow matching. + sigma_min (float): Minimum noise level. + p_uncond (float): Probability of dropping conditions. + """ + pass + + +class TextConditionedFlowMatchingCFGTrainer(TextConditionedMixin, FlowMatchingCFGTrainer): + """ + Trainer for text-conditioned diffusion model with flow matching objective and classifier-free guidance. + + Args: + models (dict[str, nn.Module]): Models to train. + dataset (torch.utils.data.Dataset): Dataset. + output_dir (str): Output directory. + load_dir (str): Load directory. + step (int): Step to load. + batch_size (int): Batch size. + batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. + batch_split (int): Split batch with gradient accumulation. + max_steps (int): Max steps. + optimizer (dict): Optimizer config. + lr_scheduler (dict): Learning rate scheduler config. + elastic (dict): Elastic memory management config. + grad_clip (float or dict): Gradient clip config. + ema_rate (float or list): Exponential moving average rates. + fp16_mode (str): FP16 mode. + - None: No FP16. + - 'inflat_all': Hold a inflated fp32 master param for all params. + - 'amp': Automatic mixed precision. + fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. + finetune_ckpt (dict): Finetune checkpoint. + log_param_stats (bool): Log parameter stats. + i_print (int): Print interval. + i_log (int): Log interval. + i_sample (int): Sample interval. + i_save (int): Save interval. + i_ddpcheck (int): DDP check interval. + + t_schedule (dict): Time schedule for flow matching. + sigma_min (float): Minimum noise level. + p_uncond (float): Probability of dropping conditions. + text_cond_model(str): Text conditioning model. + """ + pass + + +class ImageConditionedFlowMatchingCFGTrainer(ImageConditionedMixin, FlowMatchingCFGTrainer): + """ + Trainer for image-conditioned diffusion model with flow matching objective and classifier-free guidance. + + Args: + models (dict[str, nn.Module]): Models to train. + dataset (torch.utils.data.Dataset): Dataset. + output_dir (str): Output directory. + load_dir (str): Load directory. + step (int): Step to load. + batch_size (int): Batch size. + batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. + batch_split (int): Split batch with gradient accumulation. + max_steps (int): Max steps. + optimizer (dict): Optimizer config. + lr_scheduler (dict): Learning rate scheduler config. + elastic (dict): Elastic memory management config. + grad_clip (float or dict): Gradient clip config. + ema_rate (float or list): Exponential moving average rates. + fp16_mode (str): FP16 mode. + - None: No FP16. + - 'inflat_all': Hold a inflated fp32 master param for all params. + - 'amp': Automatic mixed precision. + fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. + finetune_ckpt (dict): Finetune checkpoint. + log_param_stats (bool): Log parameter stats. + i_print (int): Print interval. + i_log (int): Log interval. + i_sample (int): Sample interval. + i_save (int): Save interval. + i_ddpcheck (int): DDP check interval. + + t_schedule (dict): Time schedule for flow matching. + sigma_min (float): Minimum noise level. + p_uncond (float): Probability of dropping conditions. + image_cond_model (str): Image conditioning model. + """ + pass diff --git a/trellis2/trainers/flow_matching/mixins/classifier_free_guidance.py b/trellis2/trainers/flow_matching/mixins/classifier_free_guidance.py new file mode 100644 index 0000000000000000000000000000000000000000..903ab81b39276087eac72d9c55941fef47141498 --- /dev/null +++ b/trellis2/trainers/flow_matching/mixins/classifier_free_guidance.py @@ -0,0 +1,59 @@ +import torch +import numpy as np +from ....utils.general_utils import dict_foreach +from ....pipelines import samplers + + +class ClassifierFreeGuidanceMixin: + def __init__(self, *args, p_uncond: float = 0.1, **kwargs): + super().__init__(*args, **kwargs) + self.p_uncond = p_uncond + + def get_cond(self, cond, neg_cond=None, **kwargs): + """ + Get the conditioning data. + """ + assert neg_cond is not None, "neg_cond must be provided for classifier-free guidance" + + if self.p_uncond > 0: + # randomly drop the class label + def get_batch_size(cond): + if isinstance(cond, torch.Tensor): + return cond.shape[0] + elif isinstance(cond, list): + return len(cond) + else: + raise ValueError(f"Unsupported type of cond: {type(cond)}") + + ref_cond = cond if not isinstance(cond, dict) else cond[list(cond.keys())[0]] + B = get_batch_size(ref_cond) + + def select(cond, neg_cond, mask): + if isinstance(cond, torch.Tensor): + mask = torch.tensor(mask, device=cond.device).reshape(-1, *[1] * (cond.ndim - 1)) + return torch.where(mask, neg_cond, cond) + elif isinstance(cond, list): + return [nc if m else c for c, nc, m in zip(cond, neg_cond, mask)] + else: + raise ValueError(f"Unsupported type of cond: {type(cond)}") + + mask = list(np.random.rand(B) < self.p_uncond) + if not isinstance(cond, dict): + cond = select(cond, neg_cond, mask) + else: + cond = dict_foreach([cond, neg_cond], lambda x: select(x[0], x[1], mask)) + + return cond + + def get_inference_cond(self, cond, neg_cond=None, **kwargs): + """ + Get the conditioning data for inference. + """ + assert neg_cond is not None, "neg_cond must be provided for classifier-free guidance" + return {'cond': cond, 'neg_cond': neg_cond, **kwargs} + + def get_sampler(self, **kwargs) -> samplers.FlowEulerCfgSampler: + """ + Get the sampler for the diffusion process. + """ + return samplers.FlowEulerCfgSampler(self.sigma_min) diff --git a/trellis2/trainers/flow_matching/mixins/image_conditioned.py b/trellis2/trainers/flow_matching/mixins/image_conditioned.py new file mode 100644 index 0000000000000000000000000000000000000000..c8f4c7241ef2a9d4f6287ef6fd07cb46b84d7359 --- /dev/null +++ b/trellis2/trainers/flow_matching/mixins/image_conditioned.py @@ -0,0 +1,249 @@ +from typing import * +import torch +import torch.nn.functional as F +from torchvision import transforms +from transformers import DINOv3ViTModel +import numpy as np +from PIL import Image + +from ....utils import dist_utils + + +class DinoV2FeatureExtractor: + """ + Feature extractor for DINOv2 models. + """ + def __init__(self, model_name: str): + self.model_name = model_name + self.model = torch.hub.load('facebookresearch/dinov2', model_name, pretrained=True) + self.model.eval() + self.transform = transforms.Compose([ + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + + def to(self, device): + self.model.to(device) + + def cuda(self): + self.model.cuda() + + def cpu(self): + self.model.cpu() + + @torch.no_grad() + def __call__(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor: + """ + Extract features from the image. + + Args: + image: A batch of images as a tensor of shape (B, C, H, W) or a list of PIL images. + + Returns: + A tensor of shape (B, N, D) where N is the number of patches and D is the feature dimension. + """ + if isinstance(image, torch.Tensor): + assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)" + elif isinstance(image, list): + assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images" + image = [i.resize((518, 518), Image.LANCZOS) for i in image] + image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image] + image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image] + image = torch.stack(image).cuda() + else: + raise ValueError(f"Unsupported type of image: {type(image)}") + + image = self.transform(image).cuda() + features = self.model(image, is_training=True)['x_prenorm'] + patchtokens = F.layer_norm(features, features.shape[-1:]) + return patchtokens + + +class DinoV3FeatureExtractor: + """ + Feature extractor for DINOv3 models. + """ + def __init__(self, model_name: str, image_size=512): + self.model_name = model_name + self.model = DINOv3ViTModel.from_pretrained(model_name) + self.model.eval() + self.image_size = image_size + self.transform = transforms.Compose([ + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + + def to(self, device): + self.model.to(device) + + def cuda(self): + self.model.cuda() + + def cpu(self): + self.model.cpu() + + def extract_features(self, image: torch.Tensor) -> torch.Tensor: + image = image.to(self.model.embeddings.patch_embeddings.weight.dtype) + hidden_states = self.model.embeddings(image, bool_masked_pos=None) + position_embeddings = self.model.rope_embeddings(image) + + for i, layer_module in enumerate(self.model.layer): + hidden_states = layer_module( + hidden_states, + position_embeddings=position_embeddings, + ) + + return F.layer_norm(hidden_states, hidden_states.shape[-1:]) + + @torch.no_grad() + def __call__(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor: + """ + Extract features from the image. + + Args: + image: A batch of images as a tensor of shape (B, C, H, W) or a list of PIL images. + + Returns: + A tensor of shape (B, N, D) where N is the number of patches and D is the feature dimension. + """ + if isinstance(image, torch.Tensor): + assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)" + elif isinstance(image, list): + assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images" + image = [i.resize((self.image_size, self.image_size), Image.LANCZOS) for i in image] + image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image] + image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image] + image = torch.stack(image).cuda() + else: + raise ValueError(f"Unsupported type of image: {type(image)}") + + image = self.transform(image).cuda() + features = self.extract_features(image) + return features + + +class ImageConditionedMixin: + """ + Mixin for image-conditioned models. + + Args: + image_cond_model: The image conditioning model. + """ + def __init__(self, *args, image_cond_model: dict, **kwargs): + super().__init__(*args, **kwargs) + self.image_cond_model_config = image_cond_model + self.image_cond_model = None # the model is init lazily + + def _init_image_cond_model(self): + """ + Initialize the image conditioning model. + """ + with dist_utils.local_master_first(): + self.image_cond_model = globals()[self.image_cond_model_config['name']](**self.image_cond_model_config.get('args', {})) + self.image_cond_model.cuda() + + @torch.no_grad() + def encode_image(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor: + """ + Encode the image. + """ + if self.image_cond_model is None: + self._init_image_cond_model() + features = self.image_cond_model(image) + return features + + def get_cond(self, cond, **kwargs): + """ + Get the conditioning data. + """ + cond = self.encode_image(cond) + kwargs['neg_cond'] = torch.zeros_like(cond) + cond = super().get_cond(cond, **kwargs) + return cond + + def get_inference_cond(self, cond, **kwargs): + """ + Get the conditioning data for inference. + """ + cond = self.encode_image(cond) + kwargs['neg_cond'] = torch.zeros_like(cond) + cond = super().get_inference_cond(cond, **kwargs) + return cond + + def vis_cond(self, cond, **kwargs): + """ + Visualize the conditioning data. + """ + return {'image': {'value': cond, 'type': 'image'}} + + +class MultiImageConditionedMixin: + """ + Mixin for multiple-image-conditioned models. + + Args: + image_cond_model: The image conditioning model. + """ + def __init__(self, *args, image_cond_model: dict, **kwargs): + super().__init__(*args, **kwargs) + self.image_cond_model_config = image_cond_model + self.image_cond_model = None # the model is init lazily + + def _init_image_cond_model(self): + """ + Initialize the image conditioning model. + """ + with dist_utils.local_master_first(): + self.image_cond_model = globals()[self.image_cond_model_config['name']](**self.image_cond_model_config.get('args', {})) + + @torch.no_grad() + def encode_images(self, images: Union[List[torch.Tensor], List[List[Image.Image]]]) -> List[torch.Tensor]: + """ + Encode the image. + """ + if self.image_cond_model is None: + self._init_image_cond_model() + seqlen = [len(i) for i in images] + images = torch.cat(images, dim=0) if isinstance(images[0], torch.Tensor) else sum(images, []) + features = self.image_cond_model(images) + features = torch.split(features, seqlen) + features = [feature.reshape(-1, feature.shape[-1]) for feature in features] + return features + + def get_cond(self, cond, **kwargs): + """ + Get the conditioning data. + """ + cond = self.encode_images(cond) + kwargs['neg_cond'] = [ + torch.zeros_like(cond[0][:1, :]) for _ in range(len(cond)) + ] + cond = super().get_cond(cond, **kwargs) + return cond + + def get_inference_cond(self, cond, **kwargs): + """ + Get the conditioning data for inference. + """ + cond = self.encode_images(cond) + kwargs['neg_cond'] = [ + torch.zeros_like(cond[0][:1, :]) for _ in range(len(cond)) + ] + cond = super().get_inference_cond(cond, **kwargs) + return cond + + def vis_cond(self, cond, **kwargs): + """ + Visualize the conditioning data. + """ + H, W = cond[0].shape[-2:] + vis = [] + for images in cond: + canvas = torch.zeros(3, H * 2, W * 2, device=images.device, dtype=images.dtype) + for i, image in enumerate(images): + if i == 4: + break + kh = i // 2 + kw = i % 2 + canvas[:, kh*H:(kh+1)*H, kw*W:(kw+1)*W] = image + vis.append(canvas) + vis = torch.stack(vis) + return {'image': {'value': vis, 'type': 'image'}} diff --git a/trellis2/trainers/flow_matching/mixins/text_conditioned.py b/trellis2/trainers/flow_matching/mixins/text_conditioned.py new file mode 100644 index 0000000000000000000000000000000000000000..f7300c7474fb979aef6097a3ecb5dc0acfffce5f --- /dev/null +++ b/trellis2/trainers/flow_matching/mixins/text_conditioned.py @@ -0,0 +1,68 @@ +from typing import * +import os +os.environ['TOKENIZERS_PARALLELISM'] = 'true' +import torch +from transformers import AutoTokenizer, CLIPTextModel + +from ....utils import dist_utils + + +class TextConditionedMixin: + """ + Mixin for text-conditioned models. + + Args: + text_cond_model: The text conditioning model. + """ + def __init__(self, *args, text_cond_model: str = 'openai/clip-vit-large-patch14', **kwargs): + super().__init__(*args, **kwargs) + self.text_cond_model_name = text_cond_model + self.text_cond_model = None # the model is init lazily + + def _init_text_cond_model(self): + """ + Initialize the text conditioning model. + """ + # load model + with dist_utils.local_master_first(): + model = CLIPTextModel.from_pretrained(self.text_cond_model_name) + tokenizer = AutoTokenizer.from_pretrained(self.text_cond_model_name) + model.eval() + model = model.cuda() + self.text_cond_model = { + 'model': model, + 'tokenizer': tokenizer, + } + self.text_cond_model['null_cond'] = self.encode_text(['']) + + @torch.no_grad() + def encode_text(self, text: List[str]) -> torch.Tensor: + """ + Encode the text. + """ + assert isinstance(text, list) and isinstance(text[0], str), "TextConditionedMixin only supports list of strings as cond" + if self.text_cond_model is None: + self._init_text_cond_model() + encoding = self.text_cond_model['tokenizer'](text, max_length=77, padding='max_length', truncation=True, return_tensors='pt') + tokens = encoding['input_ids'].cuda() + embeddings = self.text_cond_model['model'](input_ids=tokens).last_hidden_state + + return embeddings + + def get_cond(self, cond, **kwargs): + """ + Get the conditioning data. + """ + cond = self.encode_text(cond) + kwargs['neg_cond'] = self.text_cond_model['null_cond'].repeat(cond.shape[0], 1, 1) + cond = super().get_cond(cond, **kwargs) + return cond + + def get_inference_cond(self, cond, **kwargs): + """ + Get the conditioning data for inference. + """ + cond = self.encode_text(cond) + kwargs['neg_cond'] = self.text_cond_model['null_cond'].repeat(cond.shape[0], 1, 1) + cond = super().get_inference_cond(cond, **kwargs) + return cond diff --git a/trellis2/trainers/flow_matching/sparse_flow_matching.py b/trellis2/trainers/flow_matching/sparse_flow_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..164fc1dab20ee9a0ccca206e5c2d7d7f3dcb1719 --- /dev/null +++ b/trellis2/trainers/flow_matching/sparse_flow_matching.py @@ -0,0 +1,325 @@ +from typing import * +import os +import copy +import functools +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader +import numpy as np +from easydict import EasyDict as edict + +from ...modules import sparse as sp +from ...utils.general_utils import dict_reduce +from ...utils.data_utils import recursive_to_device, cycle, BalancedResumableSampler +from .flow_matching import FlowMatchingTrainer +from .mixins.classifier_free_guidance import ClassifierFreeGuidanceMixin +from .mixins.text_conditioned import TextConditionedMixin +from .mixins.image_conditioned import ImageConditionedMixin, MultiImageConditionedMixin + + +class SparseFlowMatchingTrainer(FlowMatchingTrainer): + """ + Trainer for sparse diffusion model with flow matching objective. + + Args: + models (dict[str, nn.Module]): Models to train. + dataset (torch.utils.data.Dataset): Dataset. + output_dir (str): Output directory. + load_dir (str): Load directory. + step (int): Step to load. + batch_size (int): Batch size. + batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. + batch_split (int): Split batch with gradient accumulation. + max_steps (int): Max steps. + optimizer (dict): Optimizer config. + lr_scheduler (dict): Learning rate scheduler config. + elastic (dict): Elastic memory management config. + grad_clip (float or dict): Gradient clip config. + ema_rate (float or list): Exponential moving average rates. + fp16_mode (str): FP16 mode. + - None: No FP16. + - 'inflat_all': Hold a inflated fp32 master param for all params. + - 'amp': Automatic mixed precision. + fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. + finetune_ckpt (dict): Finetune checkpoint. + log_param_stats (bool): Log parameter stats. + i_print (int): Print interval. + i_log (int): Log interval. + i_sample (int): Sample interval. + i_save (int): Save interval. + i_ddpcheck (int): DDP check interval. + + t_schedule (dict): Time schedule for flow matching. + sigma_min (float): Minimum noise level. + """ + + def prepare_dataloader(self, **kwargs): + """ + Prepare dataloader. + """ + self.data_sampler = BalancedResumableSampler( + self.dataset, + shuffle=True, + batch_size=self.batch_size_per_gpu, + ) + self.dataloader = DataLoader( + self.dataset, + batch_size=self.batch_size_per_gpu, + num_workers=int(np.ceil(os.cpu_count() / torch.cuda.device_count())), + pin_memory=True, + drop_last=True, + persistent_workers=True, + collate_fn=functools.partial(self.dataset.collate_fn, split_size=self.batch_split), + sampler=self.data_sampler, + ) + self.data_iterator = cycle(self.dataloader) + + def training_losses( + self, + x_0: sp.SparseTensor, + cond=None, + **kwargs + ) -> Tuple[Dict, Dict]: + """ + Compute training losses for a single timestep. + + Args: + x_0: The [N x ... x C] sparse tensor of the inputs. + cond: The [N x ...] tensor of additional conditions. + kwargs: Additional arguments to pass to the backbone. + + Returns: + a dict with the key "loss" containing a tensor of shape [N]. + may also contain other keys for different terms. + """ + noise = x_0.replace(torch.randn_like(x_0.feats)) + t = self.sample_t(x_0.shape[0]).to(x_0.device).float() + x_t = self.diffuse(x_0, t, noise=noise) + cond = self.get_cond(cond, **kwargs) + + pred = self.training_models['denoiser'](x_t, t * 1000, cond, **kwargs) + assert pred.shape == noise.shape == x_0.shape + target = self.get_v(x_0, noise, t) + terms = edict() + terms["mse"] = F.mse_loss(pred.feats, target.feats) + terms["loss"] = terms["mse"] + + # log loss with time bins + mse_per_instance = np.array([ + F.mse_loss(pred.feats[x_0.layout[i]], target.feats[x_0.layout[i]]).item() + for i in range(x_0.shape[0]) + ]) + time_bin = np.digitize(t.cpu().numpy(), np.linspace(0, 1, 11)) - 1 + for i in range(10): + if (time_bin == i).sum() != 0: + terms[f"bin_{i}"] = {"mse": mse_per_instance[time_bin == i].mean()} + + return terms, {} + + @torch.no_grad() + def run_snapshot( + self, + num_samples: int, + batch_size: int, + verbose: bool = False, + ) -> Dict: + dataloader = DataLoader( + copy.deepcopy(self.dataset), + batch_size=num_samples, + shuffle=True, + num_workers=0, + collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None, + ) + data = next(iter(dataloader)) + + # inference + sampler = self.get_sampler() + sample = [] + cond_vis = [] + for i in range(0, num_samples, batch_size): + batch_data = {k: v[i:i+batch_size] for k, v in data.items()} + batch_data = recursive_to_device(batch_data, 'cuda') + noise = batch_data['x_0'].replace(torch.randn_like(batch_data['x_0'].feats)) + cond_vis.append(self.vis_cond(**batch_data)) + del batch_data['x_0'] + args = self.get_inference_cond(**batch_data) + res = sampler.sample( + self.models['denoiser'], + noise=noise, + **args, + steps=12, guidance_strength=3.0, verbose=verbose, + ) + sample.append(res.samples) + sample = sp.sparse_cat(sample) + + sample_gt = {k: v for k, v in data.items()} + sample = {k: v if k != 'x_0' else sample for k, v in data.items()} + sample_dict = { + 'sample_gt': {'value': sample_gt, 'type': 'sample'}, + 'sample': {'value': sample, 'type': 'sample'}, + } + sample_dict.update(dict_reduce(cond_vis, None, { + 'value': lambda x: torch.cat(x, dim=0), + 'type': lambda x: x[0], + })) + + return sample_dict + + +class SparseFlowMatchingCFGTrainer(ClassifierFreeGuidanceMixin, SparseFlowMatchingTrainer): + """ + Trainer for sparse diffusion model with flow matching objective and classifier-free guidance. + + Args: + models (dict[str, nn.Module]): Models to train. + dataset (torch.utils.data.Dataset): Dataset. + output_dir (str): Output directory. + load_dir (str): Load directory. + step (int): Step to load. + batch_size (int): Batch size. + batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. + batch_split (int): Split batch with gradient accumulation. + max_steps (int): Max steps. + optimizer (dict): Optimizer config. + lr_scheduler (dict): Learning rate scheduler config. + elastic (dict): Elastic memory management config. + grad_clip (float or dict): Gradient clip config. + ema_rate (float or list): Exponential moving average rates. + fp16_mode (str): FP16 mode. + - None: No FP16. + - 'inflat_all': Hold a inflated fp32 master param for all params. + - 'amp': Automatic mixed precision. + fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. + finetune_ckpt (dict): Finetune checkpoint. + log_param_stats (bool): Log parameter stats. + i_print (int): Print interval. + i_log (int): Log interval. + i_sample (int): Sample interval. + i_save (int): Save interval. + i_ddpcheck (int): DDP check interval. + + t_schedule (dict): Time schedule for flow matching. + sigma_min (float): Minimum noise level. + p_uncond (float): Probability of dropping conditions. + """ + pass + + +class TextConditionedSparseFlowMatchingCFGTrainer(TextConditionedMixin, SparseFlowMatchingCFGTrainer): + """ + Trainer for sparse text-conditioned diffusion model with flow matching objective and classifier-free guidance. + + Args: + models (dict[str, nn.Module]): Models to train. + dataset (torch.utils.data.Dataset): Dataset. + output_dir (str): Output directory. + load_dir (str): Load directory. + step (int): Step to load. + batch_size (int): Batch size. + batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. + batch_split (int): Split batch with gradient accumulation. + max_steps (int): Max steps. + optimizer (dict): Optimizer config. + lr_scheduler (dict): Learning rate scheduler config. + elastic (dict): Elastic memory management config. + grad_clip (float or dict): Gradient clip config. + ema_rate (float or list): Exponential moving average rates. + fp16_mode (str): FP16 mode. + - None: No FP16. + - 'inflat_all': Hold a inflated fp32 master param for all params. + - 'amp': Automatic mixed precision. + fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. + finetune_ckpt (dict): Finetune checkpoint. + log_param_stats (bool): Log parameter stats. + i_print (int): Print interval. + i_log (int): Log interval. + i_sample (int): Sample interval. + i_save (int): Save interval. + i_ddpcheck (int): DDP check interval. + + t_schedule (dict): Time schedule for flow matching. + sigma_min (float): Minimum noise level. + p_uncond (float): Probability of dropping conditions. + text_cond_model(str): Text conditioning model. + """ + pass + + +class ImageConditionedSparseFlowMatchingCFGTrainer(ImageConditionedMixin, SparseFlowMatchingCFGTrainer): + """ + Trainer for sparse image-conditioned diffusion model with flow matching objective and classifier-free guidance. + + Args: + models (dict[str, nn.Module]): Models to train. + dataset (torch.utils.data.Dataset): Dataset. + output_dir (str): Output directory. + load_dir (str): Load directory. + step (int): Step to load. + batch_size (int): Batch size. + batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. + batch_split (int): Split batch with gradient accumulation. + max_steps (int): Max steps. + optimizer (dict): Optimizer config. + lr_scheduler (dict): Learning rate scheduler config. + elastic (dict): Elastic memory management config. + grad_clip (float or dict): Gradient clip config. + ema_rate (float or list): Exponential moving average rates. + fp16_mode (str): FP16 mode. + - None: No FP16. + - 'inflat_all': Hold a inflated fp32 master param for all params. + - 'amp': Automatic mixed precision. + fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. + finetune_ckpt (dict): Finetune checkpoint. + log_param_stats (bool): Log parameter stats. + i_print (int): Print interval. + i_log (int): Log interval. + i_sample (int): Sample interval. + i_save (int): Save interval. + i_ddpcheck (int): DDP check interval. + + t_schedule (dict): Time schedule for flow matching. + sigma_min (float): Minimum noise level. + p_uncond (float): Probability of dropping conditions. + image_cond_model (str): Image conditioning model. + """ + pass + + +class MultiImageConditionedSparseFlowMatchingCFGTrainer(MultiImageConditionedMixin, SparseFlowMatchingCFGTrainer): + """ + Trainer for sparse image-conditioned diffusion model with flow matching objective and classifier-free guidance. + + Args: + models (dict[str, nn.Module]): Models to train. + dataset (torch.utils.data.Dataset): Dataset. + output_dir (str): Output directory. + load_dir (str): Load directory. + step (int): Step to load. + batch_size (int): Batch size. + batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. + batch_split (int): Split batch with gradient accumulation. + max_steps (int): Max steps. + optimizer (dict): Optimizer config. + lr_scheduler (dict): Learning rate scheduler config. + elastic (dict): Elastic memory management config. + grad_clip (float or dict): Gradient clip config. + ema_rate (float or list): Exponential moving average rates. + fp16_mode (str): FP16 mode. + - None: No FP16. + - 'inflat_all': Hold a inflated fp32 master param for all params. + - 'amp': Automatic mixed precision. + fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. + finetune_ckpt (dict): Finetune checkpoint. + log_param_stats (bool): Log parameter stats. + i_print (int): Print interval. + i_log (int): Log interval. + i_sample (int): Sample interval. + i_save (int): Save interval. + i_ddpcheck (int): DDP check interval. + + t_schedule (dict): Time schedule for flow matching. + sigma_min (float): Minimum noise level. + p_uncond (float): Probability of dropping conditions. + image_cond_model (str): Image conditioning model. + """ + pass diff --git a/trellis2/trainers/utils.py b/trellis2/trainers/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..485cfa092274d2fab632380d2c303e36916a6d00 --- /dev/null +++ b/trellis2/trainers/utils.py @@ -0,0 +1,91 @@ +import torch +import torch.nn as nn + +# FP16 utils +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + + +def str_to_dtype(dtype_str: str): + return { + 'f16': torch.float16, + 'fp16': torch.float16, + 'float16': torch.float16, + 'bf16': torch.bfloat16, + 'bfloat16': torch.bfloat16, + 'f32': torch.float32, + 'fp32': torch.float32, + 'float32': torch.float32, + }[dtype_str] + + +def make_master_params(model_params): + """ + Copy model parameters into a inflated tensor of full-precision parameters. + """ + master_params = _flatten_dense_tensors( + [param.detach().float() for param in model_params] + ) + master_params = nn.Parameter(master_params) + master_params.requires_grad = True + return [master_params] + + +def unflatten_master_params(model_params, master_params): + """ + Unflatten the master parameters to look like model_params. + """ + return _unflatten_dense_tensors(master_params[0].detach(), model_params) + + +def model_params_to_master_params(model_params, master_params): + """ + Copy the model parameter data into the master parameters. + """ + master_params[0].detach().copy_( + _flatten_dense_tensors([param.detach().float() for param in model_params]) + ) + + +def master_params_to_model_params(model_params, master_params): + """ + Copy the master parameter data back into the model parameters. + """ + for param, master_param in zip( + model_params, _unflatten_dense_tensors(master_params[0].detach(), model_params) + ): + param.detach().copy_(master_param) + + +def model_grads_to_master_grads(model_params, master_params): + """ + Copy the gradients from the model parameters into the master parameters + from make_master_params(). + """ + master_params[0].grad = _flatten_dense_tensors( + [param.grad.data.detach().float() for param in model_params] + ) + + +def zero_grad(model_params): + for param in model_params: + if param.grad is not None: + if param.grad.grad_fn is not None: + param.grad.detach_() + else: + param.grad.requires_grad_(False) + param.grad.zero_() + + +# LR Schedulers +from torch.optim.lr_scheduler import LambdaLR + +class LinearWarmupLRScheduler(LambdaLR): + def __init__(self, optimizer, warmup_steps, last_epoch=-1): + self.warmup_steps = warmup_steps + super(LinearWarmupLRScheduler, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) + + def lr_lambda(self, current_step): + if current_step < self.warmup_steps: + return float(current_step + 1) / self.warmup_steps + return 1.0 + \ No newline at end of file diff --git a/trellis2/trainers/vae/pbr_vae.py b/trellis2/trainers/vae/pbr_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..c477027f0bc529964c0f793dff38b7269cfe0af3 --- /dev/null +++ b/trellis2/trainers/vae/pbr_vae.py @@ -0,0 +1,281 @@ +from typing import * +import os +import copy +import functools +import numpy as np +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader +import utils3d +from easydict import EasyDict as edict + +from ..basic import BasicTrainer +from ...modules import sparse as sp +from ...renderers import MeshRenderer +from ...representations import Mesh, MeshWithPbrMaterial, MeshWithVoxel +from ...utils.data_utils import recursive_to_device, cycle, BalancedResumableSampler +from ...utils.loss_utils import l1_loss, l2_loss, ssim, lpips + + +class PbrVaeTrainer(BasicTrainer): + """ + Trainer for PBR attributes VAE + + Args: + models (dict[str, nn.Module]): Models to train. + dataset (torch.utils.data.Dataset): Dataset. + output_dir (str): Output directory. + load_dir (str): Load directory. + step (int): Step to load. + batch_size (int): Batch size. + batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. + batch_split (int): Split batch with gradient accumulation. + max_steps (int): Max steps. + optimizer (dict): Optimizer config. + lr_scheduler (dict): Learning rate scheduler config. + elastic (dict): Elastic memory management config. + grad_clip (float or dict): Gradient clip config. + ema_rate (float or list): Exponential moving average rates. + fp16_mode (str): FP16 mode. + - None: No FP16. + - 'inflat_all': Hold a inflated fp32 master param for all params. + - 'amp': Automatic mixed precision. + fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. + finetune_ckpt (dict): Finetune checkpoint. + log_param_stats (bool): Log parameter stats. + i_print (int): Print interval. + i_log (int): Log interval. + i_sample (int): Sample interval. + i_save (int): Save interval. + i_ddpcheck (int): DDP check interval. + + loss_type (str): Loss type. + lambda_kl (float): KL loss weight. + lambda_ssim (float): SSIM loss weight. + lambda_lpips (float): LPIPS loss weight. + """ + + def __init__( + self, + *args, + loss_type: str = 'l1', + lambda_kl: float = 1e-6, + lambda_ssim: float = 0.2, + lambda_lpips: float = 0.2, + lambda_render: float = 1.0, + render_resolution: float = 1024, + camera_randomization_config: dict = { + 'radius_range': [2, 100], + }, + **kwargs + ): + super().__init__(*args, **kwargs) + self.loss_type = loss_type + self.lambda_kl = lambda_kl + self.lambda_ssim = lambda_ssim + self.lambda_lpips = lambda_lpips + self.lambda_render = lambda_render + self.camera_randomization_config = camera_randomization_config + + self.renderer = MeshRenderer({'near': 1, 'far': 3, 'resolution': render_resolution}, device=self.device) + + def prepare_dataloader(self, **kwargs): + """ + Prepare dataloader. + """ + self.data_sampler = BalancedResumableSampler( + self.dataset, + shuffle=True, + batch_size=self.batch_size_per_gpu, + ) + self.dataloader = DataLoader( + self.dataset, + batch_size=self.batch_size_per_gpu, + num_workers=int(np.ceil(os.cpu_count() / torch.cuda.device_count())), + pin_memory=True, + drop_last=True, + persistent_workers=True, + collate_fn=functools.partial(self.dataset.collate_fn, split_size=self.batch_split), + sampler=self.data_sampler, + ) + self.data_iterator = cycle(self.dataloader) + + def _randomize_camera(self, num_samples: int): + # sample radius and fov + r_min, r_max = self.camera_randomization_config['radius_range'] + k_min = 1 / r_max**2 + k_max = 1 / r_min**2 + ks = torch.rand(num_samples, device=self.device) * (k_max - k_min) + k_min + radius = 1 / torch.sqrt(ks) + fov = 2 * torch.arcsin(0.5 / radius) + origin = radius.unsqueeze(-1) * F.normalize(torch.randn(num_samples, 3, device=self.device), dim=-1) + + # build camera + extrinsics = utils3d.torch.extrinsics_look_at(origin, torch.zeros_like(origin), torch.tensor([0, 0, 1], dtype=torch.float32, device=self.device)) + intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov) + near = [np.random.uniform(r - 1, r) for r in radius.tolist()] + + return { + 'extrinsics': extrinsics, + 'intrinsics': intrinsics, + 'near': near, + } + + def _render_batch(self, reps: List[Mesh], extrinsics: torch.Tensor, intrinsics: torch.Tensor, near: List, + ) -> Dict[str, torch.Tensor]: + """ + Render a batch of representations. + + Args: + reps: The dictionary of lists of representations. + extrinsics: The [N x 4 x 4] tensor of extrinsics. + intrinsics: The [N x 3 x 3] tensor of intrinsics. + + Returns: + a dict with + base_color : [N x 3 x H x W] tensor of base color. + metallic : [N x 1 x H x W] tensor of metallic. + roughness : [N x 1 x H x W] tensor of roughness. + alpha : [N x 1 x H x W] tensor of alpha. + """ + ret = {k : [] for k in ['base_color', 'metallic', 'roughness', 'alpha']} + for i, rep in enumerate(reps): + self.renderer.rendering_options['near'] = near[i] + self.renderer.rendering_options['far'] = near[i] + 2 + out_dict = self.renderer.render(rep, extrinsics[i], intrinsics[i], return_types=['attr']) + for k in out_dict: + ret[k].append(out_dict[k]) + for k in ret: + ret[k] = torch.stack(ret[k]) + return ret + + def training_losses( + self, + x: sp.SparseTensor, + mesh: List[MeshWithPbrMaterial] = None, + **kwargs + ) -> Tuple[Dict, Dict]: + """ + Compute training losses. + + Args: + x (SparseTensor): Input sparse tensor for pbr materials. + mesh (List[MeshWithPbrMaterial]): The list of meshes with PBR materials. + + Returns: + a dict with the key "loss" containing a scalar tensor. + may also contain other keys for different terms. + + """ + z, mean, logvar = self.training_models['encoder'](x, sample_posterior=True, return_raw=True) + y = self.training_models['decoder'](z) + + terms = edict(loss = 0.0) + + # direct regression + if self.loss_type == 'l1': + terms["l1"] = l1_loss(x.feats, y.feats) + terms["loss"] = terms["loss"] + terms["l1"] + elif self.loss_type == 'l2': + terms["l2"] = l2_loss(x.feats, y.feats) + terms["loss"] = terms["loss"] + terms["l2"] + else: + raise ValueError(f'Invalid loss type {self.loss_type}') + + # rendering loss + if self.lambda_render != 0.0: + recon = [MeshWithVoxel( + m.vertices, + m.faces, + [-0.5, -0.5, -0.5], + 1 / self.dataset.resolution, + v.coords[:, 1:], + v.feats * 0.5 + 0.5, + torch.Size([*v.shape, *v.spatial_shape]), + layout={ + 'base_color': slice(0, 3), + 'metallic': slice(3, 4), + 'roughness': slice(4, 5), + 'alpha': slice(5, 6), + } + ) for m, v in zip(mesh, y)] + cameras = self._randomize_camera(len(mesh)) + gt_renders = self._render_batch(mesh, **cameras) + pred_renders = self._render_batch(recon, **cameras) + gt_base_color = gt_renders['base_color'] + pred_base_color = pred_renders['base_color'] + gt_mra = torch.cat([gt_renders['metallic'], gt_renders['roughness'], gt_renders['alpha']], dim=1) + pred_mra = torch.cat([pred_renders['metallic'], pred_renders['roughness'], pred_renders['alpha']], dim=1) + terms['render/base_color/ssim'] = 1 - ssim(pred_base_color, gt_base_color) + terms['render/base_color/lpips'] = lpips(pred_base_color, gt_base_color) + terms['render/mra/ssim'] = 1 - ssim(pred_mra, gt_mra) + terms['render/mra/lpips'] = lpips(pred_mra, gt_mra) + terms['loss'] = terms['loss'] + \ + self.lambda_render * (self.lambda_ssim * terms['render/base_color/ssim'] + self.lambda_lpips * terms['render/base_color/lpips'] + \ + self.lambda_ssim * terms['render/mra/ssim'] + self.lambda_lpips * terms['render/mra/lpips']) + + # KL regularization + terms["kl"] = 0.5 * torch.mean(mean.pow(2) + logvar.exp() - logvar - 1) + terms["loss"] = terms["loss"] + self.lambda_kl * terms["kl"] + + return terms, {} + + @torch.no_grad() + def run_snapshot( + self, + num_samples: int, + batch_size: int, + verbose: bool = False, + ) -> Dict: + dataloader = DataLoader( + copy.deepcopy(self.dataset), + batch_size=batch_size, + shuffle=True, + num_workers=1, + collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None, + ) + dataloader.dataset.with_mesh = True + + # inference + gts = [] + recons = [] + self.models['encoder'].eval() + self.models['decoder'].eval() + for i in range(0, num_samples, batch_size): + batch = min(batch_size, num_samples - i) + data = next(iter(dataloader)) + args = {k: v[:batch] for k, v in data.items()} + args = recursive_to_device(args, self.device) + z = self.models['encoder'](args['x']) + y = self.models['decoder'](z) + gts.extend(args['mesh']) + recons.extend([MeshWithVoxel( + m.vertices, + m.faces, + [-0.5, -0.5, -0.5], + 1 / self.dataset.resolution, + v.coords[:, 1:], + v.feats * 0.5 + 0.5, + torch.Size([*v.shape, *v.spatial_shape]), + layout={ + 'base_color': slice(0, 3), + 'metallic': slice(3, 4), + 'roughness': slice(4, 5), + 'alpha': slice(5, 6), + } + ) for m, v in zip(args['mesh'], y)]) + self.models['encoder'].train() + self.models['decoder'].train() + + cameras = self._randomize_camera(num_samples) + gt_renders = self._render_batch(gts, **cameras) + pred_renders = self._render_batch(recons, **cameras) + + sample_dict = { + 'gt_base_color': {'value': gt_renders['base_color'] * 2 - 1, 'type': 'image'}, + 'pred_base_color': {'value': pred_renders['base_color'] * 2 - 1, 'type': 'image'}, + 'gt_mra': {'value': torch.cat([gt_renders['metallic'], gt_renders['roughness'], gt_renders['alpha']], dim=1) * 2 - 1, 'type': 'image'}, + 'pred_mra': {'value': torch.cat([pred_renders['metallic'], pred_renders['roughness'], pred_renders['alpha']], dim=1) * 2 - 1, 'type': 'image'}, + } + + return sample_dict diff --git a/trellis2/trainers/vae/shape_vae.py b/trellis2/trainers/vae/shape_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..0fd36879b4859725a8b9454155865d18f7ae556c --- /dev/null +++ b/trellis2/trainers/vae/shape_vae.py @@ -0,0 +1,266 @@ +from typing import * +import os +import copy +import functools +import numpy as np +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader +import utils3d +from easydict import EasyDict as edict + +from ..basic import BasicTrainer +from ...modules import sparse as sp +from ...renderers import MeshRenderer +from ...representations import Mesh +from ...utils.data_utils import recursive_to_device, cycle, BalancedResumableSampler +from ...utils.loss_utils import l1_loss, ssim, lpips + + +class ShapeVaeTrainer(BasicTrainer): + """ + Trainer for Shape VAE + + Args: + models (dict[str, nn.Module]): Models to train. + dataset (torch.utils.data.Dataset): Dataset. + output_dir (str): Output directory. + load_dir (str): Load directory. + step (int): Step to load. + batch_size (int): Batch size. + batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. + batch_split (int): Split batch with gradient accumulation. + max_steps (int): Max steps. + optimizer (dict): Optimizer config. + lr_scheduler (dict): Learning rate scheduler config. + elastic (dict): Elastic memory management config. + grad_clip (float or dict): Gradient clip config. + ema_rate (float or list): Exponential moving average rates. + fp16_mode (str): FP16 mode. + - None: No FP16. + - 'inflat_all': Hold a inflated fp32 master param for all params. + - 'amp': Automatic mixed precision. + fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. + finetune_ckpt (dict): Finetune checkpoint. + log_param_stats (bool): Log parameter stats. + i_print (int): Print interval. + i_log (int): Log interval. + i_sample (int): Sample interval. + i_save (int): Save interval. + i_ddpcheck (int): DDP check interval. + + lambda_subdiv (float): Subdivision loss weight. + lambda_intersected (float): Intersected loss weight. + lambda_vertice (float): Vertice loss weight. + lambda_kl (float): KL loss weight. + lambda_ssim (float): SSIM loss weight. + lambda_lpips (float): LPIPS loss weight. + """ + + def __init__( + self, + *args, + lambda_subdiv: float = 0.1, + lambda_intersected: float = 0.1, + lambda_vertice: float = 1e-2, + lambda_mask: float = 1, + lambda_depth: float = 10, + lambda_normal: float = 1, + lambda_kl: float = 1e-6, + lambda_ssim: float = 0.2, + lambda_lpips: float = 0.2, + render_resolution: float = 1024, + camera_randomization_config: dict = { + 'radius_range': [2, 100], + }, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lambda_subdiv = lambda_subdiv + self.lambda_intersected = lambda_intersected + self.lambda_mask = lambda_mask + self.lambda_vertice = lambda_vertice + self.lambda_depth = lambda_depth + self.lambda_normal = lambda_normal + self.lambda_kl = lambda_kl + self.lambda_ssim = lambda_ssim + self.lambda_lpips = lambda_lpips + self.camera_randomization_config = camera_randomization_config + + self.renderer = MeshRenderer({'near': 1, 'far': 3, 'resolution': render_resolution}, device=self.device) + + def prepare_dataloader(self, **kwargs): + """ + Prepare dataloader. + """ + self.data_sampler = BalancedResumableSampler( + self.dataset, + shuffle=True, + batch_size=self.batch_size_per_gpu, + ) + self.dataloader = DataLoader( + self.dataset, + batch_size=self.batch_size_per_gpu, + num_workers=int(np.ceil(os.cpu_count() / torch.cuda.device_count())), + pin_memory=True, + drop_last=True, + persistent_workers=True, + collate_fn=functools.partial(self.dataset.collate_fn, split_size=self.batch_split), + sampler=self.data_sampler, + ) + self.data_iterator = cycle(self.dataloader) + + def _randomize_camera(self, num_samples: int): + # sample radius and fov + r_min, r_max = self.camera_randomization_config['radius_range'] + k_min = 1 / r_max**2 + k_max = 1 / r_min**2 + ks = torch.rand(num_samples, device=self.device) * (k_max - k_min) + k_min + radius = 1 / torch.sqrt(ks) + fov = 2 * torch.arcsin(0.5 / radius) + origin = radius.unsqueeze(-1) * F.normalize(torch.randn(num_samples, 3, device=self.device), dim=-1) + + # build camera + extrinsics = utils3d.torch.extrinsics_look_at(origin, torch.zeros_like(origin), torch.tensor([0, 0, 1], dtype=torch.float32, device=self.device)) + intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov) + near = [np.random.uniform(r - 1, r) for r in radius.tolist()] + + return { + 'extrinsics': extrinsics, + 'intrinsics': intrinsics, + 'near': near, + } + + def _render_batch(self, reps: List[Mesh], extrinsics: torch.Tensor, intrinsics: torch.Tensor, near: List, + return_types=['mask', 'normal', 'depth']) -> Dict[str, torch.Tensor]: + """ + Render a batch of representations. + + Args: + reps: The dictionary of lists of representations. + extrinsics: The [N x 4 x 4] tensor of extrinsics. + intrinsics: The [N x 3 x 3] tensor of intrinsics. + return_types: vary in ['mask', 'normal', 'depth', 'normal_map', 'color'] + + Returns: + a dict with + mask : [N x 1 x H x W] tensor of rendered masks + normal : [N x 3 x H x W] tensor of rendered normals + depth : [N x 1 x H x W] tensor of rendered depths + """ + ret = {k : [] for k in return_types} + for i, rep in enumerate(reps): + self.renderer.rendering_options['near'] = near[i] + self.renderer.rendering_options['far'] = near[i] + 2 + out_dict = self.renderer.render(rep, extrinsics[i], intrinsics[i], return_types=return_types) + for k in out_dict: + ret[k].append(out_dict[k][None] if k in ['mask', 'depth'] else out_dict[k]) + for k in ret: + ret[k] = torch.stack(ret[k]) + return ret + + def training_losses( + self, + vertices: sp.SparseTensor, + intersected: sp.SparseTensor, + mesh: List[Mesh], + ) -> Tuple[Dict, Dict]: + """ + Compute training losses. + + Args: + vertices (SparseTensor): vertices of each active voxel + intersected (SparseTensor): intersected flag of each active voxel + mesh (List[Mesh]): the list of meshes to render + + Returns: + a dict with the key "loss" containing a scalar tensor. + may also contain other keys for different terms. + """ + z, mean, logvar = self.training_models['encoder'](vertices, intersected, sample_posterior=True, return_raw=True) + recon, pred_vertice, pred_intersected, subs_gt, subs = self.training_models['decoder'](z, intersected) + + terms = edict(loss = 0.0) + + # direct regression + if self.lambda_intersected > 0: + terms["direct/intersected"] = F.binary_cross_entropy_with_logits(pred_intersected.feats.flatten(), intersected.feats.flatten().float()) + terms["loss"] = terms["loss"] + self.lambda_intersected * terms["direct/intersected"] + if self.lambda_vertice > 0: + terms["direct/vertice"] = F.mse_loss(pred_vertice.feats, vertices.feats) + terms["loss"] = terms["loss"] + self.lambda_vertice * terms["direct/vertice"] + + # subdivision prediction loss + for i, (sub_gt, sub) in enumerate(zip(subs_gt, subs)): + terms[f"bce_sub{i}"] = F.binary_cross_entropy_with_logits(sub.feats, sub_gt.float()) + terms["loss"] = terms["loss"] + self.lambda_subdiv * terms[f"bce_sub{i}"] + + # rendering loss + cameras = self._randomize_camera(len(mesh)) + gt_renders = self._render_batch(mesh, **cameras, return_types=['mask', 'normal', 'depth']) + pred_renders = self._render_batch(recon, **cameras, return_types=['mask', 'normal', 'depth']) + terms['render/mask'] = l1_loss(pred_renders['mask'], gt_renders['mask']) + terms['render/depth'] = l1_loss(pred_renders['depth'], gt_renders['depth']) + terms['render/normal/l1'] = l1_loss(pred_renders['normal'], gt_renders['normal']) + terms['render/normal/ssim'] = 1 - ssim(pred_renders['normal'], gt_renders['normal']) + terms['render/normal/lpips'] = lpips(pred_renders['normal'], gt_renders['normal']) + terms['loss'] = terms['loss'] + \ + self.lambda_mask * terms['render/mask'] + \ + self.lambda_depth * terms['render/depth'] + \ + self.lambda_normal * (terms['render/normal/l1'] + self.lambda_ssim * terms['render/normal/ssim'] + self.lambda_lpips * terms['render/normal/lpips']) + + # KL regularization + terms["kl"] = 0.5 * torch.mean(mean.pow(2) + logvar.exp() - logvar - 1) + terms["loss"] = terms["loss"] + self.lambda_kl * terms["kl"] + + return terms, {} + + @torch.no_grad() + def run_snapshot( + self, + num_samples: int, + batch_size: int, + verbose: bool = False, + ) -> Dict: + dataloader = DataLoader( + copy.deepcopy(self.dataset), + batch_size=batch_size, + shuffle=True, + num_workers=1, + collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None, + ) + + # inference + gts = [] + recons = [] + recons2 = [] + self.models['encoder'].eval() + for i in range(0, num_samples, batch_size): + batch = min(batch_size, num_samples - i) + data = next(iter(dataloader)) + args = {k: v[:batch] for k, v in data.items()} + args = recursive_to_device(args, self.device) + z = self.models['encoder'](args['vertices'], args['intersected']) + self.models['decoder'].train() + y = self.models['decoder'](z, args['intersected'])[0] + z.clear_spatial_cache() + self.models['decoder'].eval() + y2 = self.models['decoder'](z) + gts.extend(args['mesh']) + recons.extend(y) + recons2.extend(y2) + self.models['encoder'].train() + self.models['decoder'].train() + + cameras = self._randomize_camera(num_samples) + gt_renders = self._render_batch(gts, **cameras, return_types=['normal']) + recons_renders = self._render_batch(recons, **cameras, return_types=['normal']) + recons2_renders = self._render_batch(recons2, **cameras, return_types=['normal']) + + sample_dict = { + 'gt': {'value': gt_renders['normal'], 'type': 'image'}, + 'rec': {'value': recons_renders['normal'], 'type': 'image'}, + 'rec2': {'value': recons2_renders['normal'], 'type': 'image'}, + } + + return sample_dict diff --git a/trellis2/trainers/vae/sparse_structure_vae.py b/trellis2/trainers/vae/sparse_structure_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..9233d216d23b072cca4bda85d86be3289740532b --- /dev/null +++ b/trellis2/trainers/vae/sparse_structure_vae.py @@ -0,0 +1,130 @@ +from typing import * +import copy +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader +from easydict import EasyDict as edict + +from ..basic import BasicTrainer + + +class SparseStructureVaeTrainer(BasicTrainer): + """ + Trainer for Sparse Structure VAE. + + Args: + models (dict[str, nn.Module]): Models to train. + dataset (torch.utils.data.Dataset): Dataset. + output_dir (str): Output directory. + load_dir (str): Load directory. + step (int): Step to load. + batch_size (int): Batch size. + batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. + batch_split (int): Split batch with gradient accumulation. + max_steps (int): Max steps. + optimizer (dict): Optimizer config. + lr_scheduler (dict): Learning rate scheduler config. + elastic (dict): Elastic memory management config. + grad_clip (float or dict): Gradient clip config. + ema_rate (float or list): Exponential moving average rates. + fp16_mode (str): FP16 mode. + - None: No FP16. + - 'inflat_all': Hold a inflated fp32 master param for all params. + - 'amp': Automatic mixed precision. + fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. + finetune_ckpt (dict): Finetune checkpoint. + log_param_stats (bool): Log parameter stats. + i_print (int): Print interval. + i_log (int): Log interval. + i_sample (int): Sample interval. + i_save (int): Save interval. + i_ddpcheck (int): DDP check interval. + + loss_type (str): Loss type. 'bce' for binary cross entropy, 'l1' for L1 loss, 'dice' for Dice loss. + lambda_kl (float): KL divergence loss weight. + """ + + def __init__( + self, + *args, + loss_type='bce', + lambda_kl=1e-6, + **kwargs + ): + super().__init__(*args, **kwargs) + self.loss_type = loss_type + self.lambda_kl = lambda_kl + + def training_losses( + self, + ss: torch.Tensor, + **kwargs + ) -> Tuple[Dict, Dict]: + """ + Compute training losses. + + Args: + ss: The [N x 1 x H x W x D] tensor of binary sparse structure. + + Returns: + a dict with the key "loss" containing a scalar tensor. + may also contain other keys for different terms. + """ + z, mean, logvar = self.training_models['encoder'](ss.float(), sample_posterior=True, return_raw=True) + logits = self.training_models['decoder'](z) + + terms = edict(loss = 0.0) + if self.loss_type == 'bce': + terms["bce"] = F.binary_cross_entropy_with_logits(logits, ss.float(), reduction='mean') + terms["loss"] = terms["loss"] + terms["bce"] + elif self.loss_type == 'l1': + terms["l1"] = F.l1_loss(F.sigmoid(logits), ss.float(), reduction='mean') + terms["loss"] = terms["loss"] + terms["l1"] + elif self.loss_type == 'dice': + logits = F.sigmoid(logits) + terms["dice"] = 1 - (2 * (logits * ss.float()).sum() + 1) / (logits.sum() + ss.float().sum() + 1) + terms["loss"] = terms["loss"] + terms["dice"] + else: + raise ValueError(f'Invalid loss type {self.loss_type}') + terms["kl"] = 0.5 * torch.mean(mean.pow(2) + logvar.exp() - logvar - 1) + terms["loss"] = terms["loss"] + self.lamda_kl * terms["kl"] + + return terms, {} + + @torch.no_grad() + def snapshot(self, suffix=None, num_samples=64, batch_size=1, verbose=False): + super().snapshot(suffix=suffix, num_samples=num_samples, batch_size=batch_size, verbose=verbose) + + @torch.no_grad() + def run_snapshot( + self, + num_samples: int, + batch_size: int, + verbose: bool = False, + ) -> Dict: + dataloader = DataLoader( + copy.deepcopy(self.dataset), + batch_size=batch_size, + shuffle=True, + num_workers=0, + collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None, + ) + + # inference + gts = [] + recons = [] + for i in range(0, num_samples, batch_size): + batch = min(batch_size, num_samples - i) + data = next(iter(dataloader)) + args = {k: v[:batch].cuda() if isinstance(v, torch.Tensor) else v[:batch] for k, v in data.items()} + z = self.models['encoder'](args['ss'].float(), sample_posterior=False) + logits = self.models['decoder'](z) + recon = (logits > 0).long() + gts.append(args['ss']) + recons.append(recon) + + sample_dict = { + 'gt': {'value': torch.cat(gts, dim=0), 'type': 'sample'}, + 'recon': {'value': torch.cat(recons, dim=0), 'type': 'sample'}, + } + return sample_dict diff --git a/trellis2/utils/__init__.py b/trellis2/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/trellis2/utils/__pycache__/__init__.cpython-311.pyc b/trellis2/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc811204391d369834dd0e5ba5db237db09ed507 Binary files /dev/null and b/trellis2/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/trellis2/utils/__pycache__/elastic_utils.cpython-311.pyc b/trellis2/utils/__pycache__/elastic_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6577b322a4d4486019f4e40796d588355ae7173 Binary files /dev/null and b/trellis2/utils/__pycache__/elastic_utils.cpython-311.pyc differ diff --git a/trellis2/utils/data_utils.py b/trellis2/utils/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f10e4a25afe0b5f1471f3e6856377ec803da3bd6 --- /dev/null +++ b/trellis2/utils/data_utils.py @@ -0,0 +1,226 @@ +from typing import * +import math +import torch +import numpy as np +from torch.utils.data import Sampler, Dataset, DataLoader, DistributedSampler +import torch.distributed as dist + + +def recursive_to_device( + data: Any, + device: torch.device, + non_blocking: bool = False, +) -> Any: + """ + Recursively move all tensors in a data structure to a device. + """ + if hasattr(data, "to"): + return data.to(device, non_blocking=non_blocking) + elif isinstance(data, (list, tuple)): + return type(data)(recursive_to_device(d, device, non_blocking) for d in data) + elif isinstance(data, dict): + return {k: recursive_to_device(v, device, non_blocking) for k, v in data.items()} + else: + return data + + +def load_balanced_group_indices( + load: List[int], + num_groups: int, + equal_size: bool = False, +) -> List[List[int]]: + """ + Split indices into groups with balanced load. + """ + if equal_size: + group_size = len(load) // num_groups + indices = np.argsort(load)[::-1] + groups = [[] for _ in range(num_groups)] + group_load = np.zeros(num_groups) + for idx in indices: + min_group_idx = np.argmin(group_load) + groups[min_group_idx].append(idx) + if equal_size and len(groups[min_group_idx]) == group_size: + group_load[min_group_idx] = float('inf') + else: + group_load[min_group_idx] += load[idx] + return groups + + +def cycle(data_loader: DataLoader) -> Iterator: + while True: + for data in data_loader: + if isinstance(data_loader.sampler, ResumableSampler): + data_loader.sampler.idx += data_loader.batch_size # type: ignore[attr-defined] + yield data + if isinstance(data_loader.sampler, DistributedSampler): + data_loader.sampler.epoch += 1 + if isinstance(data_loader.sampler, ResumableSampler): + data_loader.sampler.epoch += 1 + data_loader.sampler.idx = 0 + + +class ResumableSampler(Sampler): + """ + Distributed sampler that is resumable. + + Args: + dataset: Dataset used for sampling. + rank (int, optional): Rank of the current process within :attr:`num_replicas`. + By default, :attr:`rank` is retrieved from the current distributed + group. + shuffle (bool, optional): If ``True`` (default), sampler will shuffle the + indices. + seed (int, optional): random seed used to shuffle the sampler if + :attr:`shuffle=True`. This number should be identical across all + processes in the distributed group. Default: ``0``. + drop_last (bool, optional): if ``True``, then the sampler will drop the + tail of the data to make it evenly divisible across the number of + replicas. If ``False``, the sampler will add extra indices to make + the data evenly divisible across the replicas. Default: ``False``. + """ + + def __init__( + self, + dataset: Dataset, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, + ) -> None: + self.dataset = dataset + self.epoch = 0 + self.idx = 0 + self.drop_last = drop_last + self.world_size = dist.get_world_size() if dist.is_initialized() else 1 + self.rank = dist.get_rank() if dist.is_initialized() else 0 + # If the dataset length is evenly divisible by # of replicas, then there + # is no need to drop any data, since the dataset will be split equally. + if self.drop_last and len(self.dataset) % self.world_size != 0: # type: ignore[arg-type] + # Split to nearest available length that is evenly divisible. + # This is to ensure each rank receives the same amount of data when + # using this Sampler. + self.num_samples = math.ceil( + (len(self.dataset) - self.world_size) / self.world_size # type: ignore[arg-type] + ) + else: + self.num_samples = math.ceil(len(self.dataset) / self.world_size) # type: ignore[arg-type] + self.total_size = self.num_samples * self.world_size + self.shuffle = shuffle + self.seed = seed + + def __iter__(self) -> Iterator: + if self.shuffle: + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] + else: + indices = list(range(len(self.dataset))) # type: ignore[arg-type] + + if not self.drop_last: + # add extra samples to make it evenly divisible + padding_size = self.total_size - len(indices) + if padding_size <= len(indices): + indices += indices[:padding_size] + else: + indices += (indices * math.ceil(padding_size / len(indices)))[ + :padding_size + ] + else: + # remove tail of data to make it evenly divisible. + indices = indices[: self.total_size] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank : self.total_size : self.world_size] + + # resume from previous state + indices = indices[self.idx:] + + return iter(indices) + + def __len__(self) -> int: + return self.num_samples + + def state_dict(self) -> dict[str, int]: + return { + 'epoch': self.epoch, + 'idx': self.idx, + } + + def load_state_dict(self, state_dict): + self.epoch = state_dict['epoch'] + self.idx = state_dict['idx'] + + +class BalancedResumableSampler(ResumableSampler): + """ + Distributed sampler that is resumable and balances the load among the processes. + + Args: + dataset: Dataset used for sampling. + rank (int, optional): Rank of the current process within :attr:`num_replicas`. + By default, :attr:`rank` is retrieved from the current distributed + group. + shuffle (bool, optional): If ``True`` (default), sampler will shuffle the + indices. + seed (int, optional): random seed used to shuffle the sampler if + :attr:`shuffle=True`. This number should be identical across all + processes in the distributed group. Default: ``0``. + drop_last (bool, optional): if ``True``, then the sampler will drop the + tail of the data to make it evenly divisible across the number of + replicas. If ``False``, the sampler will add extra indices to make + the data evenly divisible across the replicas. Default: ``False``. + """ + + def __init__( + self, + dataset: Dataset, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, + batch_size: int = 1, + ) -> None: + assert hasattr(dataset, 'loads'), 'Dataset must have "loads" attribute to use BalancedResumableSampler' + super().__init__(dataset, shuffle, seed, drop_last) + self.batch_size = batch_size + self.loads = dataset.loads + + def __iter__(self) -> Iterator: + if self.shuffle: + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] + else: + indices = list(range(len(self.dataset))) # type: ignore[arg-type] + + if not self.drop_last: + # add extra samples to make it evenly divisible + padding_size = self.total_size - len(indices) + if padding_size <= len(indices): + indices += indices[:padding_size] + else: + indices += (indices * math.ceil(padding_size / len(indices)))[ + :padding_size + ] + else: + # remove tail of data to make it evenly divisible. + indices = indices[: self.total_size] + assert len(indices) == self.total_size + + # balance load among processes + num_batches = len(indices) // (self.batch_size * self.world_size) + balanced_indices = [] + for i in range(num_batches): + start_idx = i * self.batch_size * self.world_size + end_idx = (i + 1) * self.batch_size * self.world_size + batch_indices = indices[start_idx:end_idx] + batch_loads = [self.loads[idx] for idx in batch_indices] + groups = load_balanced_group_indices(batch_loads, self.world_size, equal_size=True) + balanced_indices.extend([batch_indices[j] for j in groups[self.rank]]) + + # resume from previous state + indices = balanced_indices[self.idx:] + + return iter(indices) diff --git a/trellis2/utils/dist_utils.py b/trellis2/utils/dist_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a4e3d8e7ab0aa0228638fb7d423fa0f28e2e306f --- /dev/null +++ b/trellis2/utils/dist_utils.py @@ -0,0 +1,93 @@ +import os +import io +from contextlib import contextmanager +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP + + +def setup_dist(rank, local_rank, world_size, master_addr, master_port): + os.environ['MASTER_ADDR'] = master_addr + os.environ['MASTER_PORT'] = master_port + os.environ['WORLD_SIZE'] = str(world_size) + os.environ['RANK'] = str(rank) + os.environ['LOCAL_RANK'] = str(local_rank) + torch.cuda.set_device(local_rank) + dist.init_process_group('nccl', rank=rank, world_size=world_size) + + +def read_file_dist(path): + """ + Read the binary file distributedly. + File is only read once by the rank 0 process and broadcasted to other processes. + + Returns: + data (io.BytesIO): The binary data read from the file. + """ + if dist.is_initialized() and dist.get_world_size() > 1: + # read file + size = torch.LongTensor(1).cuda() + if dist.get_rank() == 0: + with open(path, 'rb') as f: + data = f.read() + data = torch.ByteTensor( + torch.UntypedStorage.from_buffer(data, dtype=torch.uint8) + ).cuda() + size[0] = data.shape[0] + # broadcast size + dist.broadcast(size, src=0) + if dist.get_rank() != 0: + data = torch.ByteTensor(size[0].item()).cuda() + # broadcast data + dist.broadcast(data, src=0) + # convert to io.BytesIO + data = data.cpu().numpy().tobytes() + data = io.BytesIO(data) + return data + else: + with open(path, 'rb') as f: + data = f.read() + data = io.BytesIO(data) + return data + + +def unwrap_dist(model): + """ + Unwrap the model from distributed training. + """ + if isinstance(model, DDP): + return model.module + return model + + +@contextmanager +def master_first(): + """ + A context manager that ensures master process executes first. + """ + if not dist.is_initialized(): + yield + else: + if dist.get_rank() == 0: + yield + dist.barrier() + else: + dist.barrier() + yield + + +@contextmanager +def local_master_first(): + """ + A context manager that ensures local master process executes first. + """ + if not dist.is_initialized(): + yield + else: + if dist.get_rank() % torch.cuda.device_count() == 0: + yield + dist.barrier() + else: + dist.barrier() + yield + \ No newline at end of file diff --git a/trellis2/utils/elastic_utils.py b/trellis2/utils/elastic_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..84347059e18a60754966dac86960777e68980b1b --- /dev/null +++ b/trellis2/utils/elastic_utils.py @@ -0,0 +1,228 @@ +from abc import abstractmethod +from contextlib import contextmanager +from typing import Tuple +import torch +import torch.nn as nn +import numpy as np + + +class MemoryController: + """ + Base class for memory management during training. + """ + + _last_input_size = None + _last_mem_ratio = [] + + @contextmanager + def record(self): + pass + + def update_run_states(self, input_size=None, mem_ratio=None): + if self._last_input_size is None: + self._last_input_size = input_size + elif self._last_input_size!= input_size: + raise ValueError(f'Input size should not change for different ElasticModules.') + self._last_mem_ratio.append(mem_ratio) + + @abstractmethod + def get_mem_ratio(self, input_size): + pass + + @abstractmethod + def state_dict(self): + pass + + @abstractmethod + def log(self): + pass + + +class LinearMemoryController(MemoryController): + """ + A simple controller for memory management during training. + The memory usage is modeled as a linear function of: + - the number of input parameters + - the ratio of memory the model use compared to the maximum usage (with no checkpointing) + memory_usage = k * input_size * mem_ratio + b + The controller keeps track of the memory usage and gives the + expected memory ratio to keep the memory usage under a target + """ + def __init__( + self, + buffer_size=1000, + update_every=500, + target_ratio=0.8, + available_memory=None, + max_mem_ratio_start=0.1, + params=None, + device=None + ): + self.buffer_size = buffer_size + self.update_every = update_every + self.target_ratio = target_ratio + self.device = device or torch.cuda.current_device() + self.available_memory = available_memory or torch.cuda.get_device_properties(self.device).total_memory / 1024**3 + + self._memory = np.zeros(buffer_size, dtype=np.float32) + self._input_size = np.zeros(buffer_size, dtype=np.float32) + self._mem_ratio = np.zeros(buffer_size, dtype=np.float32) + self._buffer_ptr = 0 + self._buffer_length = 0 + self._params = tuple(params) if params is not None else (0.0, 0.0) + self._max_mem_ratio = max_mem_ratio_start + self.step = 0 + + def __repr__(self): + return f'LinearMemoryController(target_ratio={self.target_ratio}, available_memory={self.available_memory})' + + def _add_sample(self, memory, input_size, mem_ratio): + self._memory[self._buffer_ptr] = memory + self._input_size[self._buffer_ptr] = input_size + self._mem_ratio[self._buffer_ptr] = mem_ratio + self._buffer_ptr = (self._buffer_ptr + 1) % self.buffer_size + self._buffer_length = min(self._buffer_length + 1, self.buffer_size) + + @contextmanager + def record(self): + torch.cuda.reset_peak_memory_stats(self.device) + self._last_input_size = None + self._last_mem_ratio = [] + yield + self._last_memory = torch.cuda.max_memory_allocated(self.device) / 1024**3 + self._last_mem_ratio = sum(self._last_mem_ratio) / len(self._last_mem_ratio) + self._add_sample(self._last_memory, self._last_input_size, self._last_mem_ratio) + self.step += 1 + if self.step % self.update_every == 0: + self._max_mem_ratio = min(1.0, self._max_mem_ratio + 0.1) + self._fit_params() + + def _fit_params(self): + memory_usage = self._memory[:self._buffer_length] + input_size = self._input_size[:self._buffer_length] + mem_ratio = self._mem_ratio[:self._buffer_length] + + x = input_size * mem_ratio + y = memory_usage + k, b = np.polyfit(x, y, 1) + self._params = (k, b) + # self._visualize() + + def _visualize(self): + import matplotlib.pyplot as plt + memory_usage = self._memory[:self._buffer_length] + input_size = self._input_size[:self._buffer_length] + mem_ratio = self._mem_ratio[:self._buffer_length] + k, b = self._params + + plt.scatter(input_size * mem_ratio, memory_usage, c=mem_ratio, cmap='viridis') + x = np.array([0.0, 20000.0]) + plt.plot(x, k * x + b, c='r') + plt.savefig(f'linear_memory_controller_{self.step}.png') + plt.cla() + + def get_mem_ratio(self, input_size): + k, b = self._params + if k == 0: return np.random.rand() * self._max_mem_ratio + pred = (self.available_memory * self.target_ratio - b) / (k * input_size) + return min(self._max_mem_ratio, max(0.0, pred)) + + def state_dict(self): + return { + 'params': self._params, + } + + def load_state_dict(self, state_dict): + self._params = tuple(state_dict['params']) + + def log(self): + return { + 'params/k': self._params[0], + 'params/b': self._params[1], + 'memory': self._last_memory, + 'input_size': self._last_input_size, + 'mem_ratio': self._last_mem_ratio, + } + + +class ElasticModule(nn.Module): + """ + Module for training with elastic memory management. + """ + def __init__(self): + super().__init__() + self._memory_controller: MemoryController = None + + @abstractmethod + def _get_input_size(self, *args, **kwargs) -> int: + """ + Get the size of the input data. + + Returns: + int: The size of the input data. + """ + pass + + @abstractmethod + def _forward_with_mem_ratio(self, *args, mem_ratio=0.0, **kwargs) -> Tuple[float, Tuple]: + """ + Forward with a given memory ratio. + """ + pass + + def register_memory_controller(self, memory_controller: MemoryController): + self._memory_controller = memory_controller + + def forward(self, *args, **kwargs): + if self._memory_controller is None or not torch.is_grad_enabled() or not self.training: + _, ret = self._forward_with_mem_ratio(*args, **kwargs) + else: + input_size = self._get_input_size(*args, **kwargs) + mem_ratio = self._memory_controller.get_mem_ratio(input_size) + mem_ratio, ret = self._forward_with_mem_ratio(*args, mem_ratio=mem_ratio, **kwargs) + self._memory_controller.update_run_states(input_size, mem_ratio) + return ret + + +class ElasticModuleMixin: + """ + Mixin for training with elastic memory management. + """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._memory_controller: MemoryController = None + + @abstractmethod + def _get_input_size(self, *args, **kwargs) -> int: + """ + Get the size of the input data. + + Returns: + int: The size of the input data. + """ + pass + + @abstractmethod + @contextmanager + def with_mem_ratio(self, mem_ratio=1.0) -> float: + """ + Context manager for training with a reduced memory ratio compared to the full memory usage. + + Returns: + float: The exact memory ratio used during the forward pass. + """ + pass + + def register_memory_controller(self, memory_controller: MemoryController): + self._memory_controller = memory_controller + + def forward(self, *args, **kwargs): + if self._memory_controller is None or not torch.is_grad_enabled() or not self.training: + ret = super().forward(*args, **kwargs) + else: + input_size = self._get_input_size(*args, **kwargs) + mem_ratio = self._memory_controller.get_mem_ratio(input_size) + with self.with_mem_ratio(mem_ratio) as exact_mem_ratio: + ret = super().forward(*args, **kwargs) + self._memory_controller.update_run_states(input_size, exact_mem_ratio) + return ret diff --git a/trellis2/utils/general_utils.py b/trellis2/utils/general_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..94a7cb0d76f9af02e59b56c4b9bf8a3c4d33118d --- /dev/null +++ b/trellis2/utils/general_utils.py @@ -0,0 +1,373 @@ +import re +import numpy as np +import cv2 +import torch +import contextlib + + +# Dictionary utils +def _dict_merge(dicta, dictb, prefix=''): + """ + Merge two dictionaries. + """ + assert isinstance(dicta, dict), 'input must be a dictionary' + assert isinstance(dictb, dict), 'input must be a dictionary' + dict_ = {} + all_keys = set(dicta.keys()).union(set(dictb.keys())) + for key in all_keys: + if key in dicta.keys() and key in dictb.keys(): + if isinstance(dicta[key], dict) and isinstance(dictb[key], dict): + dict_[key] = _dict_merge(dicta[key], dictb[key], prefix=f'{prefix}.{key}') + else: + raise ValueError(f'Duplicate key {prefix}.{key} found in both dictionaries. Types: {type(dicta[key])}, {type(dictb[key])}') + elif key in dicta.keys(): + dict_[key] = dicta[key] + else: + dict_[key] = dictb[key] + return dict_ + + +def dict_merge(dicta, dictb): + """ + Merge two dictionaries. + """ + return _dict_merge(dicta, dictb, prefix='') + + +def dict_foreach(dic, func, special_func={}): + """ + Recursively apply a function to all non-dictionary leaf values in a dictionary. + """ + assert isinstance(dic, dict), 'input must be a dictionary' + for key in dic.keys(): + if isinstance(dic[key], dict): + dic[key] = dict_foreach(dic[key], func) + else: + if key in special_func.keys(): + dic[key] = special_func[key](dic[key]) + else: + dic[key] = func(dic[key]) + return dic + + +def dict_reduce(dicts, func, special_func={}): + """ + Reduce a list of dictionaries. Leaf values must be scalars. + """ + assert isinstance(dicts, list), 'input must be a list of dictionaries' + assert all([isinstance(d, dict) for d in dicts]), 'input must be a list of dictionaries' + assert len(dicts) > 0, 'input must be a non-empty list of dictionaries' + all_keys = set([key for dict_ in dicts for key in dict_.keys()]) + reduced_dict = {} + for key in all_keys: + vlist = [dict_[key] for dict_ in dicts if key in dict_.keys()] + if isinstance(vlist[0], dict): + reduced_dict[key] = dict_reduce(vlist, func, special_func) + else: + if key in special_func.keys(): + reduced_dict[key] = special_func[key](vlist) + else: + reduced_dict[key] = func(vlist) + return reduced_dict + + +def dict_any(dic, func): + """ + Recursively apply a function to all non-dictionary leaf values in a dictionary. + """ + assert isinstance(dic, dict), 'input must be a dictionary' + for key in dic.keys(): + if isinstance(dic[key], dict): + if dict_any(dic[key], func): + return True + else: + if func(dic[key]): + return True + return False + + +def dict_all(dic, func): + """ + Recursively apply a function to all non-dictionary leaf values in a dictionary. + """ + assert isinstance(dic, dict), 'input must be a dictionary' + for key in dic.keys(): + if isinstance(dic[key], dict): + if not dict_all(dic[key], func): + return False + else: + if not func(dic[key]): + return False + return True + + +def dict_flatten(dic, sep='.'): + """ + Flatten a nested dictionary into a dictionary with no nested dictionaries. + """ + assert isinstance(dic, dict), 'input must be a dictionary' + flat_dict = {} + for key in dic.keys(): + if isinstance(dic[key], dict): + sub_dict = dict_flatten(dic[key], sep=sep) + for sub_key in sub_dict.keys(): + flat_dict[str(key) + sep + str(sub_key)] = sub_dict[sub_key] + else: + flat_dict[key] = dic[key] + return flat_dict + + +# Context utils +@contextlib.contextmanager +def nested_contexts(*contexts): + with contextlib.ExitStack() as stack: + for ctx in contexts: + stack.enter_context(ctx()) + yield + + +# Image utils +def make_grid(images, nrow=None, ncol=None, aspect_ratio=None): + num_images = len(images) + if nrow is None and ncol is None: + if aspect_ratio is not None: + nrow = int(np.round(np.sqrt(num_images / aspect_ratio))) + else: + nrow = int(np.sqrt(num_images)) + ncol = (num_images + nrow - 1) // nrow + elif nrow is None and ncol is not None: + nrow = (num_images + ncol - 1) // ncol + elif nrow is not None and ncol is None: + ncol = (num_images + nrow - 1) // nrow + else: + assert nrow * ncol >= num_images, 'nrow * ncol must be greater than or equal to the number of images' + + if images[0].ndim == 2: + grid = np.zeros((nrow * images[0].shape[0], ncol * images[0].shape[1]), dtype=images[0].dtype) + else: + grid = np.zeros((nrow * images[0].shape[0], ncol * images[0].shape[1], images[0].shape[2]), dtype=images[0].dtype) + for i, img in enumerate(images): + row = i // ncol + col = i % ncol + grid[row * img.shape[0]:(row + 1) * img.shape[0], col * img.shape[1]:(col + 1) * img.shape[1]] = img + return grid + + +def notes_on_image(img, notes=None): + img = np.pad(img, ((0, 32), (0, 0), (0, 0)), 'constant', constant_values=0) + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + if notes is not None: + img = cv2.putText(img, notes, (0, img.shape[0] - 4), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 1) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + return img + + + +def text_image(text, resolution=(512, 512), max_size=0.5, h_align="left", v_align="center"): + """ + Draw text on an image of the given resolution. The text is automatically wrapped + and scaled so that it fits completely within the image while preserving any explicit + line breaks and original spacing. Horizontal and vertical alignment can be controlled + via flags. + + Parameters: + text (str): The input text. Newline characters and spacing are preserved. + resolution (tuple): The image resolution as (width, height). + max_size (float): The maximum font size. + h_align (str): Horizontal alignment. Options: "left", "center", "right". + v_align (str): Vertical alignment. Options: "top", "center", "bottom". + + Returns: + numpy.ndarray: The resulting image (BGR format) with the text drawn. + """ + width, height = resolution + # Create a white background image + img = np.full((height, width, 3), 255, dtype=np.uint8) + + # Set margins and compute available drawing area + margin = 10 + avail_width = width - 2 * margin + avail_height = height - 2 * margin + + # Choose OpenCV font and text thickness + font = cv2.FONT_HERSHEY_SIMPLEX + thickness = 1 + # Ratio for additional spacing between lines (relative to the height of "A") + line_spacing_ratio = 0.5 + + def wrap_line(line, max_width, font, thickness, scale): + """ + Wrap a single line of text into multiple lines such that each line's + width (measured at the given scale) does not exceed max_width. + This function preserves the original spacing by splitting the line into tokens + (words and whitespace) using a regular expression. + + Parameters: + line (str): The input text line. + max_width (int): Maximum allowed width in pixels. + font (int): OpenCV font identifier. + thickness (int): Text thickness. + scale (float): The current font scale. + + Returns: + List[str]: A list of wrapped lines. + """ + # Split the line into tokens (words and whitespace), preserving spacing + tokens = re.split(r'(\s+)', line) + if not tokens: + return [''] + + wrapped_lines = [] + current_line = "" + for token in tokens: + candidate = current_line + token + candidate_width = cv2.getTextSize(candidate, font, scale, thickness)[0][0] + if candidate_width <= max_width: + current_line = candidate + else: + # If current_line is empty, the token itself is too wide; + # break the token character by character. + if current_line == "": + sub_token = "" + for char in token: + candidate_char = sub_token + char + if cv2.getTextSize(candidate_char, font, scale, thickness)[0][0] <= max_width: + sub_token = candidate_char + else: + if sub_token: + wrapped_lines.append(sub_token) + sub_token = char + current_line = sub_token + else: + wrapped_lines.append(current_line) + current_line = token + if current_line: + wrapped_lines.append(current_line) + return wrapped_lines + + def compute_text_block(scale): + """ + Wrap the entire text (splitting at explicit newline characters) using the + provided scale, and then compute the overall width and height of the text block. + + Returns: + wrapped_lines (List[str]): The list of wrapped lines. + block_width (int): Maximum width among the wrapped lines. + block_height (int): Total height of the text block including spacing. + sizes (List[tuple]): A list of (width, height) for each wrapped line. + spacing (int): The spacing between lines (computed from the scaled "A" height). + """ + # Split text by explicit newlines + input_lines = text.splitlines() if text else [''] + wrapped_lines = [] + for line in input_lines: + wrapped = wrap_line(line, avail_width, font, thickness, scale) + wrapped_lines.extend(wrapped) + + sizes = [] + for line in wrapped_lines: + (text_size, _) = cv2.getTextSize(line, font, scale, thickness) + sizes.append(text_size) # (width, height) + + block_width = max((w for w, h in sizes), default=0) + # Use the height of "A" (at the current scale) to compute line spacing + base_height = cv2.getTextSize("A", font, scale, thickness)[0][1] + spacing = int(line_spacing_ratio * base_height) + block_height = sum(h for w, h in sizes) + spacing * (len(sizes) - 1) if sizes else 0 + + return wrapped_lines, block_width, block_height, sizes, spacing + + # Use binary search to find the maximum scale that allows the text block to fit + lo = 0.001 + hi = max_size + eps = 0.001 # convergence threshold + best_scale = lo + best_result = None + + while hi - lo > eps: + mid = (lo + hi) / 2 + wrapped_lines, block_width, block_height, sizes, spacing = compute_text_block(mid) + # Ensure that both width and height constraints are met + if block_width <= avail_width and block_height <= avail_height: + best_scale = mid + best_result = (wrapped_lines, block_width, block_height, sizes, spacing) + lo = mid # try a larger scale + else: + hi = mid # reduce the scale + + if best_result is None: + best_scale = 0.5 + best_result = compute_text_block(best_scale) + + wrapped_lines, block_width, block_height, sizes, spacing = best_result + + # Compute starting y-coordinate based on vertical alignment flag + if v_align == "top": + y_top = margin + elif v_align == "center": + y_top = margin + (avail_height - block_height) // 2 + elif v_align == "bottom": + y_top = margin + (avail_height - block_height) + else: + y_top = margin + (avail_height - block_height) // 2 # default to center if invalid flag + + # For cv2.putText, the y coordinate represents the text baseline; + # so for the first line add its height. + y = y_top + (sizes[0][1] if sizes else 0) + + # Draw each line with horizontal alignment based on the flag + for i, line in enumerate(wrapped_lines): + line_width, line_height = sizes[i] + if h_align == "left": + x = margin + elif h_align == "center": + x = margin + (avail_width - line_width) // 2 + elif h_align == "right": + x = margin + (avail_width - line_width) + else: + x = margin # default to left if invalid flag + + cv2.putText(img, line, (x, y), font, best_scale, (0, 0, 0), thickness, cv2.LINE_AA) + y += line_height + spacing + + return img + + +def save_image_with_notes(img, path, notes=None): + """ + Save an image with notes. + """ + if isinstance(img, torch.Tensor): + img = img.cpu().numpy().transpose(1, 2, 0) + if img.dtype == np.float32 or img.dtype == np.float64: + img = np.clip(img * 255, 0, 255).astype(np.uint8) + img = notes_on_image(img, notes) + cv2.imwrite(path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) + + +# debug utils + +def atol(x, y): + """ + Absolute tolerance. + """ + return torch.abs(x - y) + + +def rtol(x, y): + """ + Relative tolerance. + """ + return torch.abs(x - y) / torch.clamp_min(torch.maximum(torch.abs(x), torch.abs(y)), 1e-12) + + +# print utils +def indent(s, n=4): + """ + Indent a string. + """ + lines = s.split('\n') + for i in range(1, len(lines)): + lines[i] = ' ' * n + lines[i] + return '\n'.join(lines) + diff --git a/trellis2/utils/grad_clip_utils.py b/trellis2/utils/grad_clip_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c6636e57ecd77c1b4ad20fedcfb5728b15e9b8a3 --- /dev/null +++ b/trellis2/utils/grad_clip_utils.py @@ -0,0 +1,81 @@ +from typing import * +import torch +import numpy as np +import torch.utils + + +class AdaptiveGradClipper: + """ + Adaptive gradient clipping for training. + """ + def __init__( + self, + max_norm=None, + clip_percentile=95.0, + buffer_size=1000, + ): + self.max_norm = max_norm + self.clip_percentile = clip_percentile + self.buffer_size = buffer_size + + self._grad_norm = np.zeros(buffer_size, dtype=np.float32) + self._max_norm = max_norm + self._buffer_ptr = 0 + self._buffer_length = 0 + + def __repr__(self): + return f'AdaptiveGradClipper(max_norm={self.max_norm}, clip_percentile={self.clip_percentile})' + + def state_dict(self): + return { + 'grad_norm': self._grad_norm, + 'max_norm': self._max_norm, + 'buffer_ptr': self._buffer_ptr, + 'buffer_length': self._buffer_length, + } + + def load_state_dict(self, state_dict): + self._grad_norm = state_dict['grad_norm'] + self._max_norm = state_dict['max_norm'] + self._buffer_ptr = state_dict['buffer_ptr'] + self._buffer_length = state_dict['buffer_length'] + + def log(self): + return { + 'max_norm': self._max_norm, + } + + def __call__(self, parameters, norm_type=2.0, error_if_nonfinite=False, foreach=None): + """Clip the gradient norm of an iterable of parameters. + + The norm is computed over all gradients together, as if they were + concatenated into a single vector. Gradients are modified in-place. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + norm_type (float): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + error_if_nonfinite (bool): if True, an error is thrown if the total + norm of the gradients from :attr:`parameters` is ``nan``, + ``inf``, or ``-inf``. Default: False (will switch to True in the future) + foreach (bool): use the faster foreach-based implementation. + If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently + fall back to the slow implementation for other device types. + Default: ``None`` + + Returns: + Total norm of the parameter gradients (viewed as a single vector). + """ + max_norm = self._max_norm if self._max_norm is not None else float('inf') + grad_norm = torch.nn.utils.clip_grad_norm_(parameters, max_norm=max_norm, norm_type=norm_type, error_if_nonfinite=error_if_nonfinite, foreach=foreach) + + if torch.isfinite(grad_norm): + self._grad_norm[self._buffer_ptr] = grad_norm + self._buffer_ptr = (self._buffer_ptr + 1) % self.buffer_size + self._buffer_length = min(self._buffer_length + 1, self.buffer_size) + if self._buffer_length == self.buffer_size: + self._max_norm = np.percentile(self._grad_norm, self.clip_percentile) + self._max_norm = min(self._max_norm, self.max_norm) if self.max_norm is not None else self._max_norm + + return grad_norm \ No newline at end of file diff --git a/trellis2/utils/loss_utils.py b/trellis2/utils/loss_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bed3c139a9ecab87159a993e78eafcef5573dc28 --- /dev/null +++ b/trellis2/utils/loss_utils.py @@ -0,0 +1,92 @@ +import torch +import torch.nn.functional as F +from torch.autograd import Variable +from math import exp +from lpips import LPIPS + + +def smooth_l1_loss(pred, target, beta=1.0): + diff = torch.abs(pred - target) + loss = torch.where(diff < beta, 0.5 * diff ** 2 / beta, diff - 0.5 * beta) + return loss.mean() + + +def l1_loss(network_output, gt): + return torch.abs((network_output - gt)).mean() + + +def l2_loss(network_output, gt): + return ((network_output - gt) ** 2).mean() + + +def gaussian(window_size, sigma): + gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) + return gauss / gauss.sum() + + +def create_window(window_size, channel): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) + window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) + return window + + +def psnr(img1, img2, max_val=1.0): + mse = F.mse_loss(img1, img2) + return 20 * torch.log10(max_val / torch.sqrt(mse)) + + +def ssim(img1, img2, window_size=11, size_average=True): + channel = img1.size(-3) + window = create_window(window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + return _ssim(img1, img2, window, window_size, channel, size_average) + +def _ssim(img1, img2, window, window_size, channel, size_average=True): + mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) + mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq + sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq + sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 + + C1 = 0.01 ** 2 + C2 = 0.03 ** 2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + + if size_average: + return ssim_map.mean() + else: + return ssim_map.mean(1).mean(1).mean(1) + + +loss_fn_vgg = None +def lpips(img1, img2, value_range=(0, 1)): + global loss_fn_vgg + if loss_fn_vgg is None: + loss_fn_vgg = LPIPS(net='vgg').cuda().eval() + # normalize to [-1, 1] + img1 = (img1 - value_range[0]) / (value_range[1] - value_range[0]) * 2 - 1 + img2 = (img2 - value_range[0]) / (value_range[1] - value_range[0]) * 2 - 1 + return loss_fn_vgg(img1, img2).mean() + + +def normal_angle(pred, gt): + pred = pred * 2.0 - 1.0 + gt = gt * 2.0 - 1.0 + norms = pred.norm(dim=-1) * gt.norm(dim=-1) + cos_sim = (pred * gt).sum(-1) / (norms + 1e-9) + cos_sim = torch.clamp(cos_sim, -1.0, 1.0) + ang = torch.rad2deg(torch.acos(cos_sim[norms > 1e-9])).mean() + if ang.isnan(): + return -1 + return ang diff --git a/trellis2/utils/mesh_utils.py b/trellis2/utils/mesh_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..da5aa45fe37b42c4218d3726aa24a13982ba106b --- /dev/null +++ b/trellis2/utils/mesh_utils.py @@ -0,0 +1,268 @@ +from typing import Tuple, Dict +import numpy as np +from trimesh import grouping, util, remesh +import struct +import re +from plyfile import PlyData, PlyElement + + +def read_ply(filename): + """ + Read a PLY file and return vertices, triangle faces, and quad faces. + + Args: + filename (str): The file path to read from. + + Returns: + vertices (np.ndarray): Array of shape [N, 3] containing vertex positions. + tris (np.ndarray): Array of shape [M, 3] containing triangle face indices (empty if none). + quads (np.ndarray): Array of shape [K, 4] containing quad face indices (empty if none). + """ + with open(filename, 'rb') as f: + # Read the header until 'end_header' is encountered + header_bytes = b"" + while True: + line = f.readline() + if not line: + raise ValueError("PLY header not found") + header_bytes += line + if b"end_header" in line: + break + header = header_bytes.decode('utf-8') + + # Determine if the file is in ASCII or binary format + is_ascii = "ascii" in header + + # Extract the number of vertices and faces from the header using regex + vertex_match = re.search(r'element vertex (\d+)', header) + if vertex_match: + num_vertices = int(vertex_match.group(1)) + else: + raise ValueError("Vertex count not found in header") + + face_match = re.search(r'element face (\d+)', header) + if face_match: + num_faces = int(face_match.group(1)) + else: + raise ValueError("Face count not found in header") + + vertices = [] + tris = [] + quads = [] + + if is_ascii: + # For ASCII format, read each line of vertex data (each line contains 3 floats) + for _ in range(num_vertices): + line = f.readline().decode('utf-8').strip() + if not line: + continue + parts = line.split() + vertices.append([float(parts[0]), float(parts[1]), float(parts[2])]) + + # Read face data, where the first number indicates the number of vertices for the face + for _ in range(num_faces): + line = f.readline().decode('utf-8').strip() + if not line: + continue + parts = line.split() + count = int(parts[0]) + indices = list(map(int, parts[1:])) + if count == 3: + tris.append(indices) + elif count == 4: + quads.append(indices) + else: + # Skip faces with other numbers of vertices (can be extended as needed) + pass + else: + # For binary format: read directly from the binary stream + # Each vertex consists of 3 floats (12 bytes per vertex) + for _ in range(num_vertices): + data = f.read(12) + if len(data) < 12: + raise ValueError("Insufficient vertex data") + v = struct.unpack(' 0 else np.empty((0, 3), dtype=np.int32) + quads = np.array(quads, dtype=np.int32) if len(quads) > 0 else np.empty((0, 4), dtype=np.int32) + + return vertices, tris, quads + + +def write_ply( + filename: str, + vertices: np.ndarray, + tris: np.ndarray, + quads: np.ndarray, + vertex_colors: np.ndarray = None, + ascii: bool = False +): + """ + Write a mesh to a PLY file, with the option to save in ASCII or binary format, + and optional per-vertex colors. + + Args: + filename (str): The filename to write to. + vertices (np.ndarray): [N, 3] The vertex positions. + tris (np.ndarray): [M, 3] The triangle indices. + quads (np.ndarray): [K, 4] The quad indices. + vertex_colors (np.ndarray, optional): [N, 3] or [N, 4] UInt8 colors for each vertex (RGB or RGBA). + ascii (bool): If True, write in ASCII format; otherwise binary little-endian. + """ + import struct + + num_vertices = len(vertices) + num_faces = len(tris) + len(quads) + + # Build header + header_lines = [ + "ply", + f"format {'ascii 1.0' if ascii else 'binary_little_endian 1.0'}", + f"element vertex {num_vertices}", + "property float x", + "property float y", + "property float z", + ] + + # Add vertex color properties if provided + has_color = vertex_colors is not None + if has_color: + # Expect uint8 values 0-255 + header_lines += [ + "property uchar red", + "property uchar green", + "property uchar blue", + ] + # Include alpha if RGBA + if vertex_colors.shape[1] == 4: + header_lines.append("property uchar alpha") + + header_lines += [ + f"element face {num_faces}", + "property list uchar int vertex_index", + "end_header", + "" + ] + header = "\n".join(header_lines) + + mode = 'w' if ascii else 'wb' + with open(filename, mode) as f: + # Write header + if ascii: + f.write(header) + else: + f.write(header.encode('utf-8')) + + # Write vertex data + for i, v in enumerate(vertices): + if ascii: + line = f"{v[0]} {v[1]} {v[2]}" + if has_color: + col = vertex_colors[i] + line += ' ' + ' '.join(str(int(c)) for c in col) + f.write(line + '\n') + else: + # pack xyz as floats + f.write(struct.pack(' 0: + digit = n % base + val += digit * inv_base_n + n //= base + inv_base_n *= inv_base + return val + +def halton_sequence(dim, n): + return [radical_inverse(PRIMES[dim], n) for dim in range(dim)] + +def hammersley_sequence(dim, n, num_samples): + return [n / num_samples] + halton_sequence(dim - 1, n) + +def sphere_hammersley_sequence(n, num_samples, offset=(0, 0), remap=False): + u, v = hammersley_sequence(2, n, num_samples) + u += offset[0] / num_samples + v += offset[1] + if remap: + u = 2 * u if u < 0.25 else 2 / 3 * u + 1 / 3 + theta = np.arccos(1 - 2 * u) - np.pi / 2 + phi = v * 2 * np.pi + return [phi, theta] \ No newline at end of file diff --git a/trellis2/utils/render_utils.py b/trellis2/utils/render_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..38db7f5d0c8d114ce2dd25e00ddcece318dbe6b9 --- /dev/null +++ b/trellis2/utils/render_utils.py @@ -0,0 +1,129 @@ +import torch +import numpy as np +from tqdm import tqdm +import utils3d +from PIL import Image + +from ..renderers import MeshRenderer, VoxelRenderer, PbrMeshRenderer +from ..representations import Mesh, Voxel, MeshWithPbrMaterial, MeshWithVoxel +from .random_utils import sphere_hammersley_sequence + + +def yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, rs, fovs): + is_list = isinstance(yaws, list) + if not is_list: + yaws = [yaws] + pitchs = [pitchs] + if not isinstance(rs, list): + rs = [rs] * len(yaws) + if not isinstance(fovs, list): + fovs = [fovs] * len(yaws) + extrinsics = [] + intrinsics = [] + for yaw, pitch, r, fov in zip(yaws, pitchs, rs, fovs): + fov = torch.deg2rad(torch.tensor(float(fov))).cuda() + yaw = torch.tensor(float(yaw)).cuda() + pitch = torch.tensor(float(pitch)).cuda() + orig = torch.tensor([ + torch.sin(yaw) * torch.cos(pitch), + torch.cos(yaw) * torch.cos(pitch), + torch.sin(pitch), + ]).cuda() * r + extr = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda()) + intr = utils3d.torch.intrinsics_from_fov_xy(fov, fov) + extrinsics.append(extr) + intrinsics.append(intr) + if not is_list: + extrinsics = extrinsics[0] + intrinsics = intrinsics[0] + return extrinsics, intrinsics + + +def get_renderer(sample, **kwargs): + if isinstance(sample, (MeshWithPbrMaterial, MeshWithVoxel)): + renderer = PbrMeshRenderer() + renderer.rendering_options.resolution = kwargs.get('resolution', 512) + renderer.rendering_options.near = kwargs.get('near', 1) + renderer.rendering_options.far = kwargs.get('far', 100) + renderer.rendering_options.ssaa = kwargs.get('ssaa', 2) + renderer.rendering_options.peel_layers = kwargs.get('peel_layers', 8) + elif isinstance(sample, Mesh): + renderer = MeshRenderer() + renderer.rendering_options.resolution = kwargs.get('resolution', 512) + renderer.rendering_options.near = kwargs.get('near', 1) + renderer.rendering_options.far = kwargs.get('far', 100) + renderer.rendering_options.ssaa = kwargs.get('ssaa', 2) + renderer.rendering_options.chunk_size = kwargs.get('chunk_size', None) + elif isinstance(sample, Voxel): + renderer = VoxelRenderer() + renderer.rendering_options.resolution = kwargs.get('resolution', 512) + renderer.rendering_options.near = kwargs.get('near', 0.1) + renderer.rendering_options.far = kwargs.get('far', 10.0) + renderer.rendering_options.ssaa = kwargs.get('ssaa', 2) + else: + raise ValueError(f'Unsupported sample type: {type(sample)}') + return renderer + + +def render_frames(sample, extrinsics, intrinsics, options={}, verbose=True, **kwargs): + renderer = get_renderer(sample, **options) + rets = {} + for j, (extr, intr) in tqdm(enumerate(zip(extrinsics, intrinsics)), total=len(extrinsics), desc='Rendering', disable=not verbose): + res = renderer.render(sample, extr, intr, **kwargs) + for k, v in res.items(): + if k not in rets: rets[k] = [] + if v.dim() == 2: v = v[None].repeat(3, 1, 1) + rets[k].append(np.clip(v.detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8)) + return rets + + +def render_video(sample, resolution=1024, bg_color=(0, 0, 0), num_frames=120, r=2, fov=40, **kwargs): + yaws = -torch.linspace(0, 2 * 3.1415, num_frames) + np.pi/2 + pitch = 0.25 + 0.5 * torch.sin(torch.linspace(0, 2 * 3.1415, num_frames)) + yaws = yaws.tolist() + pitch = pitch.tolist() + extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitch, r, fov) + return render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs) + + +def render_multiview(sample, resolution=512, nviews=30): + r = 2 + fov = 40 + cams = [sphere_hammersley_sequence(i, nviews) for i in range(nviews)] + yaws = [cam[0] for cam in cams] + pitchs = [cam[1] for cam in cams] + extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, r, fov) + res = render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': (0, 0, 0)}) + return res['color'], extrinsics, intrinsics + + +def render_snapshot(samples, resolution=512, bg_color=(0, 0, 0), offset=(-16 / 180 * np.pi, 20 / 180 * np.pi), r=10, fov=8, nviews=4, **kwargs): + yaw = np.linspace(0, 2 * np.pi, nviews, endpoint=False) + yaw_offset = offset[0] + yaw = [y + yaw_offset for y in yaw] + pitch = [offset[1] for _ in range(nviews)] + extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, r, fov) + return render_frames(samples, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs) + + +def make_pbr_vis_frames(result, resolution=1024): + num_frames = len(result['shaded']) + frames = [] + for i in range(num_frames): + shaded = Image.fromarray(result['shaded'][i]) + normal = Image.fromarray(result['normal'][i]) + base_color = Image.fromarray(result['base_color'][i]) + metallic = Image.fromarray(result['metallic'][i]) + roughness = Image.fromarray(result['roughness'][i]) + alpha = Image.fromarray(result['alpha'][i]) + shaded = shaded.resize((resolution, resolution)) + normal = normal.resize((resolution, resolution)) + base_color = base_color.resize((resolution//2, resolution//2)) + metallic = metallic.resize((resolution//2, resolution//2)) + roughness = roughness.resize((resolution//2, resolution//2)) + alpha = alpha.resize((resolution//2, resolution//2)) + row1 = np.concatenate([shaded, normal], axis=1) + row2 = np.concatenate([base_color, metallic, roughness, alpha], axis=1) + frame = np.concatenate([row1, row2], axis=0) + frames.append(frame) + return frames diff --git a/trellis2/utils/vis_utils.py b/trellis2/utils/vis_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6a6dcc8714dda3f34018c498cebe720513121827 --- /dev/null +++ b/trellis2/utils/vis_utils.py @@ -0,0 +1,44 @@ +from typing import * +import numpy as np +import torch +from ..modules import sparse as sp +from ..representations import Voxel +from .render_utils import render_video + + +def pca_color(feats: torch.Tensor, channels: Tuple[int, int, int] = (0, 1, 2)) -> torch.Tensor: + """ + Apply PCA to the features and return the first three principal components. + """ + feats = feats.detach() + u, s, v = torch.svd(feats) + color = u[:, channels] + color = (color - color.min(dim=0, keepdim=True)[0]) / (color.max(dim=0, keepdim=True)[0] - color.min(dim=0, keepdim=True)[0]) + return color + + +def vis_sparse_tensor( + x: sp.SparseTensor, + num_frames: int = 300, +): + assert x.shape[0] == 1, "Only support batch size 1" + assert x.coords.shape[1] == 4, "Only support 3D coordinates" + + coords = x.coords.cuda().detach()[:, 1:] + feats = x.feats.cuda().detach() + color = pca_color(feats) + + resolution = max(list(x.spatial_shape)) + resolution = int(2**np.ceil(np.log2(resolution))) + + rep = Voxel( + origin=[-0.5, -0.5, -0.5], + voxel_size=1/resolution, + coords=coords, + attrs=color, + layout={ + 'color': slice(0, 3), + } + ) + + return render_video(rep, colors_overwrite=color, num_frames=num_frames)['color']