jsnavarroo commited on
Commit
e8cd8e2
·
1 Parent(s): 93e4a74

Vendor TRELLIS source code for inference endpoint

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. requirements.txt +1 -2
  2. trellis/__init__.py +6 -0
  3. trellis/datasets/__init__.py +58 -0
  4. trellis/datasets/components.py +137 -0
  5. trellis/datasets/sparse_feat2render.py +134 -0
  6. trellis/datasets/sparse_structure.py +107 -0
  7. trellis/datasets/sparse_structure_latent.py +188 -0
  8. trellis/datasets/structured_latent.py +217 -0
  9. trellis/datasets/structured_latent2render.py +160 -0
  10. trellis/models/__init__.py +96 -0
  11. trellis/models/sparse_elastic_mixin.py +24 -0
  12. trellis/models/sparse_structure_flow.py +200 -0
  13. trellis/models/sparse_structure_vae.py +306 -0
  14. trellis/models/structured_latent_flow.py +276 -0
  15. trellis/models/structured_latent_vae/__init__.py +4 -0
  16. trellis/models/structured_latent_vae/base.py +117 -0
  17. trellis/models/structured_latent_vae/decoder_gs.py +131 -0
  18. trellis/models/structured_latent_vae/decoder_mesh.py +176 -0
  19. trellis/models/structured_latent_vae/decoder_rf.py +113 -0
  20. trellis/models/structured_latent_vae/encoder.py +80 -0
  21. trellis/modules/attention/__init__.py +36 -0
  22. trellis/modules/attention/full_attn.py +140 -0
  23. trellis/modules/attention/modules.py +146 -0
  24. trellis/modules/norm.py +25 -0
  25. trellis/modules/sparse/__init__.py +102 -0
  26. trellis/modules/sparse/attention/__init__.py +4 -0
  27. trellis/modules/sparse/attention/full_attn.py +215 -0
  28. trellis/modules/sparse/attention/modules.py +139 -0
  29. trellis/modules/sparse/attention/serialized_attn.py +193 -0
  30. trellis/modules/sparse/attention/windowed_attn.py +135 -0
  31. trellis/modules/sparse/basic.py +459 -0
  32. trellis/modules/sparse/conv/__init__.py +21 -0
  33. trellis/modules/sparse/conv/conv_spconv.py +80 -0
  34. trellis/modules/sparse/conv/conv_torchsparse.py +38 -0
  35. trellis/modules/sparse/linear.py +15 -0
  36. trellis/modules/sparse/nonlinearity.py +35 -0
  37. trellis/modules/sparse/norm.py +58 -0
  38. trellis/modules/sparse/spatial.py +110 -0
  39. trellis/modules/sparse/transformer/__init__.py +2 -0
  40. trellis/modules/sparse/transformer/blocks.py +151 -0
  41. trellis/modules/sparse/transformer/modulated.py +166 -0
  42. trellis/modules/spatial.py +48 -0
  43. trellis/modules/transformer/__init__.py +2 -0
  44. trellis/modules/transformer/blocks.py +182 -0
  45. trellis/modules/transformer/modulated.py +157 -0
  46. trellis/modules/utils.py +54 -0
  47. trellis/pipelines/__init__.py +25 -0
  48. trellis/pipelines/base.py +68 -0
  49. trellis/pipelines/samplers/__init__.py +2 -0
  50. trellis/pipelines/samplers/base.py +20 -0
requirements.txt CHANGED
@@ -3,5 +3,4 @@ numpy
3
  scipy
4
  trimesh
5
  scikit-image
6
- pyembree
7
- trellis @ git+https://github.com/microsoft/TRELLIS.git
 
3
  scipy
4
  trimesh
5
  scikit-image
6
+ pyembree
 
trellis/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from . import models
2
+ from . import modules
3
+ from . import pipelines
4
+ from . import renderers
5
+ from . import representations
6
+ from . import utils
trellis/datasets/__init__.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+
3
+ __attributes = {
4
+ 'SparseStructure': 'sparse_structure',
5
+
6
+ 'SparseFeat2Render': 'sparse_feat2render',
7
+ 'SLat2Render':'structured_latent2render',
8
+ 'Slat2RenderGeo':'structured_latent2render',
9
+
10
+ 'SparseStructureLatent': 'sparse_structure_latent',
11
+ 'TextConditionedSparseStructureLatent': 'sparse_structure_latent',
12
+ 'ImageConditionedSparseStructureLatent': 'sparse_structure_latent',
13
+
14
+ 'SLat': 'structured_latent',
15
+ 'TextConditionedSLat': 'structured_latent',
16
+ 'ImageConditionedSLat': 'structured_latent',
17
+ }
18
+
19
+ __submodules = []
20
+
21
+ __all__ = list(__attributes.keys()) + __submodules
22
+
23
+ def __getattr__(name):
24
+ if name not in globals():
25
+ if name in __attributes:
26
+ module_name = __attributes[name]
27
+ module = importlib.import_module(f".{module_name}", __name__)
28
+ globals()[name] = getattr(module, name)
29
+ elif name in __submodules:
30
+ module = importlib.import_module(f".{name}", __name__)
31
+ globals()[name] = module
32
+ else:
33
+ raise AttributeError(f"module {__name__} has no attribute {name}")
34
+ return globals()[name]
35
+
36
+
37
+ # For Pylance
38
+ if __name__ == '__main__':
39
+ from .sparse_structure import SparseStructure
40
+
41
+ from .sparse_feat2render import SparseFeat2Render
42
+ from .structured_latent2render import (
43
+ SLat2Render,
44
+ Slat2RenderGeo,
45
+ )
46
+
47
+ from .sparse_structure_latent import (
48
+ SparseStructureLatent,
49
+ TextConditionedSparseStructureLatent,
50
+ ImageConditionedSparseStructureLatent,
51
+ )
52
+
53
+ from .structured_latent import (
54
+ SLat,
55
+ TextConditionedSLat,
56
+ ImageConditionedSLat,
57
+ )
58
+
trellis/datasets/components.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ from abc import abstractmethod
3
+ import os
4
+ import json
5
+ import torch
6
+ import numpy as np
7
+ import pandas as pd
8
+ from PIL import Image
9
+ from torch.utils.data import Dataset
10
+
11
+
12
+ class StandardDatasetBase(Dataset):
13
+ """
14
+ Base class for standard datasets.
15
+
16
+ Args:
17
+ roots (str): paths to the dataset
18
+ """
19
+
20
+ def __init__(self,
21
+ roots: str,
22
+ ):
23
+ super().__init__()
24
+ self.roots = roots.split(',')
25
+ self.instances = []
26
+ self.metadata = pd.DataFrame()
27
+
28
+ self._stats = {}
29
+ for root in self.roots:
30
+ key = os.path.basename(root)
31
+ self._stats[key] = {}
32
+ metadata = pd.read_csv(os.path.join(root, 'metadata.csv'))
33
+ self._stats[key]['Total'] = len(metadata)
34
+ metadata, stats = self.filter_metadata(metadata)
35
+ self._stats[key].update(stats)
36
+ self.instances.extend([(root, sha256) for sha256 in metadata['sha256'].values])
37
+ metadata.set_index('sha256', inplace=True)
38
+ self.metadata = pd.concat([self.metadata, metadata])
39
+
40
+ @abstractmethod
41
+ def filter_metadata(self, metadata: pd.DataFrame) -> Tuple[pd.DataFrame, Dict[str, int]]:
42
+ pass
43
+
44
+ @abstractmethod
45
+ def get_instance(self, root: str, instance: str) -> Dict[str, Any]:
46
+ pass
47
+
48
+ def __len__(self):
49
+ return len(self.instances)
50
+
51
+ def __getitem__(self, index) -> Dict[str, Any]:
52
+ try:
53
+ root, instance = self.instances[index]
54
+ return self.get_instance(root, instance)
55
+ except Exception as e:
56
+ print(e)
57
+ return self.__getitem__(np.random.randint(0, len(self)))
58
+
59
+ def __str__(self):
60
+ lines = []
61
+ lines.append(self.__class__.__name__)
62
+ lines.append(f' - Total instances: {len(self)}')
63
+ lines.append(f' - Sources:')
64
+ for key, stats in self._stats.items():
65
+ lines.append(f' - {key}:')
66
+ for k, v in stats.items():
67
+ lines.append(f' - {k}: {v}')
68
+ return '\n'.join(lines)
69
+
70
+
71
+ class TextConditionedMixin:
72
+ def __init__(self, roots, **kwargs):
73
+ super().__init__(roots, **kwargs)
74
+ self.captions = {}
75
+ for instance in self.instances:
76
+ sha256 = instance[1]
77
+ self.captions[sha256] = json.loads(self.metadata.loc[sha256]['captions'])
78
+
79
+ def filter_metadata(self, metadata):
80
+ metadata, stats = super().filter_metadata(metadata)
81
+ metadata = metadata[metadata['captions'].notna()]
82
+ stats['With captions'] = len(metadata)
83
+ return metadata, stats
84
+
85
+ def get_instance(self, root, instance):
86
+ pack = super().get_instance(root, instance)
87
+ text = np.random.choice(self.captions[instance])
88
+ pack['cond'] = text
89
+ return pack
90
+
91
+
92
+ class ImageConditionedMixin:
93
+ def __init__(self, roots, *, image_size=518, **kwargs):
94
+ self.image_size = image_size
95
+ super().__init__(roots, **kwargs)
96
+
97
+ def filter_metadata(self, metadata):
98
+ metadata, stats = super().filter_metadata(metadata)
99
+ metadata = metadata[metadata[f'cond_rendered']]
100
+ stats['Cond rendered'] = len(metadata)
101
+ return metadata, stats
102
+
103
+ def get_instance(self, root, instance):
104
+ pack = super().get_instance(root, instance)
105
+
106
+ image_root = os.path.join(root, 'renders_cond', instance)
107
+ with open(os.path.join(image_root, 'transforms.json')) as f:
108
+ metadata = json.load(f)
109
+ n_views = len(metadata['frames'])
110
+ view = np.random.randint(n_views)
111
+ metadata = metadata['frames'][view]
112
+
113
+ image_path = os.path.join(image_root, metadata['file_path'])
114
+ image = Image.open(image_path)
115
+
116
+ alpha = np.array(image.getchannel(3))
117
+ bbox = np.array(alpha).nonzero()
118
+ bbox = [bbox[1].min(), bbox[0].min(), bbox[1].max(), bbox[0].max()]
119
+ center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2]
120
+ hsize = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2
121
+ aug_size_ratio = 1.2
122
+ aug_hsize = hsize * aug_size_ratio
123
+ aug_center_offset = [0, 0]
124
+ aug_center = [center[0] + aug_center_offset[0], center[1] + aug_center_offset[1]]
125
+ aug_bbox = [int(aug_center[0] - aug_hsize), int(aug_center[1] - aug_hsize), int(aug_center[0] + aug_hsize), int(aug_center[1] + aug_hsize)]
126
+ image = image.crop(aug_bbox)
127
+
128
+ image = image.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS)
129
+ alpha = image.getchannel(3)
130
+ image = image.convert('RGB')
131
+ image = torch.tensor(np.array(image)).permute(2, 0, 1).float() / 255.0
132
+ alpha = torch.tensor(np.array(alpha)).float() / 255.0
133
+ image = image * alpha.unsqueeze(0)
134
+ pack['cond'] = image
135
+
136
+ return pack
137
+
trellis/datasets/sparse_feat2render.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ import json
4
+ import numpy as np
5
+ import pandas as pd
6
+ import torch
7
+ import utils3d.torch
8
+ from ..modules.sparse.basic import SparseTensor
9
+ from .components import StandardDatasetBase
10
+
11
+
12
+ class SparseFeat2Render(StandardDatasetBase):
13
+ """
14
+ SparseFeat2Render dataset.
15
+
16
+ Args:
17
+ roots (str): paths to the dataset
18
+ image_size (int): size of the image
19
+ model (str): model name
20
+ resolution (int): resolution of the data
21
+ min_aesthetic_score (float): minimum aesthetic score
22
+ max_num_voxels (int): maximum number of voxels
23
+ """
24
+ def __init__(
25
+ self,
26
+ roots: str,
27
+ image_size: int,
28
+ model: str = 'dinov2_vitl14_reg',
29
+ resolution: int = 64,
30
+ min_aesthetic_score: float = 5.0,
31
+ max_num_voxels: int = 32768,
32
+ ):
33
+ self.image_size = image_size
34
+ self.model = model
35
+ self.resolution = resolution
36
+ self.min_aesthetic_score = min_aesthetic_score
37
+ self.max_num_voxels = max_num_voxels
38
+ self.value_range = (0, 1)
39
+
40
+ super().__init__(roots)
41
+
42
+ def filter_metadata(self, metadata):
43
+ stats = {}
44
+ metadata = metadata[metadata[f'feature_{self.model}']]
45
+ stats['With features'] = len(metadata)
46
+ metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
47
+ stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
48
+ metadata = metadata[metadata['num_voxels'] <= self.max_num_voxels]
49
+ stats[f'Num voxels <= {self.max_num_voxels}'] = len(metadata)
50
+ return metadata, stats
51
+
52
+ def _get_image(self, root, instance):
53
+ with open(os.path.join(root, 'renders', instance, 'transforms.json')) as f:
54
+ metadata = json.load(f)
55
+ n_views = len(metadata['frames'])
56
+ view = np.random.randint(n_views)
57
+ metadata = metadata['frames'][view]
58
+ fov = metadata['camera_angle_x']
59
+ intrinsics = utils3d.torch.intrinsics_from_fov_xy(torch.tensor(fov), torch.tensor(fov))
60
+ c2w = torch.tensor(metadata['transform_matrix'])
61
+ c2w[:3, 1:3] *= -1
62
+ extrinsics = torch.inverse(c2w)
63
+
64
+ image_path = os.path.join(root, 'renders', instance, metadata['file_path'])
65
+ image = Image.open(image_path)
66
+ alpha = image.getchannel(3)
67
+ image = image.convert('RGB')
68
+ image = image.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS)
69
+ alpha = alpha.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS)
70
+ image = torch.tensor(np.array(image)).permute(2, 0, 1).float() / 255.0
71
+ alpha = torch.tensor(np.array(alpha)).float() / 255.0
72
+
73
+ return {
74
+ 'image': image,
75
+ 'alpha': alpha,
76
+ 'extrinsics': extrinsics,
77
+ 'intrinsics': intrinsics,
78
+ }
79
+
80
+ def _get_feat(self, root, instance):
81
+ DATA_RESOLUTION = 64
82
+ feats_path = os.path.join(root, 'features', self.model, f'{instance}.npz')
83
+ feats = np.load(feats_path, allow_pickle=True)
84
+ coords = torch.tensor(feats['indices']).int()
85
+ feats = torch.tensor(feats['patchtokens']).float()
86
+
87
+ if self.resolution != DATA_RESOLUTION:
88
+ factor = DATA_RESOLUTION // self.resolution
89
+ coords = coords // factor
90
+ coords, idx = coords.unique(return_inverse=True, dim=0)
91
+ feats = torch.scatter_reduce(
92
+ torch.zeros(coords.shape[0], feats.shape[1], device=feats.device),
93
+ dim=0,
94
+ index=idx.unsqueeze(-1).expand(-1, feats.shape[1]),
95
+ src=feats,
96
+ reduce='mean'
97
+ )
98
+
99
+ return {
100
+ 'coords': coords,
101
+ 'feats': feats,
102
+ }
103
+
104
+ @torch.no_grad()
105
+ def visualize_sample(self, sample: dict):
106
+ return sample['image']
107
+
108
+ @staticmethod
109
+ def collate_fn(batch):
110
+ pack = {}
111
+ coords = []
112
+ for i, b in enumerate(batch):
113
+ coords.append(torch.cat([torch.full((b['coords'].shape[0], 1), i, dtype=torch.int32), b['coords']], dim=-1))
114
+ coords = torch.cat(coords)
115
+ feats = torch.cat([b['feats'] for b in batch])
116
+ pack['feats'] = SparseTensor(
117
+ coords=coords,
118
+ feats=feats,
119
+ )
120
+
121
+ pack['image'] = torch.stack([b['image'] for b in batch])
122
+ pack['alpha'] = torch.stack([b['alpha'] for b in batch])
123
+ pack['extrinsics'] = torch.stack([b['extrinsics'] for b in batch])
124
+ pack['intrinsics'] = torch.stack([b['intrinsics'] for b in batch])
125
+
126
+ return pack
127
+
128
+ def get_instance(self, root, instance):
129
+ image = self._get_image(root, instance)
130
+ feat = self._get_feat(root, instance)
131
+ return {
132
+ **image,
133
+ **feat,
134
+ }
trellis/datasets/sparse_structure.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from typing import Union
4
+ import numpy as np
5
+ import pandas as pd
6
+ import torch
7
+ from torch.utils.data import Dataset
8
+ import utils3d
9
+ from .components import StandardDatasetBase
10
+ from ..representations.octree import DfsOctree as Octree
11
+ from ..renderers import OctreeRenderer
12
+
13
+
14
+ class SparseStructure(StandardDatasetBase):
15
+ """
16
+ Sparse structure dataset
17
+
18
+ Args:
19
+ roots (str): path to the dataset
20
+ resolution (int): resolution of the voxel grid
21
+ min_aesthetic_score (float): minimum aesthetic score of the instances to be included in the dataset
22
+ """
23
+
24
+ def __init__(self,
25
+ roots,
26
+ resolution: int = 64,
27
+ min_aesthetic_score: float = 5.0,
28
+ ):
29
+ self.resolution = resolution
30
+ self.min_aesthetic_score = min_aesthetic_score
31
+ self.value_range = (0, 1)
32
+
33
+ super().__init__(roots)
34
+
35
+ def filter_metadata(self, metadata):
36
+ stats = {}
37
+ metadata = metadata[metadata[f'voxelized']]
38
+ stats['Voxelized'] = len(metadata)
39
+ metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
40
+ stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
41
+ return metadata, stats
42
+
43
+ def get_instance(self, root, instance):
44
+ position = utils3d.io.read_ply(os.path.join(root, 'voxels', f'{instance}.ply'))[0]
45
+ coords = ((torch.tensor(position) + 0.5) * self.resolution).int().contiguous()
46
+ ss = torch.zeros(1, self.resolution, self.resolution, self.resolution, dtype=torch.long)
47
+ ss[:, coords[:, 0], coords[:, 1], coords[:, 2]] = 1
48
+ return {'ss': ss}
49
+
50
+ @torch.no_grad()
51
+ def visualize_sample(self, ss: Union[torch.Tensor, dict]):
52
+ ss = ss if isinstance(ss, torch.Tensor) else ss['ss']
53
+
54
+ renderer = OctreeRenderer()
55
+ renderer.rendering_options.resolution = 512
56
+ renderer.rendering_options.near = 0.8
57
+ renderer.rendering_options.far = 1.6
58
+ renderer.rendering_options.bg_color = (0, 0, 0)
59
+ renderer.rendering_options.ssaa = 4
60
+ renderer.pipe.primitive = 'voxel'
61
+
62
+ # Build camera
63
+ yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
64
+ yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4)
65
+ yaws = [y + yaws_offset for y in yaws]
66
+ pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)]
67
+
68
+ exts = []
69
+ ints = []
70
+ for yaw, pitch in zip(yaws, pitch):
71
+ orig = torch.tensor([
72
+ np.sin(yaw) * np.cos(pitch),
73
+ np.cos(yaw) * np.cos(pitch),
74
+ np.sin(pitch),
75
+ ]).float().cuda() * 2
76
+ fov = torch.deg2rad(torch.tensor(30)).cuda()
77
+ extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
78
+ intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
79
+ exts.append(extrinsics)
80
+ ints.append(intrinsics)
81
+
82
+ images = []
83
+
84
+ # Build each representation
85
+ ss = ss.cuda()
86
+ for i in range(ss.shape[0]):
87
+ representation = Octree(
88
+ depth=10,
89
+ aabb=[-0.5, -0.5, -0.5, 1, 1, 1],
90
+ device='cuda',
91
+ primitive='voxel',
92
+ sh_degree=0,
93
+ primitive_config={'solid': True},
94
+ )
95
+ coords = torch.nonzero(ss[i, 0], as_tuple=False)
96
+ representation.position = coords.float() / self.resolution
97
+ representation.depth = torch.full((representation.position.shape[0], 1), int(np.log2(self.resolution)), dtype=torch.uint8, device='cuda')
98
+
99
+ image = torch.zeros(3, 1024, 1024).cuda()
100
+ tile = [2, 2]
101
+ for j, (ext, intr) in enumerate(zip(exts, ints)):
102
+ res = renderer.render(representation, ext, intr, colors_overwrite=representation.position)
103
+ image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color']
104
+ images.append(image)
105
+
106
+ return torch.stack(images)
107
+
trellis/datasets/sparse_structure_latent.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from typing import *
4
+ import numpy as np
5
+ import torch
6
+ import utils3d
7
+ from ..representations.octree import DfsOctree as Octree
8
+ from ..renderers import OctreeRenderer
9
+ from .components import StandardDatasetBase, TextConditionedMixin, ImageConditionedMixin
10
+ from .. import models
11
+
12
+
13
+ class SparseStructureLatentVisMixin:
14
+ def __init__(
15
+ self,
16
+ *args,
17
+ pretrained_ss_dec: str = 'microsoft/TRELLIS-image-large/ckpts/ss_dec_conv3d_16l8_fp16',
18
+ ss_dec_path: Optional[str] = None,
19
+ ss_dec_ckpt: Optional[str] = None,
20
+ **kwargs
21
+ ):
22
+ super().__init__(*args, **kwargs)
23
+ self.ss_dec = None
24
+ self.pretrained_ss_dec = pretrained_ss_dec
25
+ self.ss_dec_path = ss_dec_path
26
+ self.ss_dec_ckpt = ss_dec_ckpt
27
+
28
+ def _loading_ss_dec(self):
29
+ if self.ss_dec is not None:
30
+ return
31
+ if self.ss_dec_path is not None:
32
+ cfg = json.load(open(os.path.join(self.ss_dec_path, 'config.json'), 'r'))
33
+ decoder = getattr(models, cfg['models']['decoder']['name'])(**cfg['models']['decoder']['args'])
34
+ ckpt_path = os.path.join(self.ss_dec_path, 'ckpts', f'decoder_{self.ss_dec_ckpt}.pt')
35
+ decoder.load_state_dict(torch.load(ckpt_path, map_location='cpu', weights_only=True))
36
+ else:
37
+ decoder = models.from_pretrained(self.pretrained_ss_dec)
38
+ self.ss_dec = decoder.cuda().eval()
39
+
40
+ def _delete_ss_dec(self):
41
+ del self.ss_dec
42
+ self.ss_dec = None
43
+
44
+ @torch.no_grad()
45
+ def decode_latent(self, z, batch_size=4):
46
+ self._loading_ss_dec()
47
+ ss = []
48
+ if self.normalization is not None:
49
+ z = z * self.std.to(z.device) + self.mean.to(z.device)
50
+ for i in range(0, z.shape[0], batch_size):
51
+ ss.append(self.ss_dec(z[i:i+batch_size]))
52
+ ss = torch.cat(ss, dim=0)
53
+ self._delete_ss_dec()
54
+ return ss
55
+
56
+ @torch.no_grad()
57
+ def visualize_sample(self, x_0: Union[torch.Tensor, dict]):
58
+ x_0 = x_0 if isinstance(x_0, torch.Tensor) else x_0['x_0']
59
+ x_0 = self.decode_latent(x_0.cuda())
60
+
61
+ renderer = OctreeRenderer()
62
+ renderer.rendering_options.resolution = 512
63
+ renderer.rendering_options.near = 0.8
64
+ renderer.rendering_options.far = 1.6
65
+ renderer.rendering_options.bg_color = (0, 0, 0)
66
+ renderer.rendering_options.ssaa = 4
67
+ renderer.pipe.primitive = 'voxel'
68
+
69
+ # Build camera
70
+ yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
71
+ yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4)
72
+ yaws = [y + yaws_offset for y in yaws]
73
+ pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)]
74
+
75
+ exts = []
76
+ ints = []
77
+ for yaw, pitch in zip(yaws, pitch):
78
+ orig = torch.tensor([
79
+ np.sin(yaw) * np.cos(pitch),
80
+ np.cos(yaw) * np.cos(pitch),
81
+ np.sin(pitch),
82
+ ]).float().cuda() * 2
83
+ fov = torch.deg2rad(torch.tensor(30)).cuda()
84
+ extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
85
+ intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
86
+ exts.append(extrinsics)
87
+ ints.append(intrinsics)
88
+
89
+ images = []
90
+
91
+ # Build each representation
92
+ x_0 = x_0.cuda()
93
+ for i in range(x_0.shape[0]):
94
+ representation = Octree(
95
+ depth=10,
96
+ aabb=[-0.5, -0.5, -0.5, 1, 1, 1],
97
+ device='cuda',
98
+ primitive='voxel',
99
+ sh_degree=0,
100
+ primitive_config={'solid': True},
101
+ )
102
+ coords = torch.nonzero(x_0[i, 0] > 0, as_tuple=False)
103
+ resolution = x_0.shape[-1]
104
+ representation.position = coords.float() / resolution
105
+ representation.depth = torch.full((representation.position.shape[0], 1), int(np.log2(resolution)), dtype=torch.uint8, device='cuda')
106
+
107
+ image = torch.zeros(3, 1024, 1024).cuda()
108
+ tile = [2, 2]
109
+ for j, (ext, intr) in enumerate(zip(exts, ints)):
110
+ res = renderer.render(representation, ext, intr, colors_overwrite=representation.position)
111
+ image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color']
112
+ images.append(image)
113
+
114
+ return torch.stack(images)
115
+
116
+
117
+ class SparseStructureLatent(SparseStructureLatentVisMixin, StandardDatasetBase):
118
+ """
119
+ Sparse structure latent dataset
120
+
121
+ Args:
122
+ roots (str): path to the dataset
123
+ latent_model (str): name of the latent model
124
+ min_aesthetic_score (float): minimum aesthetic score
125
+ normalization (dict): normalization stats
126
+ pretrained_ss_dec (str): name of the pretrained sparse structure decoder
127
+ ss_dec_path (str): path to the sparse structure decoder, if given, will override the pretrained_ss_dec
128
+ ss_dec_ckpt (str): name of the sparse structure decoder checkpoint
129
+ """
130
+ def __init__(self,
131
+ roots: str,
132
+ *,
133
+ latent_model: str,
134
+ min_aesthetic_score: float = 5.0,
135
+ normalization: Optional[dict] = None,
136
+ pretrained_ss_dec: str = 'microsoft/TRELLIS-image-large/ckpts/ss_dec_conv3d_16l8_fp16',
137
+ ss_dec_path: Optional[str] = None,
138
+ ss_dec_ckpt: Optional[str] = None,
139
+ ):
140
+ self.latent_model = latent_model
141
+ self.min_aesthetic_score = min_aesthetic_score
142
+ self.normalization = normalization
143
+ self.value_range = (0, 1)
144
+
145
+ super().__init__(
146
+ roots,
147
+ pretrained_ss_dec=pretrained_ss_dec,
148
+ ss_dec_path=ss_dec_path,
149
+ ss_dec_ckpt=ss_dec_ckpt,
150
+ )
151
+
152
+ if self.normalization is not None:
153
+ self.mean = torch.tensor(self.normalization['mean']).reshape(-1, 1, 1, 1)
154
+ self.std = torch.tensor(self.normalization['std']).reshape(-1, 1, 1, 1)
155
+
156
+ def filter_metadata(self, metadata):
157
+ stats = {}
158
+ metadata = metadata[metadata[f'ss_latent_{self.latent_model}']]
159
+ stats['With sparse structure latents'] = len(metadata)
160
+ metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
161
+ stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
162
+ return metadata, stats
163
+
164
+ def get_instance(self, root, instance):
165
+ latent = np.load(os.path.join(root, 'ss_latents', self.latent_model, f'{instance}.npz'))
166
+ z = torch.tensor(latent['mean']).float()
167
+ if self.normalization is not None:
168
+ z = (z - self.mean) / self.std
169
+
170
+ pack = {
171
+ 'x_0': z,
172
+ }
173
+ return pack
174
+
175
+
176
+ class TextConditionedSparseStructureLatent(TextConditionedMixin, SparseStructureLatent):
177
+ """
178
+ Text-conditioned sparse structure dataset
179
+ """
180
+ pass
181
+
182
+
183
+ class ImageConditionedSparseStructureLatent(ImageConditionedMixin, SparseStructureLatent):
184
+ """
185
+ Image-conditioned sparse structure dataset
186
+ """
187
+ pass
188
+
trellis/datasets/structured_latent.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from typing import *
4
+ import numpy as np
5
+ import torch
6
+ import utils3d.torch
7
+ from .components import StandardDatasetBase, TextConditionedMixin, ImageConditionedMixin
8
+ from ..modules.sparse.basic import SparseTensor
9
+ from .. import models
10
+ from ..utils.render_utils import get_renderer
11
+ from ..utils.data_utils import load_balanced_group_indices
12
+
13
+
14
+ class SLatVisMixin:
15
+ def __init__(
16
+ self,
17
+ *args,
18
+ pretrained_slat_dec: str = 'microsoft/TRELLIS-image-large/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16',
19
+ slat_dec_path: Optional[str] = None,
20
+ slat_dec_ckpt: Optional[str] = None,
21
+ **kwargs
22
+ ):
23
+ super().__init__(*args, **kwargs)
24
+ self.slat_dec = None
25
+ self.pretrained_slat_dec = pretrained_slat_dec
26
+ self.slat_dec_path = slat_dec_path
27
+ self.slat_dec_ckpt = slat_dec_ckpt
28
+
29
+ def _loading_slat_dec(self):
30
+ if self.slat_dec is not None:
31
+ return
32
+ if self.slat_dec_path is not None:
33
+ cfg = json.load(open(os.path.join(self.slat_dec_path, 'config.json'), 'r'))
34
+ decoder = getattr(models, cfg['models']['decoder']['name'])(**cfg['models']['decoder']['args'])
35
+ ckpt_path = os.path.join(self.slat_dec_path, 'ckpts', f'decoder_{self.slat_dec_ckpt}.pt')
36
+ decoder.load_state_dict(torch.load(ckpt_path, map_location='cpu', weights_only=True))
37
+ else:
38
+ decoder = models.from_pretrained(self.pretrained_slat_dec)
39
+ self.slat_dec = decoder.cuda().eval()
40
+
41
+ def _delete_slat_dec(self):
42
+ del self.slat_dec
43
+ self.slat_dec = None
44
+
45
+ @torch.no_grad()
46
+ def decode_latent(self, z, batch_size=4):
47
+ self._loading_slat_dec()
48
+ reps = []
49
+ if self.normalization is not None:
50
+ z = z * self.std.to(z.device) + self.mean.to(z.device)
51
+ for i in range(0, z.shape[0], batch_size):
52
+ reps.append(self.slat_dec(z[i:i+batch_size]))
53
+ reps = sum(reps, [])
54
+ self._delete_slat_dec()
55
+ return reps
56
+
57
+ @torch.no_grad()
58
+ def visualize_sample(self, x_0: Union[SparseTensor, dict]):
59
+ x_0 = x_0 if isinstance(x_0, SparseTensor) else x_0['x_0']
60
+ reps = self.decode_latent(x_0.cuda())
61
+
62
+ # Build camera
63
+ yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
64
+ yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4)
65
+ yaws = [y + yaws_offset for y in yaws]
66
+ pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)]
67
+
68
+ exts = []
69
+ ints = []
70
+ for yaw, pitch in zip(yaws, pitch):
71
+ orig = torch.tensor([
72
+ np.sin(yaw) * np.cos(pitch),
73
+ np.cos(yaw) * np.cos(pitch),
74
+ np.sin(pitch),
75
+ ]).float().cuda() * 2
76
+ fov = torch.deg2rad(torch.tensor(40)).cuda()
77
+ extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
78
+ intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
79
+ exts.append(extrinsics)
80
+ ints.append(intrinsics)
81
+
82
+ renderer = get_renderer(reps[0])
83
+ images = []
84
+ for representation in reps:
85
+ image = torch.zeros(3, 1024, 1024).cuda()
86
+ tile = [2, 2]
87
+ for j, (ext, intr) in enumerate(zip(exts, ints)):
88
+ res = renderer.render(representation, ext, intr)
89
+ image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color']
90
+ images.append(image)
91
+ images = torch.stack(images)
92
+
93
+ return images
94
+
95
+
96
+ class SLat(SLatVisMixin, StandardDatasetBase):
97
+ """
98
+ structured latent dataset
99
+
100
+ Args:
101
+ roots (str): path to the dataset
102
+ latent_model (str): name of the latent model
103
+ min_aesthetic_score (float): minimum aesthetic score
104
+ max_num_voxels (int): maximum number of voxels
105
+ normalization (dict): normalization stats
106
+ pretrained_slat_dec (str): name of the pretrained slat decoder
107
+ slat_dec_path (str): path to the slat decoder, if given, will override the pretrained_slat_dec
108
+ slat_dec_ckpt (str): name of the slat decoder checkpoint
109
+ """
110
+ def __init__(self,
111
+ roots: str,
112
+ *,
113
+ latent_model: str,
114
+ min_aesthetic_score: float = 5.0,
115
+ max_num_voxels: int = 32768,
116
+ normalization: Optional[dict] = None,
117
+ pretrained_slat_dec: str = 'microsoft/TRELLIS-image-large/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16',
118
+ slat_dec_path: Optional[str] = None,
119
+ slat_dec_ckpt: Optional[str] = None,
120
+ ):
121
+ self.normalization = normalization
122
+ self.latent_model = latent_model
123
+ self.min_aesthetic_score = min_aesthetic_score
124
+ self.max_num_voxels = max_num_voxels
125
+ self.value_range = (0, 1)
126
+
127
+ super().__init__(
128
+ roots,
129
+ pretrained_slat_dec=pretrained_slat_dec,
130
+ slat_dec_path=slat_dec_path,
131
+ slat_dec_ckpt=slat_dec_ckpt,
132
+ )
133
+
134
+ self.loads = [self.metadata.loc[sha256, 'num_voxels'] for _, sha256 in self.instances]
135
+
136
+ if self.normalization is not None:
137
+ self.mean = torch.tensor(self.normalization['mean']).reshape(1, -1)
138
+ self.std = torch.tensor(self.normalization['std']).reshape(1, -1)
139
+
140
+ def filter_metadata(self, metadata):
141
+ stats = {}
142
+ metadata = metadata[metadata[f'latent_{self.latent_model}']]
143
+ stats['With latent'] = len(metadata)
144
+ metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
145
+ stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
146
+ metadata = metadata[metadata['num_voxels'] <= self.max_num_voxels]
147
+ stats[f'Num voxels <= {self.max_num_voxels}'] = len(metadata)
148
+ return metadata, stats
149
+
150
+ def get_instance(self, root, instance):
151
+ data = np.load(os.path.join(root, 'latents', self.latent_model, f'{instance}.npz'))
152
+ coords = torch.tensor(data['coords']).int()
153
+ feats = torch.tensor(data['feats']).float()
154
+ if self.normalization is not None:
155
+ feats = (feats - self.mean) / self.std
156
+ return {
157
+ 'coords': coords,
158
+ 'feats': feats,
159
+ }
160
+
161
+ @staticmethod
162
+ def collate_fn(batch, split_size=None):
163
+ if split_size is None:
164
+ group_idx = [list(range(len(batch)))]
165
+ else:
166
+ group_idx = load_balanced_group_indices([b['coords'].shape[0] for b in batch], split_size)
167
+ packs = []
168
+ for group in group_idx:
169
+ sub_batch = [batch[i] for i in group]
170
+ pack = {}
171
+ coords = []
172
+ feats = []
173
+ layout = []
174
+ start = 0
175
+ for i, b in enumerate(sub_batch):
176
+ coords.append(torch.cat([torch.full((b['coords'].shape[0], 1), i, dtype=torch.int32), b['coords']], dim=-1))
177
+ feats.append(b['feats'])
178
+ layout.append(slice(start, start + b['coords'].shape[0]))
179
+ start += b['coords'].shape[0]
180
+ coords = torch.cat(coords)
181
+ feats = torch.cat(feats)
182
+ pack['x_0'] = SparseTensor(
183
+ coords=coords,
184
+ feats=feats,
185
+ )
186
+ pack['x_0']._shape = torch.Size([len(group), *sub_batch[0]['feats'].shape[1:]])
187
+ pack['x_0'].register_spatial_cache('layout', layout)
188
+
189
+ # collate other data
190
+ keys = [k for k in sub_batch[0].keys() if k not in ['coords', 'feats']]
191
+ for k in keys:
192
+ if isinstance(sub_batch[0][k], torch.Tensor):
193
+ pack[k] = torch.stack([b[k] for b in sub_batch])
194
+ elif isinstance(sub_batch[0][k], list):
195
+ pack[k] = sum([b[k] for b in sub_batch], [])
196
+ else:
197
+ pack[k] = [b[k] for b in sub_batch]
198
+
199
+ packs.append(pack)
200
+
201
+ if split_size is None:
202
+ return packs[0]
203
+ return packs
204
+
205
+
206
+ class TextConditionedSLat(TextConditionedMixin, SLat):
207
+ """
208
+ Text conditioned structured latent dataset
209
+ """
210
+ pass
211
+
212
+
213
+ class ImageConditionedSLat(ImageConditionedMixin, SLat):
214
+ """
215
+ Image conditioned structured latent dataset
216
+ """
217
+ pass
trellis/datasets/structured_latent2render.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ import json
4
+ import numpy as np
5
+ import torch
6
+ import utils3d.torch
7
+ from ..modules.sparse.basic import SparseTensor
8
+ from .components import StandardDatasetBase
9
+
10
+
11
+ class SLat2Render(StandardDatasetBase):
12
+ """
13
+ Dataset for Structured Latent and rendered images.
14
+
15
+ Args:
16
+ roots (str): paths to the dataset
17
+ image_size (int): size of the image
18
+ latent_model (str): latent model name
19
+ min_aesthetic_score (float): minimum aesthetic score
20
+ max_num_voxels (int): maximum number of voxels
21
+ """
22
+ def __init__(
23
+ self,
24
+ roots: str,
25
+ image_size: int,
26
+ latent_model: str,
27
+ min_aesthetic_score: float = 5.0,
28
+ max_num_voxels: int = 32768,
29
+ ):
30
+ self.image_size = image_size
31
+ self.latent_model = latent_model
32
+ self.min_aesthetic_score = min_aesthetic_score
33
+ self.max_num_voxels = max_num_voxels
34
+ self.value_range = (0, 1)
35
+
36
+ super().__init__(roots)
37
+
38
+ def filter_metadata(self, metadata):
39
+ stats = {}
40
+ metadata = metadata[metadata[f'latent_{self.latent_model}']]
41
+ stats['With latent'] = len(metadata)
42
+ metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
43
+ stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
44
+ metadata = metadata[metadata['num_voxels'] <= self.max_num_voxels]
45
+ stats[f'Num voxels <= {self.max_num_voxels}'] = len(metadata)
46
+ return metadata, stats
47
+
48
+ def _get_image(self, root, instance):
49
+ with open(os.path.join(root, 'renders', instance, 'transforms.json')) as f:
50
+ metadata = json.load(f)
51
+ n_views = len(metadata['frames'])
52
+ view = np.random.randint(n_views)
53
+ metadata = metadata['frames'][view]
54
+ fov = metadata['camera_angle_x']
55
+ intrinsics = utils3d.torch.intrinsics_from_fov_xy(torch.tensor(fov), torch.tensor(fov))
56
+ c2w = torch.tensor(metadata['transform_matrix'])
57
+ c2w[:3, 1:3] *= -1
58
+ extrinsics = torch.inverse(c2w)
59
+
60
+ image_path = os.path.join(root, 'renders', instance, metadata['file_path'])
61
+ image = Image.open(image_path)
62
+ alpha = image.getchannel(3)
63
+ image = image.convert('RGB')
64
+ image = image.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS)
65
+ alpha = alpha.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS)
66
+ image = torch.tensor(np.array(image)).permute(2, 0, 1).float() / 255.0
67
+ alpha = torch.tensor(np.array(alpha)).float() / 255.0
68
+
69
+ return {
70
+ 'image': image,
71
+ 'alpha': alpha,
72
+ 'extrinsics': extrinsics,
73
+ 'intrinsics': intrinsics,
74
+ }
75
+
76
+ def _get_latent(self, root, instance):
77
+ data = np.load(os.path.join(root, 'latents', self.latent_model, f'{instance}.npz'))
78
+ coords = torch.tensor(data['coords']).int()
79
+ feats = torch.tensor(data['feats']).float()
80
+ return {
81
+ 'coords': coords,
82
+ 'feats': feats,
83
+ }
84
+
85
+ @torch.no_grad()
86
+ def visualize_sample(self, sample: dict):
87
+ return sample['image']
88
+
89
+ @staticmethod
90
+ def collate_fn(batch):
91
+ pack = {}
92
+ coords = []
93
+ for i, b in enumerate(batch):
94
+ coords.append(torch.cat([torch.full((b['coords'].shape[0], 1), i, dtype=torch.int32), b['coords']], dim=-1))
95
+ coords = torch.cat(coords)
96
+ feats = torch.cat([b['feats'] for b in batch])
97
+ pack['latents'] = SparseTensor(
98
+ coords=coords,
99
+ feats=feats,
100
+ )
101
+
102
+ # collate other data
103
+ keys = [k for k in batch[0].keys() if k not in ['coords', 'feats']]
104
+ for k in keys:
105
+ if isinstance(batch[0][k], torch.Tensor):
106
+ pack[k] = torch.stack([b[k] for b in batch])
107
+ elif isinstance(batch[0][k], list):
108
+ pack[k] = sum([b[k] for b in batch], [])
109
+ else:
110
+ pack[k] = [b[k] for b in batch]
111
+
112
+ return pack
113
+
114
+ def get_instance(self, root, instance):
115
+ image = self._get_image(root, instance)
116
+ latent = self._get_latent(root, instance)
117
+ return {
118
+ **image,
119
+ **latent,
120
+ }
121
+
122
+
123
+ class Slat2RenderGeo(SLat2Render):
124
+ def __init__(
125
+ self,
126
+ roots: str,
127
+ image_size: int,
128
+ latent_model: str,
129
+ min_aesthetic_score: float = 5.0,
130
+ max_num_voxels: int = 32768,
131
+ ):
132
+ super().__init__(
133
+ roots,
134
+ image_size,
135
+ latent_model,
136
+ min_aesthetic_score,
137
+ max_num_voxels,
138
+ )
139
+
140
+ def _get_geo(self, root, instance):
141
+ verts, face = utils3d.io.read_ply(os.path.join(root, 'renders', instance, 'mesh.ply'))
142
+ mesh = {
143
+ "vertices" : torch.from_numpy(verts),
144
+ "faces" : torch.from_numpy(face),
145
+ }
146
+ return {
147
+ "mesh" : mesh,
148
+ }
149
+
150
+ def get_instance(self, root, instance):
151
+ image = self._get_image(root, instance)
152
+ latent = self._get_latent(root, instance)
153
+ geo = self._get_geo(root, instance)
154
+ return {
155
+ **image,
156
+ **latent,
157
+ **geo,
158
+ }
159
+
160
+
trellis/models/__init__.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+
3
+ __attributes = {
4
+ 'SparseStructureEncoder': 'sparse_structure_vae',
5
+ 'SparseStructureDecoder': 'sparse_structure_vae',
6
+
7
+ 'SparseStructureFlowModel': 'sparse_structure_flow',
8
+
9
+ 'SLatEncoder': 'structured_latent_vae',
10
+ 'SLatGaussianDecoder': 'structured_latent_vae',
11
+ 'SLatRadianceFieldDecoder': 'structured_latent_vae',
12
+ 'SLatMeshDecoder': 'structured_latent_vae',
13
+ 'ElasticSLatEncoder': 'structured_latent_vae',
14
+ 'ElasticSLatGaussianDecoder': 'structured_latent_vae',
15
+ 'ElasticSLatRadianceFieldDecoder': 'structured_latent_vae',
16
+ 'ElasticSLatMeshDecoder': 'structured_latent_vae',
17
+
18
+ 'SLatFlowModel': 'structured_latent_flow',
19
+ 'ElasticSLatFlowModel': 'structured_latent_flow',
20
+ }
21
+
22
+ __submodules = []
23
+
24
+ __all__ = list(__attributes.keys()) + __submodules
25
+
26
+ def __getattr__(name):
27
+ if name not in globals():
28
+ if name in __attributes:
29
+ module_name = __attributes[name]
30
+ module = importlib.import_module(f".{module_name}", __name__)
31
+ globals()[name] = getattr(module, name)
32
+ elif name in __submodules:
33
+ module = importlib.import_module(f".{name}", __name__)
34
+ globals()[name] = module
35
+ else:
36
+ raise AttributeError(f"module {__name__} has no attribute {name}")
37
+ return globals()[name]
38
+
39
+
40
+ def from_pretrained(path: str, **kwargs):
41
+ """
42
+ Load a model from a pretrained checkpoint.
43
+
44
+ Args:
45
+ path: The path to the checkpoint. Can be either local path or a Hugging Face model name.
46
+ NOTE: config file and model file should take the name f'{path}.json' and f'{path}.safetensors' respectively.
47
+ **kwargs: Additional arguments for the model constructor.
48
+ """
49
+ import os
50
+ import json
51
+ from safetensors.torch import load_file
52
+ is_local = os.path.exists(f"{path}.json") and os.path.exists(f"{path}.safetensors")
53
+
54
+ if is_local:
55
+ config_file = f"{path}.json"
56
+ model_file = f"{path}.safetensors"
57
+ else:
58
+ from huggingface_hub import hf_hub_download
59
+ path_parts = path.split('/')
60
+ repo_id = f'{path_parts[0]}/{path_parts[1]}'
61
+ model_name = '/'.join(path_parts[2:])
62
+ config_file = hf_hub_download(repo_id, f"{model_name}.json")
63
+ model_file = hf_hub_download(repo_id, f"{model_name}.safetensors")
64
+
65
+ with open(config_file, 'r') as f:
66
+ config = json.load(f)
67
+ model = __getattr__(config['name'])(**config['args'], **kwargs)
68
+ model.load_state_dict(load_file(model_file))
69
+
70
+ return model
71
+
72
+
73
+ # For Pylance
74
+ if __name__ == '__main__':
75
+ from .sparse_structure_vae import (
76
+ SparseStructureEncoder,
77
+ SparseStructureDecoder,
78
+ )
79
+
80
+ from .sparse_structure_flow import SparseStructureFlowModel
81
+
82
+ from .structured_latent_vae import (
83
+ SLatEncoder,
84
+ SLatGaussianDecoder,
85
+ SLatRadianceFieldDecoder,
86
+ SLatMeshDecoder,
87
+ ElasticSLatEncoder,
88
+ ElasticSLatGaussianDecoder,
89
+ ElasticSLatRadianceFieldDecoder,
90
+ ElasticSLatMeshDecoder,
91
+ )
92
+
93
+ from .structured_latent_flow import (
94
+ SLatFlowModel,
95
+ ElasticSLatFlowModel,
96
+ )
trellis/models/sparse_elastic_mixin.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import contextmanager
2
+ from typing import *
3
+ import math
4
+ from ..modules import sparse as sp
5
+ from ..utils.elastic_utils import ElasticModuleMixin
6
+
7
+
8
+ class SparseTransformerElasticMixin(ElasticModuleMixin):
9
+ def _get_input_size(self, x: sp.SparseTensor, *args, **kwargs):
10
+ return x.feats.shape[0]
11
+
12
+ @contextmanager
13
+ def with_mem_ratio(self, mem_ratio=1.0):
14
+ if mem_ratio == 1.0:
15
+ yield 1.0
16
+ return
17
+ num_blocks = len(self.blocks)
18
+ num_checkpoint_blocks = min(math.ceil((1 - mem_ratio) * num_blocks) + 1, num_blocks)
19
+ exact_mem_ratio = 1 - (num_checkpoint_blocks - 1) / num_blocks
20
+ for i in range(num_blocks):
21
+ self.blocks[i].use_checkpoint = i < num_checkpoint_blocks
22
+ yield exact_mem_ratio
23
+ for i in range(num_blocks):
24
+ self.blocks[i].use_checkpoint = False
trellis/models/sparse_structure_flow.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ from ..modules.utils import convert_module_to_f16, convert_module_to_f32
7
+ from ..modules.transformer import AbsolutePositionEmbedder, ModulatedTransformerCrossBlock
8
+ from ..modules.spatial import patchify, unpatchify
9
+
10
+
11
+ class TimestepEmbedder(nn.Module):
12
+ """
13
+ Embeds scalar timesteps into vector representations.
14
+ """
15
+ def __init__(self, hidden_size, frequency_embedding_size=256):
16
+ super().__init__()
17
+ self.mlp = nn.Sequential(
18
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
19
+ nn.SiLU(),
20
+ nn.Linear(hidden_size, hidden_size, bias=True),
21
+ )
22
+ self.frequency_embedding_size = frequency_embedding_size
23
+
24
+ @staticmethod
25
+ def timestep_embedding(t, dim, max_period=10000):
26
+ """
27
+ Create sinusoidal timestep embeddings.
28
+
29
+ Args:
30
+ t: a 1-D Tensor of N indices, one per batch element.
31
+ These may be fractional.
32
+ dim: the dimension of the output.
33
+ max_period: controls the minimum frequency of the embeddings.
34
+
35
+ Returns:
36
+ an (N, D) Tensor of positional embeddings.
37
+ """
38
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
39
+ half = dim // 2
40
+ freqs = torch.exp(
41
+ -np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
42
+ ).to(device=t.device)
43
+ args = t[:, None].float() * freqs[None]
44
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
45
+ if dim % 2:
46
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
47
+ return embedding
48
+
49
+ def forward(self, t):
50
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
51
+ t_emb = self.mlp(t_freq)
52
+ return t_emb
53
+
54
+
55
+ class SparseStructureFlowModel(nn.Module):
56
+ def __init__(
57
+ self,
58
+ resolution: int,
59
+ in_channels: int,
60
+ model_channels: int,
61
+ cond_channels: int,
62
+ out_channels: int,
63
+ num_blocks: int,
64
+ num_heads: Optional[int] = None,
65
+ num_head_channels: Optional[int] = 64,
66
+ mlp_ratio: float = 4,
67
+ patch_size: int = 2,
68
+ pe_mode: Literal["ape", "rope"] = "ape",
69
+ use_fp16: bool = False,
70
+ use_checkpoint: bool = False,
71
+ share_mod: bool = False,
72
+ qk_rms_norm: bool = False,
73
+ qk_rms_norm_cross: bool = False,
74
+ ):
75
+ super().__init__()
76
+ self.resolution = resolution
77
+ self.in_channels = in_channels
78
+ self.model_channels = model_channels
79
+ self.cond_channels = cond_channels
80
+ self.out_channels = out_channels
81
+ self.num_blocks = num_blocks
82
+ self.num_heads = num_heads or model_channels // num_head_channels
83
+ self.mlp_ratio = mlp_ratio
84
+ self.patch_size = patch_size
85
+ self.pe_mode = pe_mode
86
+ self.use_fp16 = use_fp16
87
+ self.use_checkpoint = use_checkpoint
88
+ self.share_mod = share_mod
89
+ self.qk_rms_norm = qk_rms_norm
90
+ self.qk_rms_norm_cross = qk_rms_norm_cross
91
+ self.dtype = torch.float16 if use_fp16 else torch.float32
92
+
93
+ self.t_embedder = TimestepEmbedder(model_channels)
94
+ if share_mod:
95
+ self.adaLN_modulation = nn.Sequential(
96
+ nn.SiLU(),
97
+ nn.Linear(model_channels, 6 * model_channels, bias=True)
98
+ )
99
+
100
+ if pe_mode == "ape":
101
+ pos_embedder = AbsolutePositionEmbedder(model_channels, 3)
102
+ coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution // patch_size] * 3], indexing='ij')
103
+ coords = torch.stack(coords, dim=-1).reshape(-1, 3)
104
+ pos_emb = pos_embedder(coords)
105
+ self.register_buffer("pos_emb", pos_emb)
106
+
107
+ self.input_layer = nn.Linear(in_channels * patch_size**3, model_channels)
108
+
109
+ self.blocks = nn.ModuleList([
110
+ ModulatedTransformerCrossBlock(
111
+ model_channels,
112
+ cond_channels,
113
+ num_heads=self.num_heads,
114
+ mlp_ratio=self.mlp_ratio,
115
+ attn_mode='full',
116
+ use_checkpoint=self.use_checkpoint,
117
+ use_rope=(pe_mode == "rope"),
118
+ share_mod=share_mod,
119
+ qk_rms_norm=self.qk_rms_norm,
120
+ qk_rms_norm_cross=self.qk_rms_norm_cross,
121
+ )
122
+ for _ in range(num_blocks)
123
+ ])
124
+
125
+ self.out_layer = nn.Linear(model_channels, out_channels * patch_size**3)
126
+
127
+ self.initialize_weights()
128
+ if use_fp16:
129
+ self.convert_to_fp16()
130
+
131
+ @property
132
+ def device(self) -> torch.device:
133
+ """
134
+ Return the device of the model.
135
+ """
136
+ return next(self.parameters()).device
137
+
138
+ def convert_to_fp16(self) -> None:
139
+ """
140
+ Convert the torso of the model to float16.
141
+ """
142
+ self.blocks.apply(convert_module_to_f16)
143
+
144
+ def convert_to_fp32(self) -> None:
145
+ """
146
+ Convert the torso of the model to float32.
147
+ """
148
+ self.blocks.apply(convert_module_to_f32)
149
+
150
+ def initialize_weights(self) -> None:
151
+ # Initialize transformer layers:
152
+ def _basic_init(module):
153
+ if isinstance(module, nn.Linear):
154
+ torch.nn.init.xavier_uniform_(module.weight)
155
+ if module.bias is not None:
156
+ nn.init.constant_(module.bias, 0)
157
+ self.apply(_basic_init)
158
+
159
+ # Initialize timestep embedding MLP:
160
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
161
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
162
+
163
+ # Zero-out adaLN modulation layers in DiT blocks:
164
+ if self.share_mod:
165
+ nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
166
+ nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
167
+ else:
168
+ for block in self.blocks:
169
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
170
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
171
+
172
+ # Zero-out output layers:
173
+ nn.init.constant_(self.out_layer.weight, 0)
174
+ nn.init.constant_(self.out_layer.bias, 0)
175
+
176
+ def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
177
+ assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \
178
+ f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}"
179
+
180
+ h = patchify(x, self.patch_size)
181
+ h = h.view(*h.shape[:2], -1).permute(0, 2, 1).contiguous()
182
+
183
+ h = self.input_layer(h)
184
+ h = h + self.pos_emb[None]
185
+ t_emb = self.t_embedder(t)
186
+ if self.share_mod:
187
+ t_emb = self.adaLN_modulation(t_emb)
188
+ t_emb = t_emb.type(self.dtype)
189
+ h = h.type(self.dtype)
190
+ cond = cond.type(self.dtype)
191
+ for block in self.blocks:
192
+ h = block(h, t_emb, cond)
193
+ h = h.type(x.dtype)
194
+ h = F.layer_norm(h, h.shape[-1:])
195
+ h = self.out_layer(h)
196
+
197
+ h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution // self.patch_size] * 3)
198
+ h = unpatchify(h, self.patch_size).contiguous()
199
+
200
+ return h
trellis/models/sparse_structure_vae.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from ..modules.norm import GroupNorm32, ChannelLayerNorm32
6
+ from ..modules.spatial import pixel_shuffle_3d
7
+ from ..modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32
8
+
9
+
10
+ def norm_layer(norm_type: str, *args, **kwargs) -> nn.Module:
11
+ """
12
+ Return a normalization layer.
13
+ """
14
+ if norm_type == "group":
15
+ return GroupNorm32(32, *args, **kwargs)
16
+ elif norm_type == "layer":
17
+ return ChannelLayerNorm32(*args, **kwargs)
18
+ else:
19
+ raise ValueError(f"Invalid norm type {norm_type}")
20
+
21
+
22
+ class ResBlock3d(nn.Module):
23
+ def __init__(
24
+ self,
25
+ channels: int,
26
+ out_channels: Optional[int] = None,
27
+ norm_type: Literal["group", "layer"] = "layer",
28
+ ):
29
+ super().__init__()
30
+ self.channels = channels
31
+ self.out_channels = out_channels or channels
32
+
33
+ self.norm1 = norm_layer(norm_type, channels)
34
+ self.norm2 = norm_layer(norm_type, self.out_channels)
35
+ self.conv1 = nn.Conv3d(channels, self.out_channels, 3, padding=1)
36
+ self.conv2 = zero_module(nn.Conv3d(self.out_channels, self.out_channels, 3, padding=1))
37
+ self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity()
38
+
39
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
40
+ h = self.norm1(x)
41
+ h = F.silu(h)
42
+ h = self.conv1(h)
43
+ h = self.norm2(h)
44
+ h = F.silu(h)
45
+ h = self.conv2(h)
46
+ h = h + self.skip_connection(x)
47
+ return h
48
+
49
+
50
+ class DownsampleBlock3d(nn.Module):
51
+ def __init__(
52
+ self,
53
+ in_channels: int,
54
+ out_channels: int,
55
+ mode: Literal["conv", "avgpool"] = "conv",
56
+ ):
57
+ assert mode in ["conv", "avgpool"], f"Invalid mode {mode}"
58
+
59
+ super().__init__()
60
+ self.in_channels = in_channels
61
+ self.out_channels = out_channels
62
+
63
+ if mode == "conv":
64
+ self.conv = nn.Conv3d(in_channels, out_channels, 2, stride=2)
65
+ elif mode == "avgpool":
66
+ assert in_channels == out_channels, "Pooling mode requires in_channels to be equal to out_channels"
67
+
68
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
69
+ if hasattr(self, "conv"):
70
+ return self.conv(x)
71
+ else:
72
+ return F.avg_pool3d(x, 2)
73
+
74
+
75
+ class UpsampleBlock3d(nn.Module):
76
+ def __init__(
77
+ self,
78
+ in_channels: int,
79
+ out_channels: int,
80
+ mode: Literal["conv", "nearest"] = "conv",
81
+ ):
82
+ assert mode in ["conv", "nearest"], f"Invalid mode {mode}"
83
+
84
+ super().__init__()
85
+ self.in_channels = in_channels
86
+ self.out_channels = out_channels
87
+
88
+ if mode == "conv":
89
+ self.conv = nn.Conv3d(in_channels, out_channels*8, 3, padding=1)
90
+ elif mode == "nearest":
91
+ assert in_channels == out_channels, "Nearest mode requires in_channels to be equal to out_channels"
92
+
93
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
94
+ if hasattr(self, "conv"):
95
+ x = self.conv(x)
96
+ return pixel_shuffle_3d(x, 2)
97
+ else:
98
+ return F.interpolate(x, scale_factor=2, mode="nearest")
99
+
100
+
101
+ class SparseStructureEncoder(nn.Module):
102
+ """
103
+ Encoder for Sparse Structure (\mathcal{E}_S in the paper Sec. 3.3).
104
+
105
+ Args:
106
+ in_channels (int): Channels of the input.
107
+ latent_channels (int): Channels of the latent representation.
108
+ num_res_blocks (int): Number of residual blocks at each resolution.
109
+ channels (List[int]): Channels of the encoder blocks.
110
+ num_res_blocks_middle (int): Number of residual blocks in the middle.
111
+ norm_type (Literal["group", "layer"]): Type of normalization layer.
112
+ use_fp16 (bool): Whether to use FP16.
113
+ """
114
+ def __init__(
115
+ self,
116
+ in_channels: int,
117
+ latent_channels: int,
118
+ num_res_blocks: int,
119
+ channels: List[int],
120
+ num_res_blocks_middle: int = 2,
121
+ norm_type: Literal["group", "layer"] = "layer",
122
+ use_fp16: bool = False,
123
+ ):
124
+ super().__init__()
125
+ self.in_channels = in_channels
126
+ self.latent_channels = latent_channels
127
+ self.num_res_blocks = num_res_blocks
128
+ self.channels = channels
129
+ self.num_res_blocks_middle = num_res_blocks_middle
130
+ self.norm_type = norm_type
131
+ self.use_fp16 = use_fp16
132
+ self.dtype = torch.float16 if use_fp16 else torch.float32
133
+
134
+ self.input_layer = nn.Conv3d(in_channels, channels[0], 3, padding=1)
135
+
136
+ self.blocks = nn.ModuleList([])
137
+ for i, ch in enumerate(channels):
138
+ self.blocks.extend([
139
+ ResBlock3d(ch, ch)
140
+ for _ in range(num_res_blocks)
141
+ ])
142
+ if i < len(channels) - 1:
143
+ self.blocks.append(
144
+ DownsampleBlock3d(ch, channels[i+1])
145
+ )
146
+
147
+ self.middle_block = nn.Sequential(*[
148
+ ResBlock3d(channels[-1], channels[-1])
149
+ for _ in range(num_res_blocks_middle)
150
+ ])
151
+
152
+ self.out_layer = nn.Sequential(
153
+ norm_layer(norm_type, channels[-1]),
154
+ nn.SiLU(),
155
+ nn.Conv3d(channels[-1], latent_channels*2, 3, padding=1)
156
+ )
157
+
158
+ if use_fp16:
159
+ self.convert_to_fp16()
160
+
161
+ @property
162
+ def device(self) -> torch.device:
163
+ """
164
+ Return the device of the model.
165
+ """
166
+ return next(self.parameters()).device
167
+
168
+ def convert_to_fp16(self) -> None:
169
+ """
170
+ Convert the torso of the model to float16.
171
+ """
172
+ self.use_fp16 = True
173
+ self.dtype = torch.float16
174
+ self.blocks.apply(convert_module_to_f16)
175
+ self.middle_block.apply(convert_module_to_f16)
176
+
177
+ def convert_to_fp32(self) -> None:
178
+ """
179
+ Convert the torso of the model to float32.
180
+ """
181
+ self.use_fp16 = False
182
+ self.dtype = torch.float32
183
+ self.blocks.apply(convert_module_to_f32)
184
+ self.middle_block.apply(convert_module_to_f32)
185
+
186
+ def forward(self, x: torch.Tensor, sample_posterior: bool = False, return_raw: bool = False) -> torch.Tensor:
187
+ h = self.input_layer(x)
188
+ h = h.type(self.dtype)
189
+
190
+ for block in self.blocks:
191
+ h = block(h)
192
+ h = self.middle_block(h)
193
+
194
+ h = h.type(x.dtype)
195
+ h = self.out_layer(h)
196
+
197
+ mean, logvar = h.chunk(2, dim=1)
198
+
199
+ if sample_posterior:
200
+ std = torch.exp(0.5 * logvar)
201
+ z = mean + std * torch.randn_like(std)
202
+ else:
203
+ z = mean
204
+
205
+ if return_raw:
206
+ return z, mean, logvar
207
+ return z
208
+
209
+
210
+ class SparseStructureDecoder(nn.Module):
211
+ """
212
+ Decoder for Sparse Structure (\mathcal{D}_S in the paper Sec. 3.3).
213
+
214
+ Args:
215
+ out_channels (int): Channels of the output.
216
+ latent_channels (int): Channels of the latent representation.
217
+ num_res_blocks (int): Number of residual blocks at each resolution.
218
+ channels (List[int]): Channels of the decoder blocks.
219
+ num_res_blocks_middle (int): Number of residual blocks in the middle.
220
+ norm_type (Literal["group", "layer"]): Type of normalization layer.
221
+ use_fp16 (bool): Whether to use FP16.
222
+ """
223
+ def __init__(
224
+ self,
225
+ out_channels: int,
226
+ latent_channels: int,
227
+ num_res_blocks: int,
228
+ channels: List[int],
229
+ num_res_blocks_middle: int = 2,
230
+ norm_type: Literal["group", "layer"] = "layer",
231
+ use_fp16: bool = False,
232
+ ):
233
+ super().__init__()
234
+ self.out_channels = out_channels
235
+ self.latent_channels = latent_channels
236
+ self.num_res_blocks = num_res_blocks
237
+ self.channels = channels
238
+ self.num_res_blocks_middle = num_res_blocks_middle
239
+ self.norm_type = norm_type
240
+ self.use_fp16 = use_fp16
241
+ self.dtype = torch.float16 if use_fp16 else torch.float32
242
+
243
+ self.input_layer = nn.Conv3d(latent_channels, channels[0], 3, padding=1)
244
+
245
+ self.middle_block = nn.Sequential(*[
246
+ ResBlock3d(channels[0], channels[0])
247
+ for _ in range(num_res_blocks_middle)
248
+ ])
249
+
250
+ self.blocks = nn.ModuleList([])
251
+ for i, ch in enumerate(channels):
252
+ self.blocks.extend([
253
+ ResBlock3d(ch, ch)
254
+ for _ in range(num_res_blocks)
255
+ ])
256
+ if i < len(channels) - 1:
257
+ self.blocks.append(
258
+ UpsampleBlock3d(ch, channels[i+1])
259
+ )
260
+
261
+ self.out_layer = nn.Sequential(
262
+ norm_layer(norm_type, channels[-1]),
263
+ nn.SiLU(),
264
+ nn.Conv3d(channels[-1], out_channels, 3, padding=1)
265
+ )
266
+
267
+ if use_fp16:
268
+ self.convert_to_fp16()
269
+
270
+ @property
271
+ def device(self) -> torch.device:
272
+ """
273
+ Return the device of the model.
274
+ """
275
+ return next(self.parameters()).device
276
+
277
+ def convert_to_fp16(self) -> None:
278
+ """
279
+ Convert the torso of the model to float16.
280
+ """
281
+ self.use_fp16 = True
282
+ self.dtype = torch.float16
283
+ self.blocks.apply(convert_module_to_f16)
284
+ self.middle_block.apply(convert_module_to_f16)
285
+
286
+ def convert_to_fp32(self) -> None:
287
+ """
288
+ Convert the torso of the model to float32.
289
+ """
290
+ self.use_fp16 = False
291
+ self.dtype = torch.float32
292
+ self.blocks.apply(convert_module_to_f32)
293
+ self.middle_block.apply(convert_module_to_f32)
294
+
295
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
296
+ h = self.input_layer(x)
297
+
298
+ h = h.type(self.dtype)
299
+
300
+ h = self.middle_block(h)
301
+ for block in self.blocks:
302
+ h = block(h)
303
+
304
+ h = h.type(x.dtype)
305
+ h = self.out_layer(h)
306
+ return h
trellis/models/structured_latent_flow.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ from ..modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32
7
+ from ..modules.transformer import AbsolutePositionEmbedder
8
+ from ..modules.norm import LayerNorm32
9
+ from ..modules import sparse as sp
10
+ from ..modules.sparse.transformer import ModulatedSparseTransformerCrossBlock
11
+ from .sparse_structure_flow import TimestepEmbedder
12
+ from .sparse_elastic_mixin import SparseTransformerElasticMixin
13
+
14
+
15
+ class SparseResBlock3d(nn.Module):
16
+ def __init__(
17
+ self,
18
+ channels: int,
19
+ emb_channels: int,
20
+ out_channels: Optional[int] = None,
21
+ downsample: bool = False,
22
+ upsample: bool = False,
23
+ ):
24
+ super().__init__()
25
+ self.channels = channels
26
+ self.emb_channels = emb_channels
27
+ self.out_channels = out_channels or channels
28
+ self.downsample = downsample
29
+ self.upsample = upsample
30
+
31
+ assert not (downsample and upsample), "Cannot downsample and upsample at the same time"
32
+
33
+ self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
34
+ self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6)
35
+ self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3)
36
+ self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3))
37
+ self.emb_layers = nn.Sequential(
38
+ nn.SiLU(),
39
+ nn.Linear(emb_channels, 2 * self.out_channels, bias=True),
40
+ )
41
+ self.skip_connection = sp.SparseLinear(channels, self.out_channels) if channels != self.out_channels else nn.Identity()
42
+ self.updown = None
43
+ if self.downsample:
44
+ self.updown = sp.SparseDownsample(2)
45
+ elif self.upsample:
46
+ self.updown = sp.SparseUpsample(2)
47
+
48
+ def _updown(self, x: sp.SparseTensor) -> sp.SparseTensor:
49
+ if self.updown is not None:
50
+ x = self.updown(x)
51
+ return x
52
+
53
+ def forward(self, x: sp.SparseTensor, emb: torch.Tensor) -> sp.SparseTensor:
54
+ emb_out = self.emb_layers(emb).type(x.dtype)
55
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
56
+
57
+ x = self._updown(x)
58
+ h = x.replace(self.norm1(x.feats))
59
+ h = h.replace(F.silu(h.feats))
60
+ h = self.conv1(h)
61
+ h = h.replace(self.norm2(h.feats)) * (1 + scale) + shift
62
+ h = h.replace(F.silu(h.feats))
63
+ h = self.conv2(h)
64
+ h = h + self.skip_connection(x)
65
+
66
+ return h
67
+
68
+
69
+ class SLatFlowModel(nn.Module):
70
+ def __init__(
71
+ self,
72
+ resolution: int,
73
+ in_channels: int,
74
+ model_channels: int,
75
+ cond_channels: int,
76
+ out_channels: int,
77
+ num_blocks: int,
78
+ num_heads: Optional[int] = None,
79
+ num_head_channels: Optional[int] = 64,
80
+ mlp_ratio: float = 4,
81
+ patch_size: int = 2,
82
+ num_io_res_blocks: int = 2,
83
+ io_block_channels: List[int] = None,
84
+ pe_mode: Literal["ape", "rope"] = "ape",
85
+ use_fp16: bool = False,
86
+ use_checkpoint: bool = False,
87
+ use_skip_connection: bool = True,
88
+ share_mod: bool = False,
89
+ qk_rms_norm: bool = False,
90
+ qk_rms_norm_cross: bool = False,
91
+ ):
92
+ super().__init__()
93
+ self.resolution = resolution
94
+ self.in_channels = in_channels
95
+ self.model_channels = model_channels
96
+ self.cond_channels = cond_channels
97
+ self.out_channels = out_channels
98
+ self.num_blocks = num_blocks
99
+ self.num_heads = num_heads or model_channels // num_head_channels
100
+ self.mlp_ratio = mlp_ratio
101
+ self.patch_size = patch_size
102
+ self.num_io_res_blocks = num_io_res_blocks
103
+ self.io_block_channels = io_block_channels
104
+ self.pe_mode = pe_mode
105
+ self.use_fp16 = use_fp16
106
+ self.use_checkpoint = use_checkpoint
107
+ self.use_skip_connection = use_skip_connection
108
+ self.share_mod = share_mod
109
+ self.qk_rms_norm = qk_rms_norm
110
+ self.qk_rms_norm_cross = qk_rms_norm_cross
111
+ self.dtype = torch.float16 if use_fp16 else torch.float32
112
+
113
+ if self.io_block_channels is not None:
114
+ assert int(np.log2(patch_size)) == np.log2(patch_size), "Patch size must be a power of 2"
115
+ assert np.log2(patch_size) == len(io_block_channels), "Number of IO ResBlocks must match the number of stages"
116
+
117
+ self.t_embedder = TimestepEmbedder(model_channels)
118
+ if share_mod:
119
+ self.adaLN_modulation = nn.Sequential(
120
+ nn.SiLU(),
121
+ nn.Linear(model_channels, 6 * model_channels, bias=True)
122
+ )
123
+
124
+ if pe_mode == "ape":
125
+ self.pos_embedder = AbsolutePositionEmbedder(model_channels)
126
+
127
+ self.input_layer = sp.SparseLinear(in_channels, model_channels if io_block_channels is None else io_block_channels[0])
128
+
129
+ self.input_blocks = nn.ModuleList([])
130
+ if io_block_channels is not None:
131
+ for chs, next_chs in zip(io_block_channels, io_block_channels[1:] + [model_channels]):
132
+ self.input_blocks.extend([
133
+ SparseResBlock3d(
134
+ chs,
135
+ model_channels,
136
+ out_channels=chs,
137
+ )
138
+ for _ in range(num_io_res_blocks-1)
139
+ ])
140
+ self.input_blocks.append(
141
+ SparseResBlock3d(
142
+ chs,
143
+ model_channels,
144
+ out_channels=next_chs,
145
+ downsample=True,
146
+ )
147
+ )
148
+
149
+ self.blocks = nn.ModuleList([
150
+ ModulatedSparseTransformerCrossBlock(
151
+ model_channels,
152
+ cond_channels,
153
+ num_heads=self.num_heads,
154
+ mlp_ratio=self.mlp_ratio,
155
+ attn_mode='full',
156
+ use_checkpoint=self.use_checkpoint,
157
+ use_rope=(pe_mode == "rope"),
158
+ share_mod=self.share_mod,
159
+ qk_rms_norm=self.qk_rms_norm,
160
+ qk_rms_norm_cross=self.qk_rms_norm_cross,
161
+ )
162
+ for _ in range(num_blocks)
163
+ ])
164
+
165
+ self.out_blocks = nn.ModuleList([])
166
+ if io_block_channels is not None:
167
+ for chs, prev_chs in zip(reversed(io_block_channels), [model_channels] + list(reversed(io_block_channels[1:]))):
168
+ self.out_blocks.append(
169
+ SparseResBlock3d(
170
+ prev_chs * 2 if self.use_skip_connection else prev_chs,
171
+ model_channels,
172
+ out_channels=chs,
173
+ upsample=True,
174
+ )
175
+ )
176
+ self.out_blocks.extend([
177
+ SparseResBlock3d(
178
+ chs * 2 if self.use_skip_connection else chs,
179
+ model_channels,
180
+ out_channels=chs,
181
+ )
182
+ for _ in range(num_io_res_blocks-1)
183
+ ])
184
+
185
+ self.out_layer = sp.SparseLinear(model_channels if io_block_channels is None else io_block_channels[0], out_channels)
186
+
187
+ self.initialize_weights()
188
+ if use_fp16:
189
+ self.convert_to_fp16()
190
+
191
+ @property
192
+ def device(self) -> torch.device:
193
+ """
194
+ Return the device of the model.
195
+ """
196
+ return next(self.parameters()).device
197
+
198
+ def convert_to_fp16(self) -> None:
199
+ """
200
+ Convert the torso of the model to float16.
201
+ """
202
+ self.input_blocks.apply(convert_module_to_f16)
203
+ self.blocks.apply(convert_module_to_f16)
204
+ self.out_blocks.apply(convert_module_to_f16)
205
+
206
+ def convert_to_fp32(self) -> None:
207
+ """
208
+ Convert the torso of the model to float32.
209
+ """
210
+ self.input_blocks.apply(convert_module_to_f32)
211
+ self.blocks.apply(convert_module_to_f32)
212
+ self.out_blocks.apply(convert_module_to_f32)
213
+
214
+ def initialize_weights(self) -> None:
215
+ # Initialize transformer layers:
216
+ def _basic_init(module):
217
+ if isinstance(module, nn.Linear):
218
+ torch.nn.init.xavier_uniform_(module.weight)
219
+ if module.bias is not None:
220
+ nn.init.constant_(module.bias, 0)
221
+ self.apply(_basic_init)
222
+
223
+ # Initialize timestep embedding MLP:
224
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
225
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
226
+
227
+ # Zero-out adaLN modulation layers in DiT blocks:
228
+ if self.share_mod:
229
+ nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
230
+ nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
231
+ else:
232
+ for block in self.blocks:
233
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
234
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
235
+
236
+ # Zero-out output layers:
237
+ nn.init.constant_(self.out_layer.weight, 0)
238
+ nn.init.constant_(self.out_layer.bias, 0)
239
+
240
+ def forward(self, x: sp.SparseTensor, t: torch.Tensor, cond: torch.Tensor) -> sp.SparseTensor:
241
+ h = self.input_layer(x).type(self.dtype)
242
+ t_emb = self.t_embedder(t)
243
+ if self.share_mod:
244
+ t_emb = self.adaLN_modulation(t_emb)
245
+ t_emb = t_emb.type(self.dtype)
246
+ cond = cond.type(self.dtype)
247
+
248
+ skips = []
249
+ # pack with input blocks
250
+ for block in self.input_blocks:
251
+ h = block(h, t_emb)
252
+ skips.append(h.feats)
253
+
254
+ if self.pe_mode == "ape":
255
+ h = h + self.pos_embedder(h.coords[:, 1:]).type(self.dtype)
256
+ for block in self.blocks:
257
+ h = block(h, t_emb, cond)
258
+
259
+ # unpack with output blocks
260
+ for block, skip in zip(self.out_blocks, reversed(skips)):
261
+ if self.use_skip_connection:
262
+ h = block(h.replace(torch.cat([h.feats, skip], dim=1)), t_emb)
263
+ else:
264
+ h = block(h, t_emb)
265
+
266
+ h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
267
+ h = self.out_layer(h.type(x.dtype))
268
+ return h
269
+
270
+
271
+ class ElasticSLatFlowModel(SparseTransformerElasticMixin, SLatFlowModel):
272
+ """
273
+ SLat Flow Model with elastic memory management.
274
+ Used for training with low VRAM.
275
+ """
276
+ pass
trellis/models/structured_latent_vae/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .encoder import SLatEncoder, ElasticSLatEncoder
2
+ from .decoder_gs import SLatGaussianDecoder, ElasticSLatGaussianDecoder
3
+ from .decoder_rf import SLatRadianceFieldDecoder, ElasticSLatRadianceFieldDecoder
4
+ from .decoder_mesh import SLatMeshDecoder, ElasticSLatMeshDecoder
trellis/models/structured_latent_vae/base.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ from ...modules.utils import convert_module_to_f16, convert_module_to_f32
5
+ from ...modules import sparse as sp
6
+ from ...modules.transformer import AbsolutePositionEmbedder
7
+ from ...modules.sparse.transformer import SparseTransformerBlock
8
+
9
+
10
+ def block_attn_config(self):
11
+ """
12
+ Return the attention configuration of the model.
13
+ """
14
+ for i in range(self.num_blocks):
15
+ if self.attn_mode == "shift_window":
16
+ yield "serialized", self.window_size, 0, (16 * (i % 2),) * 3, sp.SerializeMode.Z_ORDER
17
+ elif self.attn_mode == "shift_sequence":
18
+ yield "serialized", self.window_size, self.window_size // 2 * (i % 2), (0, 0, 0), sp.SerializeMode.Z_ORDER
19
+ elif self.attn_mode == "shift_order":
20
+ yield "serialized", self.window_size, 0, (0, 0, 0), sp.SerializeModes[i % 4]
21
+ elif self.attn_mode == "full":
22
+ yield "full", None, None, None, None
23
+ elif self.attn_mode == "swin":
24
+ yield "windowed", self.window_size, None, self.window_size // 2 * (i % 2), None
25
+
26
+
27
+ class SparseTransformerBase(nn.Module):
28
+ """
29
+ Sparse Transformer without output layers.
30
+ Serve as the base class for encoder and decoder.
31
+ """
32
+ def __init__(
33
+ self,
34
+ in_channels: int,
35
+ model_channels: int,
36
+ num_blocks: int,
37
+ num_heads: Optional[int] = None,
38
+ num_head_channels: Optional[int] = 64,
39
+ mlp_ratio: float = 4.0,
40
+ attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
41
+ window_size: Optional[int] = None,
42
+ pe_mode: Literal["ape", "rope"] = "ape",
43
+ use_fp16: bool = False,
44
+ use_checkpoint: bool = False,
45
+ qk_rms_norm: bool = False,
46
+ ):
47
+ super().__init__()
48
+ self.in_channels = in_channels
49
+ self.model_channels = model_channels
50
+ self.num_blocks = num_blocks
51
+ self.window_size = window_size
52
+ self.num_heads = num_heads or model_channels // num_head_channels
53
+ self.mlp_ratio = mlp_ratio
54
+ self.attn_mode = attn_mode
55
+ self.pe_mode = pe_mode
56
+ self.use_fp16 = use_fp16
57
+ self.use_checkpoint = use_checkpoint
58
+ self.qk_rms_norm = qk_rms_norm
59
+ self.dtype = torch.float16 if use_fp16 else torch.float32
60
+
61
+ if pe_mode == "ape":
62
+ self.pos_embedder = AbsolutePositionEmbedder(model_channels)
63
+
64
+ self.input_layer = sp.SparseLinear(in_channels, model_channels)
65
+ self.blocks = nn.ModuleList([
66
+ SparseTransformerBlock(
67
+ model_channels,
68
+ num_heads=self.num_heads,
69
+ mlp_ratio=self.mlp_ratio,
70
+ attn_mode=attn_mode,
71
+ window_size=window_size,
72
+ shift_sequence=shift_sequence,
73
+ shift_window=shift_window,
74
+ serialize_mode=serialize_mode,
75
+ use_checkpoint=self.use_checkpoint,
76
+ use_rope=(pe_mode == "rope"),
77
+ qk_rms_norm=self.qk_rms_norm,
78
+ )
79
+ for attn_mode, window_size, shift_sequence, shift_window, serialize_mode in block_attn_config(self)
80
+ ])
81
+
82
+ @property
83
+ def device(self) -> torch.device:
84
+ """
85
+ Return the device of the model.
86
+ """
87
+ return next(self.parameters()).device
88
+
89
+ def convert_to_fp16(self) -> None:
90
+ """
91
+ Convert the torso of the model to float16.
92
+ """
93
+ self.blocks.apply(convert_module_to_f16)
94
+
95
+ def convert_to_fp32(self) -> None:
96
+ """
97
+ Convert the torso of the model to float32.
98
+ """
99
+ self.blocks.apply(convert_module_to_f32)
100
+
101
+ def initialize_weights(self) -> None:
102
+ # Initialize transformer layers:
103
+ def _basic_init(module):
104
+ if isinstance(module, nn.Linear):
105
+ torch.nn.init.xavier_uniform_(module.weight)
106
+ if module.bias is not None:
107
+ nn.init.constant_(module.bias, 0)
108
+ self.apply(_basic_init)
109
+
110
+ def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
111
+ h = self.input_layer(x)
112
+ if self.pe_mode == "ape":
113
+ h = h + self.pos_embedder(x.coords[:, 1:])
114
+ h = h.type(self.dtype)
115
+ for block in self.blocks:
116
+ h = block(h)
117
+ return h
trellis/models/structured_latent_vae/decoder_gs.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from ...modules import sparse as sp
6
+ from ...utils.random_utils import hammersley_sequence
7
+ from .base import SparseTransformerBase
8
+ from ...representations import Gaussian
9
+ from ..sparse_elastic_mixin import SparseTransformerElasticMixin
10
+
11
+
12
+ class SLatGaussianDecoder(SparseTransformerBase):
13
+ def __init__(
14
+ self,
15
+ resolution: int,
16
+ model_channels: int,
17
+ latent_channels: int,
18
+ num_blocks: int,
19
+ num_heads: Optional[int] = None,
20
+ num_head_channels: Optional[int] = 64,
21
+ mlp_ratio: float = 4,
22
+ attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
23
+ window_size: int = 8,
24
+ pe_mode: Literal["ape", "rope"] = "ape",
25
+ use_fp16: bool = False,
26
+ use_checkpoint: bool = False,
27
+ qk_rms_norm: bool = False,
28
+ representation_config: dict = None,
29
+ ):
30
+ super().__init__(
31
+ in_channels=latent_channels,
32
+ model_channels=model_channels,
33
+ num_blocks=num_blocks,
34
+ num_heads=num_heads,
35
+ num_head_channels=num_head_channels,
36
+ mlp_ratio=mlp_ratio,
37
+ attn_mode=attn_mode,
38
+ window_size=window_size,
39
+ pe_mode=pe_mode,
40
+ use_fp16=use_fp16,
41
+ use_checkpoint=use_checkpoint,
42
+ qk_rms_norm=qk_rms_norm,
43
+ )
44
+ self.resolution = resolution
45
+ self.rep_config = representation_config
46
+ self._calc_layout()
47
+ self.out_layer = sp.SparseLinear(model_channels, self.out_channels)
48
+ self._build_perturbation()
49
+
50
+ self.initialize_weights()
51
+ if use_fp16:
52
+ self.convert_to_fp16()
53
+
54
+ def initialize_weights(self) -> None:
55
+ super().initialize_weights()
56
+ # Zero-out output layers:
57
+ nn.init.constant_(self.out_layer.weight, 0)
58
+ nn.init.constant_(self.out_layer.bias, 0)
59
+
60
+ def _build_perturbation(self) -> None:
61
+ perturbation = [hammersley_sequence(3, i, self.rep_config['num_gaussians']) for i in range(self.rep_config['num_gaussians'])]
62
+ perturbation = torch.tensor(perturbation).float() * 2 - 1
63
+ perturbation = perturbation / self.rep_config['voxel_size']
64
+ perturbation = torch.atanh(perturbation).to(self.device)
65
+ self.register_buffer('offset_perturbation', perturbation)
66
+
67
+ def _calc_layout(self) -> None:
68
+ self.layout = {
69
+ '_xyz' : {'shape': (self.rep_config['num_gaussians'], 3), 'size': self.rep_config['num_gaussians'] * 3},
70
+ '_features_dc' : {'shape': (self.rep_config['num_gaussians'], 1, 3), 'size': self.rep_config['num_gaussians'] * 3},
71
+ '_scaling' : {'shape': (self.rep_config['num_gaussians'], 3), 'size': self.rep_config['num_gaussians'] * 3},
72
+ '_rotation' : {'shape': (self.rep_config['num_gaussians'], 4), 'size': self.rep_config['num_gaussians'] * 4},
73
+ '_opacity' : {'shape': (self.rep_config['num_gaussians'], 1), 'size': self.rep_config['num_gaussians']},
74
+ }
75
+ start = 0
76
+ for k, v in self.layout.items():
77
+ v['range'] = (start, start + v['size'])
78
+ start += v['size']
79
+ self.out_channels = start
80
+
81
+ def to_representation(self, x: sp.SparseTensor) -> List[Gaussian]:
82
+ """
83
+ Convert a batch of network outputs to 3D representations.
84
+
85
+ Args:
86
+ x: The [N x * x C] sparse tensor output by the network.
87
+
88
+ Returns:
89
+ list of representations
90
+ """
91
+ ret = []
92
+ for i in range(x.shape[0]):
93
+ representation = Gaussian(
94
+ sh_degree=0,
95
+ aabb=[-0.5, -0.5, -0.5, 1.0, 1.0, 1.0],
96
+ mininum_kernel_size = self.rep_config['3d_filter_kernel_size'],
97
+ scaling_bias = self.rep_config['scaling_bias'],
98
+ opacity_bias = self.rep_config['opacity_bias'],
99
+ scaling_activation = self.rep_config['scaling_activation']
100
+ )
101
+ xyz = (x.coords[x.layout[i]][:, 1:].float() + 0.5) / self.resolution
102
+ for k, v in self.layout.items():
103
+ if k == '_xyz':
104
+ offset = x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape'])
105
+ offset = offset * self.rep_config['lr'][k]
106
+ if self.rep_config['perturb_offset']:
107
+ offset = offset + self.offset_perturbation
108
+ offset = torch.tanh(offset) / self.resolution * 0.5 * self.rep_config['voxel_size']
109
+ _xyz = xyz.unsqueeze(1) + offset
110
+ setattr(representation, k, _xyz.flatten(0, 1))
111
+ else:
112
+ feats = x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']).flatten(0, 1)
113
+ feats = feats * self.rep_config['lr'][k]
114
+ setattr(representation, k, feats)
115
+ ret.append(representation)
116
+ return ret
117
+
118
+ def forward(self, x: sp.SparseTensor) -> List[Gaussian]:
119
+ h = super().forward(x)
120
+ h = h.type(x.dtype)
121
+ h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
122
+ h = self.out_layer(h)
123
+ return self.to_representation(h)
124
+
125
+
126
+ class ElasticSLatGaussianDecoder(SparseTransformerElasticMixin, SLatGaussianDecoder):
127
+ """
128
+ Slat VAE Gaussian decoder with elastic memory management.
129
+ Used for training with low VRAM.
130
+ """
131
+ pass
trellis/models/structured_latent_vae/decoder_mesh.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ from ...modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32
7
+ from ...modules import sparse as sp
8
+ from .base import SparseTransformerBase
9
+ from ...representations import MeshExtractResult
10
+ from ...representations.mesh import SparseFeatures2Mesh
11
+ from ..sparse_elastic_mixin import SparseTransformerElasticMixin
12
+
13
+
14
+ class SparseSubdivideBlock3d(nn.Module):
15
+ """
16
+ A 3D subdivide block that can subdivide the sparse tensor.
17
+
18
+ Args:
19
+ channels: channels in the inputs and outputs.
20
+ out_channels: if specified, the number of output channels.
21
+ num_groups: the number of groups for the group norm.
22
+ """
23
+ def __init__(
24
+ self,
25
+ channels: int,
26
+ resolution: int,
27
+ out_channels: Optional[int] = None,
28
+ num_groups: int = 32
29
+ ):
30
+ super().__init__()
31
+ self.channels = channels
32
+ self.resolution = resolution
33
+ self.out_resolution = resolution * 2
34
+ self.out_channels = out_channels or channels
35
+
36
+ self.act_layers = nn.Sequential(
37
+ sp.SparseGroupNorm32(num_groups, channels),
38
+ sp.SparseSiLU()
39
+ )
40
+
41
+ self.sub = sp.SparseSubdivide()
42
+
43
+ self.out_layers = nn.Sequential(
44
+ sp.SparseConv3d(channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}"),
45
+ sp.SparseGroupNorm32(num_groups, self.out_channels),
46
+ sp.SparseSiLU(),
47
+ zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}")),
48
+ )
49
+
50
+ if self.out_channels == channels:
51
+ self.skip_connection = nn.Identity()
52
+ else:
53
+ self.skip_connection = sp.SparseConv3d(channels, self.out_channels, 1, indice_key=f"res_{self.out_resolution}")
54
+
55
+ def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
56
+ """
57
+ Apply the block to a Tensor, conditioned on a timestep embedding.
58
+
59
+ Args:
60
+ x: an [N x C x ...] Tensor of features.
61
+ Returns:
62
+ an [N x C x ...] Tensor of outputs.
63
+ """
64
+ h = self.act_layers(x)
65
+ h = self.sub(h)
66
+ x = self.sub(x)
67
+ h = self.out_layers(h)
68
+ h = h + self.skip_connection(x)
69
+ return h
70
+
71
+
72
+ class SLatMeshDecoder(SparseTransformerBase):
73
+ def __init__(
74
+ self,
75
+ resolution: int,
76
+ model_channels: int,
77
+ latent_channels: int,
78
+ num_blocks: int,
79
+ num_heads: Optional[int] = None,
80
+ num_head_channels: Optional[int] = 64,
81
+ mlp_ratio: float = 4,
82
+ attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
83
+ window_size: int = 8,
84
+ pe_mode: Literal["ape", "rope"] = "ape",
85
+ use_fp16: bool = False,
86
+ use_checkpoint: bool = False,
87
+ qk_rms_norm: bool = False,
88
+ representation_config: dict = None,
89
+ ):
90
+ super().__init__(
91
+ in_channels=latent_channels,
92
+ model_channels=model_channels,
93
+ num_blocks=num_blocks,
94
+ num_heads=num_heads,
95
+ num_head_channels=num_head_channels,
96
+ mlp_ratio=mlp_ratio,
97
+ attn_mode=attn_mode,
98
+ window_size=window_size,
99
+ pe_mode=pe_mode,
100
+ use_fp16=use_fp16,
101
+ use_checkpoint=use_checkpoint,
102
+ qk_rms_norm=qk_rms_norm,
103
+ )
104
+ self.resolution = resolution
105
+ self.rep_config = representation_config
106
+ self.mesh_extractor = SparseFeatures2Mesh(res=self.resolution*4, use_color=self.rep_config.get('use_color', False))
107
+ self.out_channels = self.mesh_extractor.feats_channels
108
+ self.upsample = nn.ModuleList([
109
+ SparseSubdivideBlock3d(
110
+ channels=model_channels,
111
+ resolution=resolution,
112
+ out_channels=model_channels // 4
113
+ ),
114
+ SparseSubdivideBlock3d(
115
+ channels=model_channels // 4,
116
+ resolution=resolution * 2,
117
+ out_channels=model_channels // 8
118
+ )
119
+ ])
120
+ self.out_layer = sp.SparseLinear(model_channels // 8, self.out_channels)
121
+
122
+ self.initialize_weights()
123
+ if use_fp16:
124
+ self.convert_to_fp16()
125
+
126
+ def initialize_weights(self) -> None:
127
+ super().initialize_weights()
128
+ # Zero-out output layers:
129
+ nn.init.constant_(self.out_layer.weight, 0)
130
+ nn.init.constant_(self.out_layer.bias, 0)
131
+
132
+ def convert_to_fp16(self) -> None:
133
+ """
134
+ Convert the torso of the model to float16.
135
+ """
136
+ super().convert_to_fp16()
137
+ self.upsample.apply(convert_module_to_f16)
138
+
139
+ def convert_to_fp32(self) -> None:
140
+ """
141
+ Convert the torso of the model to float32.
142
+ """
143
+ super().convert_to_fp32()
144
+ self.upsample.apply(convert_module_to_f32)
145
+
146
+ def to_representation(self, x: sp.SparseTensor) -> List[MeshExtractResult]:
147
+ """
148
+ Convert a batch of network outputs to 3D representations.
149
+
150
+ Args:
151
+ x: The [N x * x C] sparse tensor output by the network.
152
+
153
+ Returns:
154
+ list of representations
155
+ """
156
+ ret = []
157
+ for i in range(x.shape[0]):
158
+ mesh = self.mesh_extractor(x[i], training=self.training)
159
+ ret.append(mesh)
160
+ return ret
161
+
162
+ def forward(self, x: sp.SparseTensor) -> List[MeshExtractResult]:
163
+ h = super().forward(x)
164
+ for block in self.upsample:
165
+ h = block(h)
166
+ h = h.type(x.dtype)
167
+ h = self.out_layer(h)
168
+ return self.to_representation(h)
169
+
170
+
171
+ class ElasticSLatMeshDecoder(SparseTransformerElasticMixin, SLatMeshDecoder):
172
+ """
173
+ Slat VAE Mesh decoder with elastic memory management.
174
+ Used for training with low VRAM.
175
+ """
176
+ pass
trellis/models/structured_latent_vae/decoder_rf.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ from ...modules import sparse as sp
7
+ from .base import SparseTransformerBase
8
+ from ...representations import Strivec
9
+ from ..sparse_elastic_mixin import SparseTransformerElasticMixin
10
+
11
+
12
+ class SLatRadianceFieldDecoder(SparseTransformerBase):
13
+ def __init__(
14
+ self,
15
+ resolution: int,
16
+ model_channels: int,
17
+ latent_channels: int,
18
+ num_blocks: int,
19
+ num_heads: Optional[int] = None,
20
+ num_head_channels: Optional[int] = 64,
21
+ mlp_ratio: float = 4,
22
+ attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
23
+ window_size: int = 8,
24
+ pe_mode: Literal["ape", "rope"] = "ape",
25
+ use_fp16: bool = False,
26
+ use_checkpoint: bool = False,
27
+ qk_rms_norm: bool = False,
28
+ representation_config: dict = None,
29
+ ):
30
+ super().__init__(
31
+ in_channels=latent_channels,
32
+ model_channels=model_channels,
33
+ num_blocks=num_blocks,
34
+ num_heads=num_heads,
35
+ num_head_channels=num_head_channels,
36
+ mlp_ratio=mlp_ratio,
37
+ attn_mode=attn_mode,
38
+ window_size=window_size,
39
+ pe_mode=pe_mode,
40
+ use_fp16=use_fp16,
41
+ use_checkpoint=use_checkpoint,
42
+ qk_rms_norm=qk_rms_norm,
43
+ )
44
+ self.resolution = resolution
45
+ self.rep_config = representation_config
46
+ self._calc_layout()
47
+ self.out_layer = sp.SparseLinear(model_channels, self.out_channels)
48
+
49
+ self.initialize_weights()
50
+ if use_fp16:
51
+ self.convert_to_fp16()
52
+
53
+ def initialize_weights(self) -> None:
54
+ super().initialize_weights()
55
+ # Zero-out output layers:
56
+ nn.init.constant_(self.out_layer.weight, 0)
57
+ nn.init.constant_(self.out_layer.bias, 0)
58
+
59
+ def _calc_layout(self) -> None:
60
+ self.layout = {
61
+ 'trivec': {'shape': (self.rep_config['rank'], 3, self.rep_config['dim']), 'size': self.rep_config['rank'] * 3 * self.rep_config['dim']},
62
+ 'density': {'shape': (self.rep_config['rank'],), 'size': self.rep_config['rank']},
63
+ 'features_dc': {'shape': (self.rep_config['rank'], 1, 3), 'size': self.rep_config['rank'] * 3},
64
+ }
65
+ start = 0
66
+ for k, v in self.layout.items():
67
+ v['range'] = (start, start + v['size'])
68
+ start += v['size']
69
+ self.out_channels = start
70
+
71
+ def to_representation(self, x: sp.SparseTensor) -> List[Strivec]:
72
+ """
73
+ Convert a batch of network outputs to 3D representations.
74
+
75
+ Args:
76
+ x: The [N x * x C] sparse tensor output by the network.
77
+
78
+ Returns:
79
+ list of representations
80
+ """
81
+ ret = []
82
+ for i in range(x.shape[0]):
83
+ representation = Strivec(
84
+ sh_degree=0,
85
+ resolution=self.resolution,
86
+ aabb=[-0.5, -0.5, -0.5, 1, 1, 1],
87
+ rank=self.rep_config['rank'],
88
+ dim=self.rep_config['dim'],
89
+ device='cuda',
90
+ )
91
+ representation.density_shift = 0.0
92
+ representation.position = (x.coords[x.layout[i]][:, 1:].float() + 0.5) / self.resolution
93
+ representation.depth = torch.full((representation.position.shape[0], 1), int(np.log2(self.resolution)), dtype=torch.uint8, device='cuda')
94
+ for k, v in self.layout.items():
95
+ setattr(representation, k, x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']))
96
+ representation.trivec = representation.trivec + 1
97
+ ret.append(representation)
98
+ return ret
99
+
100
+ def forward(self, x: sp.SparseTensor) -> List[Strivec]:
101
+ h = super().forward(x)
102
+ h = h.type(x.dtype)
103
+ h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
104
+ h = self.out_layer(h)
105
+ return self.to_representation(h)
106
+
107
+
108
+ class ElasticSLatRadianceFieldDecoder(SparseTransformerElasticMixin, SLatRadianceFieldDecoder):
109
+ """
110
+ Slat VAE Radiance Field Decoder with elastic memory management.
111
+ Used for training with low VRAM.
112
+ """
113
+ pass
trellis/models/structured_latent_vae/encoder.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from ...modules import sparse as sp
6
+ from .base import SparseTransformerBase
7
+ from ..sparse_elastic_mixin import SparseTransformerElasticMixin
8
+
9
+
10
+ class SLatEncoder(SparseTransformerBase):
11
+ def __init__(
12
+ self,
13
+ resolution: int,
14
+ in_channels: int,
15
+ model_channels: int,
16
+ latent_channels: int,
17
+ num_blocks: int,
18
+ num_heads: Optional[int] = None,
19
+ num_head_channels: Optional[int] = 64,
20
+ mlp_ratio: float = 4,
21
+ attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
22
+ window_size: int = 8,
23
+ pe_mode: Literal["ape", "rope"] = "ape",
24
+ use_fp16: bool = False,
25
+ use_checkpoint: bool = False,
26
+ qk_rms_norm: bool = False,
27
+ ):
28
+ super().__init__(
29
+ in_channels=in_channels,
30
+ model_channels=model_channels,
31
+ num_blocks=num_blocks,
32
+ num_heads=num_heads,
33
+ num_head_channels=num_head_channels,
34
+ mlp_ratio=mlp_ratio,
35
+ attn_mode=attn_mode,
36
+ window_size=window_size,
37
+ pe_mode=pe_mode,
38
+ use_fp16=use_fp16,
39
+ use_checkpoint=use_checkpoint,
40
+ qk_rms_norm=qk_rms_norm,
41
+ )
42
+ self.resolution = resolution
43
+ self.out_layer = sp.SparseLinear(model_channels, 2 * latent_channels)
44
+
45
+ self.initialize_weights()
46
+ if use_fp16:
47
+ self.convert_to_fp16()
48
+
49
+ def initialize_weights(self) -> None:
50
+ super().initialize_weights()
51
+ # Zero-out output layers:
52
+ nn.init.constant_(self.out_layer.weight, 0)
53
+ nn.init.constant_(self.out_layer.bias, 0)
54
+
55
+ def forward(self, x: sp.SparseTensor, sample_posterior=True, return_raw=False):
56
+ h = super().forward(x)
57
+ h = h.type(x.dtype)
58
+ h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
59
+ h = self.out_layer(h)
60
+
61
+ # Sample from the posterior distribution
62
+ mean, logvar = h.feats.chunk(2, dim=-1)
63
+ if sample_posterior:
64
+ std = torch.exp(0.5 * logvar)
65
+ z = mean + std * torch.randn_like(std)
66
+ else:
67
+ z = mean
68
+ z = h.replace(z)
69
+
70
+ if return_raw:
71
+ return z, mean, logvar
72
+ else:
73
+ return z
74
+
75
+
76
+ class ElasticSLatEncoder(SparseTransformerElasticMixin, SLatEncoder):
77
+ """
78
+ SLat VAE encoder with elastic memory management.
79
+ Used for training with low VRAM.
80
+ """
trellis/modules/attention/__init__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+
3
+ BACKEND = 'flash_attn'
4
+ DEBUG = False
5
+
6
+ def __from_env():
7
+ import os
8
+
9
+ global BACKEND
10
+ global DEBUG
11
+
12
+ env_attn_backend = os.environ.get('ATTN_BACKEND')
13
+ env_sttn_debug = os.environ.get('ATTN_DEBUG')
14
+
15
+ if env_attn_backend is not None and env_attn_backend in ['xformers', 'flash_attn', 'sdpa', 'naive']:
16
+ BACKEND = env_attn_backend
17
+ if env_sttn_debug is not None:
18
+ DEBUG = env_sttn_debug == '1'
19
+
20
+ print(f"[ATTENTION] Using backend: {BACKEND}")
21
+
22
+
23
+ __from_env()
24
+
25
+
26
+ def set_backend(backend: Literal['xformers', 'flash_attn']):
27
+ global BACKEND
28
+ BACKEND = backend
29
+
30
+ def set_debug(debug: bool):
31
+ global DEBUG
32
+ DEBUG = debug
33
+
34
+
35
+ from .full_attn import *
36
+ from .modules import *
trellis/modules/attention/full_attn.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import math
4
+ from . import DEBUG, BACKEND
5
+
6
+ if BACKEND == 'xformers':
7
+ import xformers.ops as xops
8
+ elif BACKEND == 'flash_attn':
9
+ import flash_attn
10
+ elif BACKEND == 'sdpa':
11
+ from torch.nn.functional import scaled_dot_product_attention as sdpa
12
+ elif BACKEND == 'naive':
13
+ pass
14
+ else:
15
+ raise ValueError(f"Unknown attention backend: {BACKEND}")
16
+
17
+
18
+ __all__ = [
19
+ 'scaled_dot_product_attention',
20
+ ]
21
+
22
+
23
+ def _naive_sdpa(q, k, v):
24
+ """
25
+ Naive implementation of scaled dot product attention.
26
+ """
27
+ q = q.permute(0, 2, 1, 3) # [N, H, L, C]
28
+ k = k.permute(0, 2, 1, 3) # [N, H, L, C]
29
+ v = v.permute(0, 2, 1, 3) # [N, H, L, C]
30
+ scale_factor = 1 / math.sqrt(q.size(-1))
31
+ attn_weight = q @ k.transpose(-2, -1) * scale_factor
32
+ attn_weight = torch.softmax(attn_weight, dim=-1)
33
+ out = attn_weight @ v
34
+ out = out.permute(0, 2, 1, 3) # [N, L, H, C]
35
+ return out
36
+
37
+
38
+ @overload
39
+ def scaled_dot_product_attention(qkv: torch.Tensor) -> torch.Tensor:
40
+ """
41
+ Apply scaled dot product attention.
42
+
43
+ Args:
44
+ qkv (torch.Tensor): A [N, L, 3, H, C] tensor containing Qs, Ks, and Vs.
45
+ """
46
+ ...
47
+
48
+ @overload
49
+ def scaled_dot_product_attention(q: torch.Tensor, kv: torch.Tensor) -> torch.Tensor:
50
+ """
51
+ Apply scaled dot product attention.
52
+
53
+ Args:
54
+ q (torch.Tensor): A [N, L, H, C] tensor containing Qs.
55
+ kv (torch.Tensor): A [N, L, 2, H, C] tensor containing Ks and Vs.
56
+ """
57
+ ...
58
+
59
+ @overload
60
+ def scaled_dot_product_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
61
+ """
62
+ Apply scaled dot product attention.
63
+
64
+ Args:
65
+ q (torch.Tensor): A [N, L, H, Ci] tensor containing Qs.
66
+ k (torch.Tensor): A [N, L, H, Ci] tensor containing Ks.
67
+ v (torch.Tensor): A [N, L, H, Co] tensor containing Vs.
68
+
69
+ Note:
70
+ k and v are assumed to have the same coordinate map.
71
+ """
72
+ ...
73
+
74
+ def scaled_dot_product_attention(*args, **kwargs):
75
+ arg_names_dict = {
76
+ 1: ['qkv'],
77
+ 2: ['q', 'kv'],
78
+ 3: ['q', 'k', 'v']
79
+ }
80
+ num_all_args = len(args) + len(kwargs)
81
+ assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3"
82
+ for key in arg_names_dict[num_all_args][len(args):]:
83
+ assert key in kwargs, f"Missing argument {key}"
84
+
85
+ if num_all_args == 1:
86
+ qkv = args[0] if len(args) > 0 else kwargs['qkv']
87
+ assert len(qkv.shape) == 5 and qkv.shape[2] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, L, 3, H, C]"
88
+ device = qkv.device
89
+
90
+ elif num_all_args == 2:
91
+ q = args[0] if len(args) > 0 else kwargs['q']
92
+ kv = args[1] if len(args) > 1 else kwargs['kv']
93
+ assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}"
94
+ assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]"
95
+ assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]"
96
+ device = q.device
97
+
98
+ elif num_all_args == 3:
99
+ q = args[0] if len(args) > 0 else kwargs['q']
100
+ k = args[1] if len(args) > 1 else kwargs['k']
101
+ v = args[2] if len(args) > 2 else kwargs['v']
102
+ assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}"
103
+ assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]"
104
+ assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]"
105
+ assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]"
106
+ device = q.device
107
+
108
+ if BACKEND == 'xformers':
109
+ if num_all_args == 1:
110
+ q, k, v = qkv.unbind(dim=2)
111
+ elif num_all_args == 2:
112
+ k, v = kv.unbind(dim=2)
113
+ out = xops.memory_efficient_attention(q, k, v)
114
+ elif BACKEND == 'flash_attn':
115
+ if num_all_args == 1:
116
+ out = flash_attn.flash_attn_qkvpacked_func(qkv)
117
+ elif num_all_args == 2:
118
+ out = flash_attn.flash_attn_kvpacked_func(q, kv)
119
+ elif num_all_args == 3:
120
+ out = flash_attn.flash_attn_func(q, k, v)
121
+ elif BACKEND == 'sdpa':
122
+ if num_all_args == 1:
123
+ q, k, v = qkv.unbind(dim=2)
124
+ elif num_all_args == 2:
125
+ k, v = kv.unbind(dim=2)
126
+ q = q.permute(0, 2, 1, 3) # [N, H, L, C]
127
+ k = k.permute(0, 2, 1, 3) # [N, H, L, C]
128
+ v = v.permute(0, 2, 1, 3) # [N, H, L, C]
129
+ out = sdpa(q, k, v) # [N, H, L, C]
130
+ out = out.permute(0, 2, 1, 3) # [N, L, H, C]
131
+ elif BACKEND == 'naive':
132
+ if num_all_args == 1:
133
+ q, k, v = qkv.unbind(dim=2)
134
+ elif num_all_args == 2:
135
+ k, v = kv.unbind(dim=2)
136
+ out = _naive_sdpa(q, k, v)
137
+ else:
138
+ raise ValueError(f"Unknown attention module: {BACKEND}")
139
+
140
+ return out
trellis/modules/attention/modules.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from .full_attn import scaled_dot_product_attention
6
+
7
+
8
+ class MultiHeadRMSNorm(nn.Module):
9
+ def __init__(self, dim: int, heads: int):
10
+ super().__init__()
11
+ self.scale = dim ** 0.5
12
+ self.gamma = nn.Parameter(torch.ones(heads, dim))
13
+
14
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
15
+ return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype)
16
+
17
+
18
+ class RotaryPositionEmbedder(nn.Module):
19
+ def __init__(self, hidden_size: int, in_channels: int = 3):
20
+ super().__init__()
21
+ assert hidden_size % 2 == 0, "Hidden size must be divisible by 2"
22
+ self.hidden_size = hidden_size
23
+ self.in_channels = in_channels
24
+ self.freq_dim = hidden_size // in_channels // 2
25
+ self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
26
+ self.freqs = 1.0 / (10000 ** self.freqs)
27
+
28
+ def _get_phases(self, indices: torch.Tensor) -> torch.Tensor:
29
+ self.freqs = self.freqs.to(indices.device)
30
+ phases = torch.outer(indices, self.freqs)
31
+ phases = torch.polar(torch.ones_like(phases), phases)
32
+ return phases
33
+
34
+ def _rotary_embedding(self, x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor:
35
+ x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
36
+ x_rotated = x_complex * phases
37
+ x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype)
38
+ return x_embed
39
+
40
+ def forward(self, q: torch.Tensor, k: torch.Tensor, indices: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
41
+ """
42
+ Args:
43
+ q (sp.SparseTensor): [..., N, D] tensor of queries
44
+ k (sp.SparseTensor): [..., N, D] tensor of keys
45
+ indices (torch.Tensor): [..., N, C] tensor of spatial positions
46
+ """
47
+ if indices is None:
48
+ indices = torch.arange(q.shape[-2], device=q.device)
49
+ if len(q.shape) > 2:
50
+ indices = indices.unsqueeze(0).expand(q.shape[:-2] + (-1,))
51
+
52
+ phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1)
53
+ if phases.shape[1] < self.hidden_size // 2:
54
+ phases = torch.cat([phases, torch.polar(
55
+ torch.ones(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device),
56
+ torch.zeros(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device)
57
+ )], dim=-1)
58
+ q_embed = self._rotary_embedding(q, phases)
59
+ k_embed = self._rotary_embedding(k, phases)
60
+ return q_embed, k_embed
61
+
62
+
63
+ class MultiHeadAttention(nn.Module):
64
+ def __init__(
65
+ self,
66
+ channels: int,
67
+ num_heads: int,
68
+ ctx_channels: Optional[int]=None,
69
+ type: Literal["self", "cross"] = "self",
70
+ attn_mode: Literal["full", "windowed"] = "full",
71
+ window_size: Optional[int] = None,
72
+ shift_window: Optional[Tuple[int, int, int]] = None,
73
+ qkv_bias: bool = True,
74
+ use_rope: bool = False,
75
+ qk_rms_norm: bool = False,
76
+ ):
77
+ super().__init__()
78
+ assert channels % num_heads == 0
79
+ assert type in ["self", "cross"], f"Invalid attention type: {type}"
80
+ assert attn_mode in ["full", "windowed"], f"Invalid attention mode: {attn_mode}"
81
+ assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention"
82
+
83
+ if attn_mode == "windowed":
84
+ raise NotImplementedError("Windowed attention is not yet implemented")
85
+
86
+ self.channels = channels
87
+ self.head_dim = channels // num_heads
88
+ self.ctx_channels = ctx_channels if ctx_channels is not None else channels
89
+ self.num_heads = num_heads
90
+ self._type = type
91
+ self.attn_mode = attn_mode
92
+ self.window_size = window_size
93
+ self.shift_window = shift_window
94
+ self.use_rope = use_rope
95
+ self.qk_rms_norm = qk_rms_norm
96
+
97
+ if self._type == "self":
98
+ self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias)
99
+ else:
100
+ self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
101
+ self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)
102
+
103
+ if self.qk_rms_norm:
104
+ self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
105
+ self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
106
+
107
+ self.to_out = nn.Linear(channels, channels)
108
+
109
+ if use_rope:
110
+ self.rope = RotaryPositionEmbedder(channels)
111
+
112
+ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None) -> torch.Tensor:
113
+ B, L, C = x.shape
114
+ if self._type == "self":
115
+ qkv = self.to_qkv(x)
116
+ qkv = qkv.reshape(B, L, 3, self.num_heads, -1)
117
+ if self.use_rope:
118
+ q, k, v = qkv.unbind(dim=2)
119
+ q, k = self.rope(q, k, indices)
120
+ qkv = torch.stack([q, k, v], dim=2)
121
+ if self.attn_mode == "full":
122
+ if self.qk_rms_norm:
123
+ q, k, v = qkv.unbind(dim=2)
124
+ q = self.q_rms_norm(q)
125
+ k = self.k_rms_norm(k)
126
+ h = scaled_dot_product_attention(q, k, v)
127
+ else:
128
+ h = scaled_dot_product_attention(qkv)
129
+ elif self.attn_mode == "windowed":
130
+ raise NotImplementedError("Windowed attention is not yet implemented")
131
+ else:
132
+ Lkv = context.shape[1]
133
+ q = self.to_q(x)
134
+ kv = self.to_kv(context)
135
+ q = q.reshape(B, L, self.num_heads, -1)
136
+ kv = kv.reshape(B, Lkv, 2, self.num_heads, -1)
137
+ if self.qk_rms_norm:
138
+ q = self.q_rms_norm(q)
139
+ k, v = kv.unbind(dim=2)
140
+ k = self.k_rms_norm(k)
141
+ h = scaled_dot_product_attention(q, k, v)
142
+ else:
143
+ h = scaled_dot_product_attention(q, kv)
144
+ h = h.reshape(B, L, -1)
145
+ h = self.to_out(h)
146
+ return h
trellis/modules/norm.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class LayerNorm32(nn.LayerNorm):
6
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
7
+ return super().forward(x.float()).type(x.dtype)
8
+
9
+
10
+ class GroupNorm32(nn.GroupNorm):
11
+ """
12
+ A GroupNorm layer that converts to float32 before the forward pass.
13
+ """
14
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
15
+ return super().forward(x.float()).type(x.dtype)
16
+
17
+
18
+ class ChannelLayerNorm32(LayerNorm32):
19
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
20
+ DIM = x.dim()
21
+ x = x.permute(0, *range(2, DIM), 1).contiguous()
22
+ x = super().forward(x)
23
+ x = x.permute(0, DIM-1, *range(1, DIM-1)).contiguous()
24
+ return x
25
+
trellis/modules/sparse/__init__.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+
3
+ BACKEND = 'spconv'
4
+ DEBUG = False
5
+ ATTN = 'flash_attn'
6
+
7
+ def __from_env():
8
+ import os
9
+
10
+ global BACKEND
11
+ global DEBUG
12
+ global ATTN
13
+
14
+ env_sparse_backend = os.environ.get('SPARSE_BACKEND')
15
+ env_sparse_debug = os.environ.get('SPARSE_DEBUG')
16
+ env_sparse_attn = os.environ.get('SPARSE_ATTN_BACKEND')
17
+ if env_sparse_attn is None:
18
+ env_sparse_attn = os.environ.get('ATTN_BACKEND')
19
+
20
+ if env_sparse_backend is not None and env_sparse_backend in ['spconv', 'torchsparse']:
21
+ BACKEND = env_sparse_backend
22
+ if env_sparse_debug is not None:
23
+ DEBUG = env_sparse_debug == '1'
24
+ if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn']:
25
+ ATTN = env_sparse_attn
26
+
27
+ print(f"[SPARSE] Backend: {BACKEND}, Attention: {ATTN}")
28
+
29
+
30
+ __from_env()
31
+
32
+
33
+ def set_backend(backend: Literal['spconv', 'torchsparse']):
34
+ global BACKEND
35
+ BACKEND = backend
36
+
37
+ def set_debug(debug: bool):
38
+ global DEBUG
39
+ DEBUG = debug
40
+
41
+ def set_attn(attn: Literal['xformers', 'flash_attn']):
42
+ global ATTN
43
+ ATTN = attn
44
+
45
+
46
+ import importlib
47
+
48
+ __attributes = {
49
+ 'SparseTensor': 'basic',
50
+ 'sparse_batch_broadcast': 'basic',
51
+ 'sparse_batch_op': 'basic',
52
+ 'sparse_cat': 'basic',
53
+ 'sparse_unbind': 'basic',
54
+ 'SparseGroupNorm': 'norm',
55
+ 'SparseLayerNorm': 'norm',
56
+ 'SparseGroupNorm32': 'norm',
57
+ 'SparseLayerNorm32': 'norm',
58
+ 'SparseReLU': 'nonlinearity',
59
+ 'SparseSiLU': 'nonlinearity',
60
+ 'SparseGELU': 'nonlinearity',
61
+ 'SparseActivation': 'nonlinearity',
62
+ 'SparseLinear': 'linear',
63
+ 'sparse_scaled_dot_product_attention': 'attention',
64
+ 'SerializeMode': 'attention',
65
+ 'sparse_serialized_scaled_dot_product_self_attention': 'attention',
66
+ 'sparse_windowed_scaled_dot_product_self_attention': 'attention',
67
+ 'SparseMultiHeadAttention': 'attention',
68
+ 'SparseConv3d': 'conv',
69
+ 'SparseInverseConv3d': 'conv',
70
+ 'SparseDownsample': 'spatial',
71
+ 'SparseUpsample': 'spatial',
72
+ 'SparseSubdivide' : 'spatial'
73
+ }
74
+
75
+ __submodules = ['transformer']
76
+
77
+ __all__ = list(__attributes.keys()) + __submodules
78
+
79
+ def __getattr__(name):
80
+ if name not in globals():
81
+ if name in __attributes:
82
+ module_name = __attributes[name]
83
+ module = importlib.import_module(f".{module_name}", __name__)
84
+ globals()[name] = getattr(module, name)
85
+ elif name in __submodules:
86
+ module = importlib.import_module(f".{name}", __name__)
87
+ globals()[name] = module
88
+ else:
89
+ raise AttributeError(f"module {__name__} has no attribute {name}")
90
+ return globals()[name]
91
+
92
+
93
+ # For Pylance
94
+ if __name__ == '__main__':
95
+ from .basic import *
96
+ from .norm import *
97
+ from .nonlinearity import *
98
+ from .linear import *
99
+ from .attention import *
100
+ from .conv import *
101
+ from .spatial import *
102
+ import transformer
trellis/modules/sparse/attention/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .full_attn import *
2
+ from .serialized_attn import *
3
+ from .windowed_attn import *
4
+ from .modules import *
trellis/modules/sparse/attention/full_attn.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ from .. import SparseTensor
4
+ from .. import DEBUG, ATTN
5
+
6
+ if ATTN == 'xformers':
7
+ import xformers.ops as xops
8
+ elif ATTN == 'flash_attn':
9
+ import flash_attn
10
+ else:
11
+ raise ValueError(f"Unknown attention module: {ATTN}")
12
+
13
+
14
+ __all__ = [
15
+ 'sparse_scaled_dot_product_attention',
16
+ ]
17
+
18
+
19
+ @overload
20
+ def sparse_scaled_dot_product_attention(qkv: SparseTensor) -> SparseTensor:
21
+ """
22
+ Apply scaled dot product attention to a sparse tensor.
23
+
24
+ Args:
25
+ qkv (SparseTensor): A [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs.
26
+ """
27
+ ...
28
+
29
+ @overload
30
+ def sparse_scaled_dot_product_attention(q: SparseTensor, kv: Union[SparseTensor, torch.Tensor]) -> SparseTensor:
31
+ """
32
+ Apply scaled dot product attention to a sparse tensor.
33
+
34
+ Args:
35
+ q (SparseTensor): A [N, *, H, C] sparse tensor containing Qs.
36
+ kv (SparseTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor or a [N, L, 2, H, C] dense tensor containing Ks and Vs.
37
+ """
38
+ ...
39
+
40
+ @overload
41
+ def sparse_scaled_dot_product_attention(q: torch.Tensor, kv: SparseTensor) -> torch.Tensor:
42
+ """
43
+ Apply scaled dot product attention to a sparse tensor.
44
+
45
+ Args:
46
+ q (SparseTensor): A [N, L, H, C] dense tensor containing Qs.
47
+ kv (SparseTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor containing Ks and Vs.
48
+ """
49
+ ...
50
+
51
+ @overload
52
+ def sparse_scaled_dot_product_attention(q: SparseTensor, k: SparseTensor, v: SparseTensor) -> SparseTensor:
53
+ """
54
+ Apply scaled dot product attention to a sparse tensor.
55
+
56
+ Args:
57
+ q (SparseTensor): A [N, *, H, Ci] sparse tensor containing Qs.
58
+ k (SparseTensor): A [N, *, H, Ci] sparse tensor containing Ks.
59
+ v (SparseTensor): A [N, *, H, Co] sparse tensor containing Vs.
60
+
61
+ Note:
62
+ k and v are assumed to have the same coordinate map.
63
+ """
64
+ ...
65
+
66
+ @overload
67
+ def sparse_scaled_dot_product_attention(q: SparseTensor, k: torch.Tensor, v: torch.Tensor) -> SparseTensor:
68
+ """
69
+ Apply scaled dot product attention to a sparse tensor.
70
+
71
+ Args:
72
+ q (SparseTensor): A [N, *, H, Ci] sparse tensor containing Qs.
73
+ k (torch.Tensor): A [N, L, H, Ci] dense tensor containing Ks.
74
+ v (torch.Tensor): A [N, L, H, Co] dense tensor containing Vs.
75
+ """
76
+ ...
77
+
78
+ @overload
79
+ def sparse_scaled_dot_product_attention(q: torch.Tensor, k: SparseTensor, v: SparseTensor) -> torch.Tensor:
80
+ """
81
+ Apply scaled dot product attention to a sparse tensor.
82
+
83
+ Args:
84
+ q (torch.Tensor): A [N, L, H, Ci] dense tensor containing Qs.
85
+ k (SparseTensor): A [N, *, H, Ci] sparse tensor containing Ks.
86
+ v (SparseTensor): A [N, *, H, Co] sparse tensor containing Vs.
87
+ """
88
+ ...
89
+
90
+ def sparse_scaled_dot_product_attention(*args, **kwargs):
91
+ arg_names_dict = {
92
+ 1: ['qkv'],
93
+ 2: ['q', 'kv'],
94
+ 3: ['q', 'k', 'v']
95
+ }
96
+ num_all_args = len(args) + len(kwargs)
97
+ assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3"
98
+ for key in arg_names_dict[num_all_args][len(args):]:
99
+ assert key in kwargs, f"Missing argument {key}"
100
+
101
+ if num_all_args == 1:
102
+ qkv = args[0] if len(args) > 0 else kwargs['qkv']
103
+ assert isinstance(qkv, SparseTensor), f"qkv must be a SparseTensor, got {type(qkv)}"
104
+ assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
105
+ device = qkv.device
106
+
107
+ s = qkv
108
+ q_seqlen = [qkv.layout[i].stop - qkv.layout[i].start for i in range(qkv.shape[0])]
109
+ kv_seqlen = q_seqlen
110
+ qkv = qkv.feats # [T, 3, H, C]
111
+
112
+ elif num_all_args == 2:
113
+ q = args[0] if len(args) > 0 else kwargs['q']
114
+ kv = args[1] if len(args) > 1 else kwargs['kv']
115
+ assert isinstance(q, SparseTensor) and isinstance(kv, (SparseTensor, torch.Tensor)) or \
116
+ isinstance(q, torch.Tensor) and isinstance(kv, SparseTensor), \
117
+ f"Invalid types, got {type(q)} and {type(kv)}"
118
+ assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}"
119
+ device = q.device
120
+
121
+ if isinstance(q, SparseTensor):
122
+ assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, C]"
123
+ s = q
124
+ q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
125
+ q = q.feats # [T_Q, H, C]
126
+ else:
127
+ assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]"
128
+ s = None
129
+ N, L, H, C = q.shape
130
+ q_seqlen = [L] * N
131
+ q = q.reshape(N * L, H, C) # [T_Q, H, C]
132
+
133
+ if isinstance(kv, SparseTensor):
134
+ assert len(kv.shape) == 4 and kv.shape[1] == 2, f"Invalid shape for kv, got {kv.shape}, expected [N, *, 2, H, C]"
135
+ kv_seqlen = [kv.layout[i].stop - kv.layout[i].start for i in range(kv.shape[0])]
136
+ kv = kv.feats # [T_KV, 2, H, C]
137
+ else:
138
+ assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]"
139
+ N, L, _, H, C = kv.shape
140
+ kv_seqlen = [L] * N
141
+ kv = kv.reshape(N * L, 2, H, C) # [T_KV, 2, H, C]
142
+
143
+ elif num_all_args == 3:
144
+ q = args[0] if len(args) > 0 else kwargs['q']
145
+ k = args[1] if len(args) > 1 else kwargs['k']
146
+ v = args[2] if len(args) > 2 else kwargs['v']
147
+ assert isinstance(q, SparseTensor) and isinstance(k, (SparseTensor, torch.Tensor)) and type(k) == type(v) or \
148
+ isinstance(q, torch.Tensor) and isinstance(k, SparseTensor) and isinstance(v, SparseTensor), \
149
+ f"Invalid types, got {type(q)}, {type(k)}, and {type(v)}"
150
+ assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}"
151
+ device = q.device
152
+
153
+ if isinstance(q, SparseTensor):
154
+ assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, Ci]"
155
+ s = q
156
+ q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
157
+ q = q.feats # [T_Q, H, Ci]
158
+ else:
159
+ assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]"
160
+ s = None
161
+ N, L, H, CI = q.shape
162
+ q_seqlen = [L] * N
163
+ q = q.reshape(N * L, H, CI) # [T_Q, H, Ci]
164
+
165
+ if isinstance(k, SparseTensor):
166
+ assert len(k.shape) == 3, f"Invalid shape for k, got {k.shape}, expected [N, *, H, Ci]"
167
+ assert len(v.shape) == 3, f"Invalid shape for v, got {v.shape}, expected [N, *, H, Co]"
168
+ kv_seqlen = [k.layout[i].stop - k.layout[i].start for i in range(k.shape[0])]
169
+ k = k.feats # [T_KV, H, Ci]
170
+ v = v.feats # [T_KV, H, Co]
171
+ else:
172
+ assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]"
173
+ assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]"
174
+ N, L, H, CI, CO = *k.shape, v.shape[-1]
175
+ kv_seqlen = [L] * N
176
+ k = k.reshape(N * L, H, CI) # [T_KV, H, Ci]
177
+ v = v.reshape(N * L, H, CO) # [T_KV, H, Co]
178
+
179
+ if DEBUG:
180
+ if s is not None:
181
+ for i in range(s.shape[0]):
182
+ assert (s.coords[s.layout[i]] == i).all(), f"SparseScaledDotProductSelfAttention: batch index mismatch"
183
+ if num_all_args in [2, 3]:
184
+ assert q.shape[:2] == [1, sum(q_seqlen)], f"SparseScaledDotProductSelfAttention: q shape mismatch"
185
+ if num_all_args == 3:
186
+ assert k.shape[:2] == [1, sum(kv_seqlen)], f"SparseScaledDotProductSelfAttention: k shape mismatch"
187
+ assert v.shape[:2] == [1, sum(kv_seqlen)], f"SparseScaledDotProductSelfAttention: v shape mismatch"
188
+
189
+ if ATTN == 'xformers':
190
+ if num_all_args == 1:
191
+ q, k, v = qkv.unbind(dim=1)
192
+ elif num_all_args == 2:
193
+ k, v = kv.unbind(dim=1)
194
+ q = q.unsqueeze(0)
195
+ k = k.unsqueeze(0)
196
+ v = v.unsqueeze(0)
197
+ mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen)
198
+ out = xops.memory_efficient_attention(q, k, v, mask)[0]
199
+ elif ATTN == 'flash_attn':
200
+ cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
201
+ if num_all_args in [2, 3]:
202
+ cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
203
+ if num_all_args == 1:
204
+ out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens_q, max(q_seqlen))
205
+ elif num_all_args == 2:
206
+ out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
207
+ elif num_all_args == 3:
208
+ out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
209
+ else:
210
+ raise ValueError(f"Unknown attention module: {ATTN}")
211
+
212
+ if s is not None:
213
+ return s.replace(out)
214
+ else:
215
+ return out.reshape(N, L, H, -1)
trellis/modules/sparse/attention/modules.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from .. import SparseTensor
6
+ from .full_attn import sparse_scaled_dot_product_attention
7
+ from .serialized_attn import SerializeMode, sparse_serialized_scaled_dot_product_self_attention
8
+ from .windowed_attn import sparse_windowed_scaled_dot_product_self_attention
9
+ from ...attention import RotaryPositionEmbedder
10
+
11
+
12
+ class SparseMultiHeadRMSNorm(nn.Module):
13
+ def __init__(self, dim: int, heads: int):
14
+ super().__init__()
15
+ self.scale = dim ** 0.5
16
+ self.gamma = nn.Parameter(torch.ones(heads, dim))
17
+
18
+ def forward(self, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]:
19
+ x_type = x.dtype
20
+ x = x.float()
21
+ if isinstance(x, SparseTensor):
22
+ x = x.replace(F.normalize(x.feats, dim=-1))
23
+ else:
24
+ x = F.normalize(x, dim=-1)
25
+ return (x * self.gamma * self.scale).to(x_type)
26
+
27
+
28
+ class SparseMultiHeadAttention(nn.Module):
29
+ def __init__(
30
+ self,
31
+ channels: int,
32
+ num_heads: int,
33
+ ctx_channels: Optional[int] = None,
34
+ type: Literal["self", "cross"] = "self",
35
+ attn_mode: Literal["full", "serialized", "windowed"] = "full",
36
+ window_size: Optional[int] = None,
37
+ shift_sequence: Optional[int] = None,
38
+ shift_window: Optional[Tuple[int, int, int]] = None,
39
+ serialize_mode: Optional[SerializeMode] = None,
40
+ qkv_bias: bool = True,
41
+ use_rope: bool = False,
42
+ qk_rms_norm: bool = False,
43
+ ):
44
+ super().__init__()
45
+ assert channels % num_heads == 0
46
+ assert type in ["self", "cross"], f"Invalid attention type: {type}"
47
+ assert attn_mode in ["full", "serialized", "windowed"], f"Invalid attention mode: {attn_mode}"
48
+ assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention"
49
+ assert type == "self" or use_rope is False, "Rotary position embeddings only supported for self-attention"
50
+ self.channels = channels
51
+ self.ctx_channels = ctx_channels if ctx_channels is not None else channels
52
+ self.num_heads = num_heads
53
+ self._type = type
54
+ self.attn_mode = attn_mode
55
+ self.window_size = window_size
56
+ self.shift_sequence = shift_sequence
57
+ self.shift_window = shift_window
58
+ self.serialize_mode = serialize_mode
59
+ self.use_rope = use_rope
60
+ self.qk_rms_norm = qk_rms_norm
61
+
62
+ if self._type == "self":
63
+ self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias)
64
+ else:
65
+ self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
66
+ self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)
67
+
68
+ if self.qk_rms_norm:
69
+ self.q_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads)
70
+ self.k_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads)
71
+
72
+ self.to_out = nn.Linear(channels, channels)
73
+
74
+ if use_rope:
75
+ self.rope = RotaryPositionEmbedder(channels)
76
+
77
+ @staticmethod
78
+ def _linear(module: nn.Linear, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]:
79
+ if isinstance(x, SparseTensor):
80
+ return x.replace(module(x.feats))
81
+ else:
82
+ return module(x)
83
+
84
+ @staticmethod
85
+ def _reshape_chs(x: Union[SparseTensor, torch.Tensor], shape: Tuple[int, ...]) -> Union[SparseTensor, torch.Tensor]:
86
+ if isinstance(x, SparseTensor):
87
+ return x.reshape(*shape)
88
+ else:
89
+ return x.reshape(*x.shape[:2], *shape)
90
+
91
+ def _fused_pre(self, x: Union[SparseTensor, torch.Tensor], num_fused: int) -> Union[SparseTensor, torch.Tensor]:
92
+ if isinstance(x, SparseTensor):
93
+ x_feats = x.feats.unsqueeze(0)
94
+ else:
95
+ x_feats = x
96
+ x_feats = x_feats.reshape(*x_feats.shape[:2], num_fused, self.num_heads, -1)
97
+ return x.replace(x_feats.squeeze(0)) if isinstance(x, SparseTensor) else x_feats
98
+
99
+ def _rope(self, qkv: SparseTensor) -> SparseTensor:
100
+ q, k, v = qkv.feats.unbind(dim=1) # [T, H, C]
101
+ q, k = self.rope(q, k, qkv.coords[:, 1:])
102
+ qkv = qkv.replace(torch.stack([q, k, v], dim=1))
103
+ return qkv
104
+
105
+ def forward(self, x: Union[SparseTensor, torch.Tensor], context: Optional[Union[SparseTensor, torch.Tensor]] = None) -> Union[SparseTensor, torch.Tensor]:
106
+ if self._type == "self":
107
+ qkv = self._linear(self.to_qkv, x)
108
+ qkv = self._fused_pre(qkv, num_fused=3)
109
+ if self.use_rope:
110
+ qkv = self._rope(qkv)
111
+ if self.qk_rms_norm:
112
+ q, k, v = qkv.unbind(dim=1)
113
+ q = self.q_rms_norm(q)
114
+ k = self.k_rms_norm(k)
115
+ qkv = qkv.replace(torch.stack([q.feats, k.feats, v.feats], dim=1))
116
+ if self.attn_mode == "full":
117
+ h = sparse_scaled_dot_product_attention(qkv)
118
+ elif self.attn_mode == "serialized":
119
+ h = sparse_serialized_scaled_dot_product_self_attention(
120
+ qkv, self.window_size, serialize_mode=self.serialize_mode, shift_sequence=self.shift_sequence, shift_window=self.shift_window
121
+ )
122
+ elif self.attn_mode == "windowed":
123
+ h = sparse_windowed_scaled_dot_product_self_attention(
124
+ qkv, self.window_size, shift_window=self.shift_window
125
+ )
126
+ else:
127
+ q = self._linear(self.to_q, x)
128
+ q = self._reshape_chs(q, (self.num_heads, -1))
129
+ kv = self._linear(self.to_kv, context)
130
+ kv = self._fused_pre(kv, num_fused=2)
131
+ if self.qk_rms_norm:
132
+ q = self.q_rms_norm(q)
133
+ k, v = kv.unbind(dim=1)
134
+ k = self.k_rms_norm(k)
135
+ kv = kv.replace(torch.stack([k.feats, v.feats], dim=1))
136
+ h = sparse_scaled_dot_product_attention(q, kv)
137
+ h = self._reshape_chs(h, (-1,))
138
+ h = self._linear(self.to_out, h)
139
+ return h
trellis/modules/sparse/attention/serialized_attn.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ from enum import Enum
3
+ import torch
4
+ import math
5
+ from .. import SparseTensor
6
+ from .. import DEBUG, ATTN
7
+
8
+ if ATTN == 'xformers':
9
+ import xformers.ops as xops
10
+ elif ATTN == 'flash_attn':
11
+ import flash_attn
12
+ else:
13
+ raise ValueError(f"Unknown attention module: {ATTN}")
14
+
15
+
16
+ __all__ = [
17
+ 'sparse_serialized_scaled_dot_product_self_attention',
18
+ ]
19
+
20
+
21
+ class SerializeMode(Enum):
22
+ Z_ORDER = 0
23
+ Z_ORDER_TRANSPOSED = 1
24
+ HILBERT = 2
25
+ HILBERT_TRANSPOSED = 3
26
+
27
+
28
+ SerializeModes = [
29
+ SerializeMode.Z_ORDER,
30
+ SerializeMode.Z_ORDER_TRANSPOSED,
31
+ SerializeMode.HILBERT,
32
+ SerializeMode.HILBERT_TRANSPOSED
33
+ ]
34
+
35
+
36
+ def calc_serialization(
37
+ tensor: SparseTensor,
38
+ window_size: int,
39
+ serialize_mode: SerializeMode = SerializeMode.Z_ORDER,
40
+ shift_sequence: int = 0,
41
+ shift_window: Tuple[int, int, int] = (0, 0, 0)
42
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
43
+ """
44
+ Calculate serialization and partitioning for a set of coordinates.
45
+
46
+ Args:
47
+ tensor (SparseTensor): The input tensor.
48
+ window_size (int): The window size to use.
49
+ serialize_mode (SerializeMode): The serialization mode to use.
50
+ shift_sequence (int): The shift of serialized sequence.
51
+ shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
52
+
53
+ Returns:
54
+ (torch.Tensor, torch.Tensor): Forwards and backwards indices.
55
+ """
56
+ fwd_indices = []
57
+ bwd_indices = []
58
+ seq_lens = []
59
+ seq_batch_indices = []
60
+ offsets = [0]
61
+
62
+ if 'vox2seq' not in globals():
63
+ import vox2seq
64
+
65
+ # Serialize the input
66
+ serialize_coords = tensor.coords[:, 1:].clone()
67
+ serialize_coords += torch.tensor(shift_window, dtype=torch.int32, device=tensor.device).reshape(1, 3)
68
+ if serialize_mode == SerializeMode.Z_ORDER:
69
+ code = vox2seq.encode(serialize_coords, mode='z_order', permute=[0, 1, 2])
70
+ elif serialize_mode == SerializeMode.Z_ORDER_TRANSPOSED:
71
+ code = vox2seq.encode(serialize_coords, mode='z_order', permute=[1, 0, 2])
72
+ elif serialize_mode == SerializeMode.HILBERT:
73
+ code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[0, 1, 2])
74
+ elif serialize_mode == SerializeMode.HILBERT_TRANSPOSED:
75
+ code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[1, 0, 2])
76
+ else:
77
+ raise ValueError(f"Unknown serialize mode: {serialize_mode}")
78
+
79
+ for bi, s in enumerate(tensor.layout):
80
+ num_points = s.stop - s.start
81
+ num_windows = (num_points + window_size - 1) // window_size
82
+ valid_window_size = num_points / num_windows
83
+ to_ordered = torch.argsort(code[s.start:s.stop])
84
+ if num_windows == 1:
85
+ fwd_indices.append(to_ordered)
86
+ bwd_indices.append(torch.zeros_like(to_ordered).scatter_(0, to_ordered, torch.arange(num_points, device=tensor.device)))
87
+ fwd_indices[-1] += s.start
88
+ bwd_indices[-1] += offsets[-1]
89
+ seq_lens.append(num_points)
90
+ seq_batch_indices.append(bi)
91
+ offsets.append(offsets[-1] + seq_lens[-1])
92
+ else:
93
+ # Partition the input
94
+ offset = 0
95
+ mids = [(i + 0.5) * valid_window_size + shift_sequence for i in range(num_windows)]
96
+ split = [math.floor(i * valid_window_size + shift_sequence) for i in range(num_windows + 1)]
97
+ bwd_index = torch.zeros((num_points,), dtype=torch.int64, device=tensor.device)
98
+ for i in range(num_windows):
99
+ mid = mids[i]
100
+ valid_start = split[i]
101
+ valid_end = split[i + 1]
102
+ padded_start = math.floor(mid - 0.5 * window_size)
103
+ padded_end = padded_start + window_size
104
+ fwd_indices.append(to_ordered[torch.arange(padded_start, padded_end, device=tensor.device) % num_points])
105
+ offset += valid_start - padded_start
106
+ bwd_index.scatter_(0, fwd_indices[-1][valid_start-padded_start:valid_end-padded_start], torch.arange(offset, offset + valid_end - valid_start, device=tensor.device))
107
+ offset += padded_end - valid_start
108
+ fwd_indices[-1] += s.start
109
+ seq_lens.extend([window_size] * num_windows)
110
+ seq_batch_indices.extend([bi] * num_windows)
111
+ bwd_indices.append(bwd_index + offsets[-1])
112
+ offsets.append(offsets[-1] + num_windows * window_size)
113
+
114
+ fwd_indices = torch.cat(fwd_indices)
115
+ bwd_indices = torch.cat(bwd_indices)
116
+
117
+ return fwd_indices, bwd_indices, seq_lens, seq_batch_indices
118
+
119
+
120
+ def sparse_serialized_scaled_dot_product_self_attention(
121
+ qkv: SparseTensor,
122
+ window_size: int,
123
+ serialize_mode: SerializeMode = SerializeMode.Z_ORDER,
124
+ shift_sequence: int = 0,
125
+ shift_window: Tuple[int, int, int] = (0, 0, 0)
126
+ ) -> SparseTensor:
127
+ """
128
+ Apply serialized scaled dot product self attention to a sparse tensor.
129
+
130
+ Args:
131
+ qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs.
132
+ window_size (int): The window size to use.
133
+ serialize_mode (SerializeMode): The serialization mode to use.
134
+ shift_sequence (int): The shift of serialized sequence.
135
+ shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
136
+ shift (int): The shift to use.
137
+ """
138
+ assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
139
+
140
+ serialization_spatial_cache_name = f'serialization_{serialize_mode}_{window_size}_{shift_sequence}_{shift_window}'
141
+ serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name)
142
+ if serialization_spatial_cache is None:
143
+ fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_serialization(qkv, window_size, serialize_mode, shift_sequence, shift_window)
144
+ qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices))
145
+ else:
146
+ fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache
147
+
148
+ M = fwd_indices.shape[0]
149
+ T = qkv.feats.shape[0]
150
+ H = qkv.feats.shape[2]
151
+ C = qkv.feats.shape[3]
152
+
153
+ qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C]
154
+
155
+ if DEBUG:
156
+ start = 0
157
+ qkv_coords = qkv.coords[fwd_indices]
158
+ for i in range(len(seq_lens)):
159
+ assert (qkv_coords[start:start+seq_lens[i], 0] == seq_batch_indices[i]).all(), f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch"
160
+ start += seq_lens[i]
161
+
162
+ if all([seq_len == window_size for seq_len in seq_lens]):
163
+ B = len(seq_lens)
164
+ N = window_size
165
+ qkv_feats = qkv_feats.reshape(B, N, 3, H, C)
166
+ if ATTN == 'xformers':
167
+ q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C]
168
+ out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C]
169
+ elif ATTN == 'flash_attn':
170
+ out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C]
171
+ else:
172
+ raise ValueError(f"Unknown attention module: {ATTN}")
173
+ out = out.reshape(B * N, H, C) # [M, H, C]
174
+ else:
175
+ if ATTN == 'xformers':
176
+ q, k, v = qkv_feats.unbind(dim=1) # [M, H, C]
177
+ q = q.unsqueeze(0) # [1, M, H, C]
178
+ k = k.unsqueeze(0) # [1, M, H, C]
179
+ v = v.unsqueeze(0) # [1, M, H, C]
180
+ mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens)
181
+ out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C]
182
+ elif ATTN == 'flash_attn':
183
+ cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \
184
+ .to(qkv.device).int()
185
+ out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C]
186
+
187
+ out = out[bwd_indices] # [T, H, C]
188
+
189
+ if DEBUG:
190
+ qkv_coords = qkv_coords[bwd_indices]
191
+ assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch"
192
+
193
+ return qkv.replace(out)
trellis/modules/sparse/attention/windowed_attn.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import math
4
+ from .. import SparseTensor
5
+ from .. import DEBUG, ATTN
6
+
7
+ if ATTN == 'xformers':
8
+ import xformers.ops as xops
9
+ elif ATTN == 'flash_attn':
10
+ import flash_attn
11
+ else:
12
+ raise ValueError(f"Unknown attention module: {ATTN}")
13
+
14
+
15
+ __all__ = [
16
+ 'sparse_windowed_scaled_dot_product_self_attention',
17
+ ]
18
+
19
+
20
+ def calc_window_partition(
21
+ tensor: SparseTensor,
22
+ window_size: Union[int, Tuple[int, ...]],
23
+ shift_window: Union[int, Tuple[int, ...]] = 0
24
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[int], List[int]]:
25
+ """
26
+ Calculate serialization and partitioning for a set of coordinates.
27
+
28
+ Args:
29
+ tensor (SparseTensor): The input tensor.
30
+ window_size (int): The window size to use.
31
+ shift_window (Tuple[int, ...]): The shift of serialized coordinates.
32
+
33
+ Returns:
34
+ (torch.Tensor): Forwards indices.
35
+ (torch.Tensor): Backwards indices.
36
+ (List[int]): Sequence lengths.
37
+ (List[int]): Sequence batch indices.
38
+ """
39
+ DIM = tensor.coords.shape[1] - 1
40
+ shift_window = (shift_window,) * DIM if isinstance(shift_window, int) else shift_window
41
+ window_size = (window_size,) * DIM if isinstance(window_size, int) else window_size
42
+ shifted_coords = tensor.coords.clone().detach()
43
+ shifted_coords[:, 1:] += torch.tensor(shift_window, device=tensor.device, dtype=torch.int32).unsqueeze(0)
44
+
45
+ MAX_COORDS = shifted_coords[:, 1:].max(dim=0).values.tolist()
46
+ NUM_WINDOWS = [math.ceil((mc + 1) / ws) for mc, ws in zip(MAX_COORDS, window_size)]
47
+ OFFSET = torch.cumprod(torch.tensor([1] + NUM_WINDOWS[::-1]), dim=0).tolist()[::-1]
48
+
49
+ shifted_coords[:, 1:] //= torch.tensor(window_size, device=tensor.device, dtype=torch.int32).unsqueeze(0)
50
+ shifted_indices = (shifted_coords * torch.tensor(OFFSET, device=tensor.device, dtype=torch.int32).unsqueeze(0)).sum(dim=1)
51
+ fwd_indices = torch.argsort(shifted_indices)
52
+ bwd_indices = torch.empty_like(fwd_indices)
53
+ bwd_indices[fwd_indices] = torch.arange(fwd_indices.shape[0], device=tensor.device)
54
+ seq_lens = torch.bincount(shifted_indices)
55
+ seq_batch_indices = torch.arange(seq_lens.shape[0], device=tensor.device, dtype=torch.int32) // OFFSET[0]
56
+ mask = seq_lens != 0
57
+ seq_lens = seq_lens[mask].tolist()
58
+ seq_batch_indices = seq_batch_indices[mask].tolist()
59
+
60
+ return fwd_indices, bwd_indices, seq_lens, seq_batch_indices
61
+
62
+
63
+ def sparse_windowed_scaled_dot_product_self_attention(
64
+ qkv: SparseTensor,
65
+ window_size: int,
66
+ shift_window: Tuple[int, int, int] = (0, 0, 0)
67
+ ) -> SparseTensor:
68
+ """
69
+ Apply windowed scaled dot product self attention to a sparse tensor.
70
+
71
+ Args:
72
+ qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs.
73
+ window_size (int): The window size to use.
74
+ shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
75
+ shift (int): The shift to use.
76
+ """
77
+ assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
78
+
79
+ serialization_spatial_cache_name = f'window_partition_{window_size}_{shift_window}'
80
+ serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name)
81
+ if serialization_spatial_cache is None:
82
+ fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_window_partition(qkv, window_size, shift_window)
83
+ qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices))
84
+ else:
85
+ fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache
86
+
87
+ M = fwd_indices.shape[0]
88
+ T = qkv.feats.shape[0]
89
+ H = qkv.feats.shape[2]
90
+ C = qkv.feats.shape[3]
91
+
92
+ qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C]
93
+
94
+ if DEBUG:
95
+ start = 0
96
+ qkv_coords = qkv.coords[fwd_indices]
97
+ for i in range(len(seq_lens)):
98
+ seq_coords = qkv_coords[start:start+seq_lens[i]]
99
+ assert (seq_coords[:, 0] == seq_batch_indices[i]).all(), f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch"
100
+ assert (seq_coords[:, 1:].max(dim=0).values - seq_coords[:, 1:].min(dim=0).values < window_size).all(), \
101
+ f"SparseWindowedScaledDotProductSelfAttention: window size exceeded"
102
+ start += seq_lens[i]
103
+
104
+ if all([seq_len == window_size for seq_len in seq_lens]):
105
+ B = len(seq_lens)
106
+ N = window_size
107
+ qkv_feats = qkv_feats.reshape(B, N, 3, H, C)
108
+ if ATTN == 'xformers':
109
+ q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C]
110
+ out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C]
111
+ elif ATTN == 'flash_attn':
112
+ out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C]
113
+ else:
114
+ raise ValueError(f"Unknown attention module: {ATTN}")
115
+ out = out.reshape(B * N, H, C) # [M, H, C]
116
+ else:
117
+ if ATTN == 'xformers':
118
+ q, k, v = qkv_feats.unbind(dim=1) # [M, H, C]
119
+ q = q.unsqueeze(0) # [1, M, H, C]
120
+ k = k.unsqueeze(0) # [1, M, H, C]
121
+ v = v.unsqueeze(0) # [1, M, H, C]
122
+ mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens)
123
+ out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C]
124
+ elif ATTN == 'flash_attn':
125
+ cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \
126
+ .to(qkv.device).int()
127
+ out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C]
128
+
129
+ out = out[bwd_indices] # [T, H, C]
130
+
131
+ if DEBUG:
132
+ qkv_coords = qkv_coords[bwd_indices]
133
+ assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch"
134
+
135
+ return qkv.replace(out)
trellis/modules/sparse/basic.py ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ from . import BACKEND, DEBUG
5
+ SparseTensorData = None # Lazy import
6
+
7
+
8
+ __all__ = [
9
+ 'SparseTensor',
10
+ 'sparse_batch_broadcast',
11
+ 'sparse_batch_op',
12
+ 'sparse_cat',
13
+ 'sparse_unbind',
14
+ ]
15
+
16
+
17
+ class SparseTensor:
18
+ """
19
+ Sparse tensor with support for both torchsparse and spconv backends.
20
+
21
+ Parameters:
22
+ - feats (torch.Tensor): Features of the sparse tensor.
23
+ - coords (torch.Tensor): Coordinates of the sparse tensor.
24
+ - shape (torch.Size): Shape of the sparse tensor.
25
+ - layout (List[slice]): Layout of the sparse tensor for each batch
26
+ - data (SparseTensorData): Sparse tensor data used for convolusion
27
+
28
+ NOTE:
29
+ - Data corresponding to a same batch should be contiguous.
30
+ - Coords should be in [0, 1023]
31
+ """
32
+ @overload
33
+ def __init__(self, feats: torch.Tensor, coords: torch.Tensor, shape: Optional[torch.Size] = None, layout: Optional[List[slice]] = None, **kwargs): ...
34
+
35
+ @overload
36
+ def __init__(self, data, shape: Optional[torch.Size] = None, layout: Optional[List[slice]] = None, **kwargs): ...
37
+
38
+ def __init__(self, *args, **kwargs):
39
+ # Lazy import of sparse tensor backend
40
+ global SparseTensorData
41
+ if SparseTensorData is None:
42
+ import importlib
43
+ if BACKEND == 'torchsparse':
44
+ SparseTensorData = importlib.import_module('torchsparse').SparseTensor
45
+ elif BACKEND == 'spconv':
46
+ SparseTensorData = importlib.import_module('spconv.pytorch').SparseConvTensor
47
+
48
+ method_id = 0
49
+ if len(args) != 0:
50
+ method_id = 0 if isinstance(args[0], torch.Tensor) else 1
51
+ else:
52
+ method_id = 1 if 'data' in kwargs else 0
53
+
54
+ if method_id == 0:
55
+ feats, coords, shape, layout = args + (None,) * (4 - len(args))
56
+ if 'feats' in kwargs:
57
+ feats = kwargs['feats']
58
+ del kwargs['feats']
59
+ if 'coords' in kwargs:
60
+ coords = kwargs['coords']
61
+ del kwargs['coords']
62
+ if 'shape' in kwargs:
63
+ shape = kwargs['shape']
64
+ del kwargs['shape']
65
+ if 'layout' in kwargs:
66
+ layout = kwargs['layout']
67
+ del kwargs['layout']
68
+
69
+ if shape is None:
70
+ shape = self.__cal_shape(feats, coords)
71
+ if layout is None:
72
+ layout = self.__cal_layout(coords, shape[0])
73
+ if BACKEND == 'torchsparse':
74
+ self.data = SparseTensorData(feats, coords, **kwargs)
75
+ elif BACKEND == 'spconv':
76
+ spatial_shape = list(coords.max(0)[0] + 1)[1:]
77
+ self.data = SparseTensorData(feats.reshape(feats.shape[0], -1), coords, spatial_shape, shape[0], **kwargs)
78
+ self.data._features = feats
79
+ elif method_id == 1:
80
+ data, shape, layout = args + (None,) * (3 - len(args))
81
+ if 'data' in kwargs:
82
+ data = kwargs['data']
83
+ del kwargs['data']
84
+ if 'shape' in kwargs:
85
+ shape = kwargs['shape']
86
+ del kwargs['shape']
87
+ if 'layout' in kwargs:
88
+ layout = kwargs['layout']
89
+ del kwargs['layout']
90
+
91
+ self.data = data
92
+ if shape is None:
93
+ shape = self.__cal_shape(self.feats, self.coords)
94
+ if layout is None:
95
+ layout = self.__cal_layout(self.coords, shape[0])
96
+
97
+ self._shape = shape
98
+ self._layout = layout
99
+ self._scale = kwargs.get('scale', (1, 1, 1))
100
+ self._spatial_cache = kwargs.get('spatial_cache', {})
101
+
102
+ if DEBUG:
103
+ try:
104
+ assert self.feats.shape[0] == self.coords.shape[0], f"Invalid feats shape: {self.feats.shape}, coords shape: {self.coords.shape}"
105
+ assert self.shape == self.__cal_shape(self.feats, self.coords), f"Invalid shape: {self.shape}"
106
+ assert self.layout == self.__cal_layout(self.coords, self.shape[0]), f"Invalid layout: {self.layout}"
107
+ for i in range(self.shape[0]):
108
+ assert torch.all(self.coords[self.layout[i], 0] == i), f"The data of batch {i} is not contiguous"
109
+ except Exception as e:
110
+ print('Debugging information:')
111
+ print(f"- Shape: {self.shape}")
112
+ print(f"- Layout: {self.layout}")
113
+ print(f"- Scale: {self._scale}")
114
+ print(f"- Coords: {self.coords}")
115
+ raise e
116
+
117
+ def __cal_shape(self, feats, coords):
118
+ shape = []
119
+ shape.append(coords[:, 0].max().item() + 1)
120
+ shape.extend([*feats.shape[1:]])
121
+ return torch.Size(shape)
122
+
123
+ def __cal_layout(self, coords, batch_size):
124
+ seq_len = torch.bincount(coords[:, 0], minlength=batch_size)
125
+ offset = torch.cumsum(seq_len, dim=0)
126
+ layout = [slice((offset[i] - seq_len[i]).item(), offset[i].item()) for i in range(batch_size)]
127
+ return layout
128
+
129
+ @property
130
+ def shape(self) -> torch.Size:
131
+ return self._shape
132
+
133
+ def dim(self) -> int:
134
+ return len(self.shape)
135
+
136
+ @property
137
+ def layout(self) -> List[slice]:
138
+ return self._layout
139
+
140
+ @property
141
+ def feats(self) -> torch.Tensor:
142
+ if BACKEND == 'torchsparse':
143
+ return self.data.F
144
+ elif BACKEND == 'spconv':
145
+ return self.data.features
146
+
147
+ @feats.setter
148
+ def feats(self, value: torch.Tensor):
149
+ if BACKEND == 'torchsparse':
150
+ self.data.F = value
151
+ elif BACKEND == 'spconv':
152
+ self.data.features = value
153
+
154
+ @property
155
+ def coords(self) -> torch.Tensor:
156
+ if BACKEND == 'torchsparse':
157
+ return self.data.C
158
+ elif BACKEND == 'spconv':
159
+ return self.data.indices
160
+
161
+ @coords.setter
162
+ def coords(self, value: torch.Tensor):
163
+ if BACKEND == 'torchsparse':
164
+ self.data.C = value
165
+ elif BACKEND == 'spconv':
166
+ self.data.indices = value
167
+
168
+ @property
169
+ def dtype(self):
170
+ return self.feats.dtype
171
+
172
+ @property
173
+ def device(self):
174
+ return self.feats.device
175
+
176
+ @overload
177
+ def to(self, dtype: torch.dtype) -> 'SparseTensor': ...
178
+
179
+ @overload
180
+ def to(self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None) -> 'SparseTensor': ...
181
+
182
+ def to(self, *args, **kwargs) -> 'SparseTensor':
183
+ device = None
184
+ dtype = None
185
+ if len(args) == 2:
186
+ device, dtype = args
187
+ elif len(args) == 1:
188
+ if isinstance(args[0], torch.dtype):
189
+ dtype = args[0]
190
+ else:
191
+ device = args[0]
192
+ if 'dtype' in kwargs:
193
+ assert dtype is None, "to() received multiple values for argument 'dtype'"
194
+ dtype = kwargs['dtype']
195
+ if 'device' in kwargs:
196
+ assert device is None, "to() received multiple values for argument 'device'"
197
+ device = kwargs['device']
198
+
199
+ new_feats = self.feats.to(device=device, dtype=dtype)
200
+ new_coords = self.coords.to(device=device)
201
+ return self.replace(new_feats, new_coords)
202
+
203
+ def type(self, dtype):
204
+ new_feats = self.feats.type(dtype)
205
+ return self.replace(new_feats)
206
+
207
+ def cpu(self) -> 'SparseTensor':
208
+ new_feats = self.feats.cpu()
209
+ new_coords = self.coords.cpu()
210
+ return self.replace(new_feats, new_coords)
211
+
212
+ def cuda(self) -> 'SparseTensor':
213
+ new_feats = self.feats.cuda()
214
+ new_coords = self.coords.cuda()
215
+ return self.replace(new_feats, new_coords)
216
+
217
+ def half(self) -> 'SparseTensor':
218
+ new_feats = self.feats.half()
219
+ return self.replace(new_feats)
220
+
221
+ def float(self) -> 'SparseTensor':
222
+ new_feats = self.feats.float()
223
+ return self.replace(new_feats)
224
+
225
+ def detach(self) -> 'SparseTensor':
226
+ new_coords = self.coords.detach()
227
+ new_feats = self.feats.detach()
228
+ return self.replace(new_feats, new_coords)
229
+
230
+ def dense(self) -> torch.Tensor:
231
+ if BACKEND == 'torchsparse':
232
+ return self.data.dense()
233
+ elif BACKEND == 'spconv':
234
+ return self.data.dense()
235
+
236
+ def reshape(self, *shape) -> 'SparseTensor':
237
+ new_feats = self.feats.reshape(self.feats.shape[0], *shape)
238
+ return self.replace(new_feats)
239
+
240
+ def unbind(self, dim: int) -> List['SparseTensor']:
241
+ return sparse_unbind(self, dim)
242
+
243
+ def replace(self, feats: torch.Tensor, coords: Optional[torch.Tensor] = None) -> 'SparseTensor':
244
+ new_shape = [self.shape[0]]
245
+ new_shape.extend(feats.shape[1:])
246
+ if BACKEND == 'torchsparse':
247
+ new_data = SparseTensorData(
248
+ feats=feats,
249
+ coords=self.data.coords if coords is None else coords,
250
+ stride=self.data.stride,
251
+ spatial_range=self.data.spatial_range,
252
+ )
253
+ new_data._caches = self.data._caches
254
+ elif BACKEND == 'spconv':
255
+ new_data = SparseTensorData(
256
+ self.data.features.reshape(self.data.features.shape[0], -1),
257
+ self.data.indices,
258
+ self.data.spatial_shape,
259
+ self.data.batch_size,
260
+ self.data.grid,
261
+ self.data.voxel_num,
262
+ self.data.indice_dict
263
+ )
264
+ new_data._features = feats
265
+ new_data.benchmark = self.data.benchmark
266
+ new_data.benchmark_record = self.data.benchmark_record
267
+ new_data.thrust_allocator = self.data.thrust_allocator
268
+ new_data._timer = self.data._timer
269
+ new_data.force_algo = self.data.force_algo
270
+ new_data.int8_scale = self.data.int8_scale
271
+ if coords is not None:
272
+ new_data.indices = coords
273
+ new_tensor = SparseTensor(new_data, shape=torch.Size(new_shape), layout=self.layout, scale=self._scale, spatial_cache=self._spatial_cache)
274
+ return new_tensor
275
+
276
+ @staticmethod
277
+ def full(aabb, dim, value, dtype=torch.float32, device=None) -> 'SparseTensor':
278
+ N, C = dim
279
+ x = torch.arange(aabb[0], aabb[3] + 1)
280
+ y = torch.arange(aabb[1], aabb[4] + 1)
281
+ z = torch.arange(aabb[2], aabb[5] + 1)
282
+ coords = torch.stack(torch.meshgrid(x, y, z, indexing='ij'), dim=-1).reshape(-1, 3)
283
+ coords = torch.cat([
284
+ torch.arange(N).view(-1, 1).repeat(1, coords.shape[0]).view(-1, 1),
285
+ coords.repeat(N, 1),
286
+ ], dim=1).to(dtype=torch.int32, device=device)
287
+ feats = torch.full((coords.shape[0], C), value, dtype=dtype, device=device)
288
+ return SparseTensor(feats=feats, coords=coords)
289
+
290
+ def __merge_sparse_cache(self, other: 'SparseTensor') -> dict:
291
+ new_cache = {}
292
+ for k in set(list(self._spatial_cache.keys()) + list(other._spatial_cache.keys())):
293
+ if k in self._spatial_cache:
294
+ new_cache[k] = self._spatial_cache[k]
295
+ if k in other._spatial_cache:
296
+ if k not in new_cache:
297
+ new_cache[k] = other._spatial_cache[k]
298
+ else:
299
+ new_cache[k].update(other._spatial_cache[k])
300
+ return new_cache
301
+
302
+ def __neg__(self) -> 'SparseTensor':
303
+ return self.replace(-self.feats)
304
+
305
+ def __elemwise__(self, other: Union[torch.Tensor, 'SparseTensor'], op: callable) -> 'SparseTensor':
306
+ if isinstance(other, torch.Tensor):
307
+ try:
308
+ other = torch.broadcast_to(other, self.shape)
309
+ other = sparse_batch_broadcast(self, other)
310
+ except:
311
+ pass
312
+ if isinstance(other, SparseTensor):
313
+ other = other.feats
314
+ new_feats = op(self.feats, other)
315
+ new_tensor = self.replace(new_feats)
316
+ if isinstance(other, SparseTensor):
317
+ new_tensor._spatial_cache = self.__merge_sparse_cache(other)
318
+ return new_tensor
319
+
320
+ def __add__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
321
+ return self.__elemwise__(other, torch.add)
322
+
323
+ def __radd__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
324
+ return self.__elemwise__(other, torch.add)
325
+
326
+ def __sub__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
327
+ return self.__elemwise__(other, torch.sub)
328
+
329
+ def __rsub__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
330
+ return self.__elemwise__(other, lambda x, y: torch.sub(y, x))
331
+
332
+ def __mul__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
333
+ return self.__elemwise__(other, torch.mul)
334
+
335
+ def __rmul__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
336
+ return self.__elemwise__(other, torch.mul)
337
+
338
+ def __truediv__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
339
+ return self.__elemwise__(other, torch.div)
340
+
341
+ def __rtruediv__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
342
+ return self.__elemwise__(other, lambda x, y: torch.div(y, x))
343
+
344
+ def __getitem__(self, idx):
345
+ if isinstance(idx, int):
346
+ idx = [idx]
347
+ elif isinstance(idx, slice):
348
+ idx = range(*idx.indices(self.shape[0]))
349
+ elif isinstance(idx, torch.Tensor):
350
+ if idx.dtype == torch.bool:
351
+ assert idx.shape == (self.shape[0],), f"Invalid index shape: {idx.shape}"
352
+ idx = idx.nonzero().squeeze(1)
353
+ elif idx.dtype in [torch.int32, torch.int64]:
354
+ assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}"
355
+ else:
356
+ raise ValueError(f"Unknown index type: {idx.dtype}")
357
+ else:
358
+ raise ValueError(f"Unknown index type: {type(idx)}")
359
+
360
+ coords = []
361
+ feats = []
362
+ for new_idx, old_idx in enumerate(idx):
363
+ coords.append(self.coords[self.layout[old_idx]].clone())
364
+ coords[-1][:, 0] = new_idx
365
+ feats.append(self.feats[self.layout[old_idx]])
366
+ coords = torch.cat(coords, dim=0).contiguous()
367
+ feats = torch.cat(feats, dim=0).contiguous()
368
+ return SparseTensor(feats=feats, coords=coords)
369
+
370
+ def register_spatial_cache(self, key, value) -> None:
371
+ """
372
+ Register a spatial cache.
373
+ The spatial cache can be any thing you want to cache.
374
+ The registery and retrieval of the cache is based on current scale.
375
+ """
376
+ scale_key = str(self._scale)
377
+ if scale_key not in self._spatial_cache:
378
+ self._spatial_cache[scale_key] = {}
379
+ self._spatial_cache[scale_key][key] = value
380
+
381
+ def get_spatial_cache(self, key=None):
382
+ """
383
+ Get a spatial cache.
384
+ """
385
+ scale_key = str(self._scale)
386
+ cur_scale_cache = self._spatial_cache.get(scale_key, {})
387
+ if key is None:
388
+ return cur_scale_cache
389
+ return cur_scale_cache.get(key, None)
390
+
391
+
392
+ def sparse_batch_broadcast(input: SparseTensor, other: torch.Tensor) -> torch.Tensor:
393
+ """
394
+ Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation.
395
+
396
+ Args:
397
+ input (torch.Tensor): 1D tensor to broadcast.
398
+ target (SparseTensor): Sparse tensor to broadcast to.
399
+ op (callable): Operation to perform after broadcasting. Defaults to torch.add.
400
+ """
401
+ coords, feats = input.coords, input.feats
402
+ broadcasted = torch.zeros_like(feats)
403
+ for k in range(input.shape[0]):
404
+ broadcasted[input.layout[k]] = other[k]
405
+ return broadcasted
406
+
407
+
408
+ def sparse_batch_op(input: SparseTensor, other: torch.Tensor, op: callable = torch.add) -> SparseTensor:
409
+ """
410
+ Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation.
411
+
412
+ Args:
413
+ input (torch.Tensor): 1D tensor to broadcast.
414
+ target (SparseTensor): Sparse tensor to broadcast to.
415
+ op (callable): Operation to perform after broadcasting. Defaults to torch.add.
416
+ """
417
+ return input.replace(op(input.feats, sparse_batch_broadcast(input, other)))
418
+
419
+
420
+ def sparse_cat(inputs: List[SparseTensor], dim: int = 0) -> SparseTensor:
421
+ """
422
+ Concatenate a list of sparse tensors.
423
+
424
+ Args:
425
+ inputs (List[SparseTensor]): List of sparse tensors to concatenate.
426
+ """
427
+ if dim == 0:
428
+ start = 0
429
+ coords = []
430
+ for input in inputs:
431
+ coords.append(input.coords.clone())
432
+ coords[-1][:, 0] += start
433
+ start += input.shape[0]
434
+ coords = torch.cat(coords, dim=0)
435
+ feats = torch.cat([input.feats for input in inputs], dim=0)
436
+ output = SparseTensor(
437
+ coords=coords,
438
+ feats=feats,
439
+ )
440
+ else:
441
+ feats = torch.cat([input.feats for input in inputs], dim=dim)
442
+ output = inputs[0].replace(feats)
443
+
444
+ return output
445
+
446
+
447
+ def sparse_unbind(input: SparseTensor, dim: int) -> List[SparseTensor]:
448
+ """
449
+ Unbind a sparse tensor along a dimension.
450
+
451
+ Args:
452
+ input (SparseTensor): Sparse tensor to unbind.
453
+ dim (int): Dimension to unbind.
454
+ """
455
+ if dim == 0:
456
+ return [input[i] for i in range(input.shape[0])]
457
+ else:
458
+ feats = input.feats.unbind(dim)
459
+ return [input.replace(f) for f in feats]
trellis/modules/sparse/conv/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .. import BACKEND
2
+
3
+
4
+ SPCONV_ALGO = 'auto' # 'auto', 'implicit_gemm', 'native'
5
+
6
+ def __from_env():
7
+ import os
8
+
9
+ global SPCONV_ALGO
10
+ env_spconv_algo = os.environ.get('SPCONV_ALGO')
11
+ if env_spconv_algo is not None and env_spconv_algo in ['auto', 'implicit_gemm', 'native']:
12
+ SPCONV_ALGO = env_spconv_algo
13
+ print(f"[SPARSE][CONV] spconv algo: {SPCONV_ALGO}")
14
+
15
+
16
+ __from_env()
17
+
18
+ if BACKEND == 'torchsparse':
19
+ from .conv_torchsparse import *
20
+ elif BACKEND == 'spconv':
21
+ from .conv_spconv import *
trellis/modules/sparse/conv/conv_spconv.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .. import SparseTensor
4
+ from .. import DEBUG
5
+ from . import SPCONV_ALGO
6
+
7
+ class SparseConv3d(nn.Module):
8
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None):
9
+ super(SparseConv3d, self).__init__()
10
+ if 'spconv' not in globals():
11
+ import spconv.pytorch as spconv
12
+ algo = None
13
+ if SPCONV_ALGO == 'native':
14
+ algo = spconv.ConvAlgo.Native
15
+ elif SPCONV_ALGO == 'implicit_gemm':
16
+ algo = spconv.ConvAlgo.MaskImplicitGemm
17
+ if stride == 1 and (padding is None):
18
+ self.conv = spconv.SubMConv3d(in_channels, out_channels, kernel_size, dilation=dilation, bias=bias, indice_key=indice_key, algo=algo)
19
+ else:
20
+ self.conv = spconv.SparseConv3d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, padding=padding, bias=bias, indice_key=indice_key, algo=algo)
21
+ self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride)
22
+ self.padding = padding
23
+
24
+ def forward(self, x: SparseTensor) -> SparseTensor:
25
+ spatial_changed = any(s != 1 for s in self.stride) or (self.padding is not None)
26
+ new_data = self.conv(x.data)
27
+ new_shape = [x.shape[0], self.conv.out_channels]
28
+ new_layout = None if spatial_changed else x.layout
29
+
30
+ if spatial_changed and (x.shape[0] != 1):
31
+ # spconv was non-1 stride will break the contiguous of the output tensor, sort by the coords
32
+ fwd = new_data.indices[:, 0].argsort()
33
+ bwd = torch.zeros_like(fwd).scatter_(0, fwd, torch.arange(fwd.shape[0], device=fwd.device))
34
+ sorted_feats = new_data.features[fwd]
35
+ sorted_coords = new_data.indices[fwd]
36
+ unsorted_data = new_data
37
+ new_data = spconv.SparseConvTensor(sorted_feats, sorted_coords, unsorted_data.spatial_shape, unsorted_data.batch_size) # type: ignore
38
+
39
+ out = SparseTensor(
40
+ new_data, shape=torch.Size(new_shape), layout=new_layout,
41
+ scale=tuple([s * stride for s, stride in zip(x._scale, self.stride)]),
42
+ spatial_cache=x._spatial_cache,
43
+ )
44
+
45
+ if spatial_changed and (x.shape[0] != 1):
46
+ out.register_spatial_cache(f'conv_{self.stride}_unsorted_data', unsorted_data)
47
+ out.register_spatial_cache(f'conv_{self.stride}_sort_bwd', bwd)
48
+
49
+ return out
50
+
51
+
52
+ class SparseInverseConv3d(nn.Module):
53
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None):
54
+ super(SparseInverseConv3d, self).__init__()
55
+ if 'spconv' not in globals():
56
+ import spconv.pytorch as spconv
57
+ self.conv = spconv.SparseInverseConv3d(in_channels, out_channels, kernel_size, bias=bias, indice_key=indice_key)
58
+ self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride)
59
+
60
+ def forward(self, x: SparseTensor) -> SparseTensor:
61
+ spatial_changed = any(s != 1 for s in self.stride)
62
+ if spatial_changed:
63
+ # recover the original spconv order
64
+ data = x.get_spatial_cache(f'conv_{self.stride}_unsorted_data')
65
+ bwd = x.get_spatial_cache(f'conv_{self.stride}_sort_bwd')
66
+ data = data.replace_feature(x.feats[bwd])
67
+ if DEBUG:
68
+ assert torch.equal(data.indices, x.coords[bwd]), 'Recover the original order failed'
69
+ else:
70
+ data = x.data
71
+
72
+ new_data = self.conv(data)
73
+ new_shape = [x.shape[0], self.conv.out_channels]
74
+ new_layout = None if spatial_changed else x.layout
75
+ out = SparseTensor(
76
+ new_data, shape=torch.Size(new_shape), layout=new_layout,
77
+ scale=tuple([s // stride for s, stride in zip(x._scale, self.stride)]),
78
+ spatial_cache=x._spatial_cache,
79
+ )
80
+ return out
trellis/modules/sparse/conv/conv_torchsparse.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .. import SparseTensor
4
+
5
+
6
+ class SparseConv3d(nn.Module):
7
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None):
8
+ super(SparseConv3d, self).__init__()
9
+ if 'torchsparse' not in globals():
10
+ import torchsparse
11
+ self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias)
12
+
13
+ def forward(self, x: SparseTensor) -> SparseTensor:
14
+ out = self.conv(x.data)
15
+ new_shape = [x.shape[0], self.conv.out_channels]
16
+ out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None)
17
+ out._spatial_cache = x._spatial_cache
18
+ out._scale = tuple([s * stride for s, stride in zip(x._scale, self.conv.stride)])
19
+ return out
20
+
21
+
22
+ class SparseInverseConv3d(nn.Module):
23
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None):
24
+ super(SparseInverseConv3d, self).__init__()
25
+ if 'torchsparse' not in globals():
26
+ import torchsparse
27
+ self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias, transposed=True)
28
+
29
+ def forward(self, x: SparseTensor) -> SparseTensor:
30
+ out = self.conv(x.data)
31
+ new_shape = [x.shape[0], self.conv.out_channels]
32
+ out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None)
33
+ out._spatial_cache = x._spatial_cache
34
+ out._scale = tuple([s // stride for s, stride in zip(x._scale, self.conv.stride)])
35
+ return out
36
+
37
+
38
+
trellis/modules/sparse/linear.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from . import SparseTensor
4
+
5
+ __all__ = [
6
+ 'SparseLinear'
7
+ ]
8
+
9
+
10
+ class SparseLinear(nn.Linear):
11
+ def __init__(self, in_features, out_features, bias=True):
12
+ super(SparseLinear, self).__init__(in_features, out_features, bias)
13
+
14
+ def forward(self, input: SparseTensor) -> SparseTensor:
15
+ return input.replace(super().forward(input.feats))
trellis/modules/sparse/nonlinearity.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from . import SparseTensor
4
+
5
+ __all__ = [
6
+ 'SparseReLU',
7
+ 'SparseSiLU',
8
+ 'SparseGELU',
9
+ 'SparseActivation'
10
+ ]
11
+
12
+
13
+ class SparseReLU(nn.ReLU):
14
+ def forward(self, input: SparseTensor) -> SparseTensor:
15
+ return input.replace(super().forward(input.feats))
16
+
17
+
18
+ class SparseSiLU(nn.SiLU):
19
+ def forward(self, input: SparseTensor) -> SparseTensor:
20
+ return input.replace(super().forward(input.feats))
21
+
22
+
23
+ class SparseGELU(nn.GELU):
24
+ def forward(self, input: SparseTensor) -> SparseTensor:
25
+ return input.replace(super().forward(input.feats))
26
+
27
+
28
+ class SparseActivation(nn.Module):
29
+ def __init__(self, activation: nn.Module):
30
+ super().__init__()
31
+ self.activation = activation
32
+
33
+ def forward(self, input: SparseTensor) -> SparseTensor:
34
+ return input.replace(self.activation(input.feats))
35
+
trellis/modules/sparse/norm.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from . import SparseTensor
4
+ from . import DEBUG
5
+
6
+ __all__ = [
7
+ 'SparseGroupNorm',
8
+ 'SparseLayerNorm',
9
+ 'SparseGroupNorm32',
10
+ 'SparseLayerNorm32',
11
+ ]
12
+
13
+
14
+ class SparseGroupNorm(nn.GroupNorm):
15
+ def __init__(self, num_groups, num_channels, eps=1e-5, affine=True):
16
+ super(SparseGroupNorm, self).__init__(num_groups, num_channels, eps, affine)
17
+
18
+ def forward(self, input: SparseTensor) -> SparseTensor:
19
+ nfeats = torch.zeros_like(input.feats)
20
+ for k in range(input.shape[0]):
21
+ if DEBUG:
22
+ assert (input.coords[input.layout[k], 0] == k).all(), f"SparseGroupNorm: batch index mismatch"
23
+ bfeats = input.feats[input.layout[k]]
24
+ bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1)
25
+ bfeats = super().forward(bfeats)
26
+ bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0)
27
+ nfeats[input.layout[k]] = bfeats
28
+ return input.replace(nfeats)
29
+
30
+
31
+ class SparseLayerNorm(nn.LayerNorm):
32
+ def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
33
+ super(SparseLayerNorm, self).__init__(normalized_shape, eps, elementwise_affine)
34
+
35
+ def forward(self, input: SparseTensor) -> SparseTensor:
36
+ nfeats = torch.zeros_like(input.feats)
37
+ for k in range(input.shape[0]):
38
+ bfeats = input.feats[input.layout[k]]
39
+ bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1)
40
+ bfeats = super().forward(bfeats)
41
+ bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0)
42
+ nfeats[input.layout[k]] = bfeats
43
+ return input.replace(nfeats)
44
+
45
+
46
+ class SparseGroupNorm32(SparseGroupNorm):
47
+ """
48
+ A GroupNorm layer that converts to float32 before the forward pass.
49
+ """
50
+ def forward(self, x: SparseTensor) -> SparseTensor:
51
+ return super().forward(x.float()).type(x.dtype)
52
+
53
+ class SparseLayerNorm32(SparseLayerNorm):
54
+ """
55
+ A LayerNorm layer that converts to float32 before the forward pass.
56
+ """
57
+ def forward(self, x: SparseTensor) -> SparseTensor:
58
+ return super().forward(x.float()).type(x.dtype)
trellis/modules/sparse/spatial.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ from . import SparseTensor
5
+
6
+ __all__ = [
7
+ 'SparseDownsample',
8
+ 'SparseUpsample',
9
+ 'SparseSubdivide'
10
+ ]
11
+
12
+
13
+ class SparseDownsample(nn.Module):
14
+ """
15
+ Downsample a sparse tensor by a factor of `factor`.
16
+ Implemented as average pooling.
17
+ """
18
+ def __init__(self, factor: Union[int, Tuple[int, ...], List[int]]):
19
+ super(SparseDownsample, self).__init__()
20
+ self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor
21
+
22
+ def forward(self, input: SparseTensor) -> SparseTensor:
23
+ DIM = input.coords.shape[-1] - 1
24
+ factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM
25
+ assert DIM == len(factor), 'Input coordinates must have the same dimension as the downsample factor.'
26
+
27
+ coord = list(input.coords.unbind(dim=-1))
28
+ for i, f in enumerate(factor):
29
+ coord[i+1] = coord[i+1] // f
30
+
31
+ MAX = [coord[i+1].max().item() + 1 for i in range(DIM)]
32
+ OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1]
33
+ code = sum([c * o for c, o in zip(coord, OFFSET)])
34
+ code, idx = code.unique(return_inverse=True)
35
+
36
+ new_feats = torch.scatter_reduce(
37
+ torch.zeros(code.shape[0], input.feats.shape[1], device=input.feats.device, dtype=input.feats.dtype),
38
+ dim=0,
39
+ index=idx.unsqueeze(1).expand(-1, input.feats.shape[1]),
40
+ src=input.feats,
41
+ reduce='mean'
42
+ )
43
+ new_coords = torch.stack(
44
+ [code // OFFSET[0]] +
45
+ [(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)],
46
+ dim=-1
47
+ )
48
+ out = SparseTensor(new_feats, new_coords, input.shape,)
49
+ out._scale = tuple([s // f for s, f in zip(input._scale, factor)])
50
+ out._spatial_cache = input._spatial_cache
51
+
52
+ out.register_spatial_cache(f'upsample_{factor}_coords', input.coords)
53
+ out.register_spatial_cache(f'upsample_{factor}_layout', input.layout)
54
+ out.register_spatial_cache(f'upsample_{factor}_idx', idx)
55
+
56
+ return out
57
+
58
+
59
+ class SparseUpsample(nn.Module):
60
+ """
61
+ Upsample a sparse tensor by a factor of `factor`.
62
+ Implemented as nearest neighbor interpolation.
63
+ """
64
+ def __init__(self, factor: Union[int, Tuple[int, int, int], List[int]]):
65
+ super(SparseUpsample, self).__init__()
66
+ self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor
67
+
68
+ def forward(self, input: SparseTensor) -> SparseTensor:
69
+ DIM = input.coords.shape[-1] - 1
70
+ factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM
71
+ assert DIM == len(factor), 'Input coordinates must have the same dimension as the upsample factor.'
72
+
73
+ new_coords = input.get_spatial_cache(f'upsample_{factor}_coords')
74
+ new_layout = input.get_spatial_cache(f'upsample_{factor}_layout')
75
+ idx = input.get_spatial_cache(f'upsample_{factor}_idx')
76
+ if any([x is None for x in [new_coords, new_layout, idx]]):
77
+ raise ValueError('Upsample cache not found. SparseUpsample must be paired with SparseDownsample.')
78
+ new_feats = input.feats[idx]
79
+ out = SparseTensor(new_feats, new_coords, input.shape, new_layout)
80
+ out._scale = tuple([s * f for s, f in zip(input._scale, factor)])
81
+ out._spatial_cache = input._spatial_cache
82
+ return out
83
+
84
+ class SparseSubdivide(nn.Module):
85
+ """
86
+ Upsample a sparse tensor by a factor of `factor`.
87
+ Implemented as nearest neighbor interpolation.
88
+ """
89
+ def __init__(self):
90
+ super(SparseSubdivide, self).__init__()
91
+
92
+ def forward(self, input: SparseTensor) -> SparseTensor:
93
+ DIM = input.coords.shape[-1] - 1
94
+ # upsample scale=2^DIM
95
+ n_cube = torch.ones([2] * DIM, device=input.device, dtype=torch.int)
96
+ n_coords = torch.nonzero(n_cube)
97
+ n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1)
98
+ factor = n_coords.shape[0]
99
+ assert factor == 2 ** DIM
100
+ # print(n_coords.shape)
101
+ new_coords = input.coords.clone()
102
+ new_coords[:, 1:] *= 2
103
+ new_coords = new_coords.unsqueeze(1) + n_coords.unsqueeze(0).to(new_coords.dtype)
104
+
105
+ new_feats = input.feats.unsqueeze(1).expand(input.feats.shape[0], factor, *input.feats.shape[1:])
106
+ out = SparseTensor(new_feats.flatten(0, 1), new_coords.flatten(0, 1), input.shape)
107
+ out._scale = input._scale * 2
108
+ out._spatial_cache = input._spatial_cache
109
+ return out
110
+
trellis/modules/sparse/transformer/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .blocks import *
2
+ from .modulated import *
trellis/modules/sparse/transformer/blocks.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ from ..basic import SparseTensor
5
+ from ..linear import SparseLinear
6
+ from ..nonlinearity import SparseGELU
7
+ from ..attention import SparseMultiHeadAttention, SerializeMode
8
+ from ...norm import LayerNorm32
9
+
10
+
11
+ class SparseFeedForwardNet(nn.Module):
12
+ def __init__(self, channels: int, mlp_ratio: float = 4.0):
13
+ super().__init__()
14
+ self.mlp = nn.Sequential(
15
+ SparseLinear(channels, int(channels * mlp_ratio)),
16
+ SparseGELU(approximate="tanh"),
17
+ SparseLinear(int(channels * mlp_ratio), channels),
18
+ )
19
+
20
+ def forward(self, x: SparseTensor) -> SparseTensor:
21
+ return self.mlp(x)
22
+
23
+
24
+ class SparseTransformerBlock(nn.Module):
25
+ """
26
+ Sparse Transformer block (MSA + FFN).
27
+ """
28
+ def __init__(
29
+ self,
30
+ channels: int,
31
+ num_heads: int,
32
+ mlp_ratio: float = 4.0,
33
+ attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
34
+ window_size: Optional[int] = None,
35
+ shift_sequence: Optional[int] = None,
36
+ shift_window: Optional[Tuple[int, int, int]] = None,
37
+ serialize_mode: Optional[SerializeMode] = None,
38
+ use_checkpoint: bool = False,
39
+ use_rope: bool = False,
40
+ qk_rms_norm: bool = False,
41
+ qkv_bias: bool = True,
42
+ ln_affine: bool = False,
43
+ ):
44
+ super().__init__()
45
+ self.use_checkpoint = use_checkpoint
46
+ self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
47
+ self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
48
+ self.attn = SparseMultiHeadAttention(
49
+ channels,
50
+ num_heads=num_heads,
51
+ attn_mode=attn_mode,
52
+ window_size=window_size,
53
+ shift_sequence=shift_sequence,
54
+ shift_window=shift_window,
55
+ serialize_mode=serialize_mode,
56
+ qkv_bias=qkv_bias,
57
+ use_rope=use_rope,
58
+ qk_rms_norm=qk_rms_norm,
59
+ )
60
+ self.mlp = SparseFeedForwardNet(
61
+ channels,
62
+ mlp_ratio=mlp_ratio,
63
+ )
64
+
65
+ def _forward(self, x: SparseTensor) -> SparseTensor:
66
+ h = x.replace(self.norm1(x.feats))
67
+ h = self.attn(h)
68
+ x = x + h
69
+ h = x.replace(self.norm2(x.feats))
70
+ h = self.mlp(h)
71
+ x = x + h
72
+ return x
73
+
74
+ def forward(self, x: SparseTensor) -> SparseTensor:
75
+ if self.use_checkpoint:
76
+ return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
77
+ else:
78
+ return self._forward(x)
79
+
80
+
81
+ class SparseTransformerCrossBlock(nn.Module):
82
+ """
83
+ Sparse Transformer cross-attention block (MSA + MCA + FFN).
84
+ """
85
+ def __init__(
86
+ self,
87
+ channels: int,
88
+ ctx_channels: int,
89
+ num_heads: int,
90
+ mlp_ratio: float = 4.0,
91
+ attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
92
+ window_size: Optional[int] = None,
93
+ shift_sequence: Optional[int] = None,
94
+ shift_window: Optional[Tuple[int, int, int]] = None,
95
+ serialize_mode: Optional[SerializeMode] = None,
96
+ use_checkpoint: bool = False,
97
+ use_rope: bool = False,
98
+ qk_rms_norm: bool = False,
99
+ qk_rms_norm_cross: bool = False,
100
+ qkv_bias: bool = True,
101
+ ln_affine: bool = False,
102
+ ):
103
+ super().__init__()
104
+ self.use_checkpoint = use_checkpoint
105
+ self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
106
+ self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
107
+ self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
108
+ self.self_attn = SparseMultiHeadAttention(
109
+ channels,
110
+ num_heads=num_heads,
111
+ type="self",
112
+ attn_mode=attn_mode,
113
+ window_size=window_size,
114
+ shift_sequence=shift_sequence,
115
+ shift_window=shift_window,
116
+ serialize_mode=serialize_mode,
117
+ qkv_bias=qkv_bias,
118
+ use_rope=use_rope,
119
+ qk_rms_norm=qk_rms_norm,
120
+ )
121
+ self.cross_attn = SparseMultiHeadAttention(
122
+ channels,
123
+ ctx_channels=ctx_channels,
124
+ num_heads=num_heads,
125
+ type="cross",
126
+ attn_mode="full",
127
+ qkv_bias=qkv_bias,
128
+ qk_rms_norm=qk_rms_norm_cross,
129
+ )
130
+ self.mlp = SparseFeedForwardNet(
131
+ channels,
132
+ mlp_ratio=mlp_ratio,
133
+ )
134
+
135
+ def _forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor):
136
+ h = x.replace(self.norm1(x.feats))
137
+ h = self.self_attn(h)
138
+ x = x + h
139
+ h = x.replace(self.norm2(x.feats))
140
+ h = self.cross_attn(h, context)
141
+ x = x + h
142
+ h = x.replace(self.norm3(x.feats))
143
+ h = self.mlp(h)
144
+ x = x + h
145
+ return x
146
+
147
+ def forward(self, x: SparseTensor, context: torch.Tensor):
148
+ if self.use_checkpoint:
149
+ return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False)
150
+ else:
151
+ return self._forward(x, context)
trellis/modules/sparse/transformer/modulated.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ from ..basic import SparseTensor
5
+ from ..attention import SparseMultiHeadAttention, SerializeMode
6
+ from ...norm import LayerNorm32
7
+ from .blocks import SparseFeedForwardNet
8
+
9
+
10
+ class ModulatedSparseTransformerBlock(nn.Module):
11
+ """
12
+ Sparse Transformer block (MSA + FFN) with adaptive layer norm conditioning.
13
+ """
14
+ def __init__(
15
+ self,
16
+ channels: int,
17
+ num_heads: int,
18
+ mlp_ratio: float = 4.0,
19
+ attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
20
+ window_size: Optional[int] = None,
21
+ shift_sequence: Optional[int] = None,
22
+ shift_window: Optional[Tuple[int, int, int]] = None,
23
+ serialize_mode: Optional[SerializeMode] = None,
24
+ use_checkpoint: bool = False,
25
+ use_rope: bool = False,
26
+ qk_rms_norm: bool = False,
27
+ qkv_bias: bool = True,
28
+ share_mod: bool = False,
29
+ ):
30
+ super().__init__()
31
+ self.use_checkpoint = use_checkpoint
32
+ self.share_mod = share_mod
33
+ self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
34
+ self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
35
+ self.attn = SparseMultiHeadAttention(
36
+ channels,
37
+ num_heads=num_heads,
38
+ attn_mode=attn_mode,
39
+ window_size=window_size,
40
+ shift_sequence=shift_sequence,
41
+ shift_window=shift_window,
42
+ serialize_mode=serialize_mode,
43
+ qkv_bias=qkv_bias,
44
+ use_rope=use_rope,
45
+ qk_rms_norm=qk_rms_norm,
46
+ )
47
+ self.mlp = SparseFeedForwardNet(
48
+ channels,
49
+ mlp_ratio=mlp_ratio,
50
+ )
51
+ if not share_mod:
52
+ self.adaLN_modulation = nn.Sequential(
53
+ nn.SiLU(),
54
+ nn.Linear(channels, 6 * channels, bias=True)
55
+ )
56
+
57
+ def _forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor:
58
+ if self.share_mod:
59
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
60
+ else:
61
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
62
+ h = x.replace(self.norm1(x.feats))
63
+ h = h * (1 + scale_msa) + shift_msa
64
+ h = self.attn(h)
65
+ h = h * gate_msa
66
+ x = x + h
67
+ h = x.replace(self.norm2(x.feats))
68
+ h = h * (1 + scale_mlp) + shift_mlp
69
+ h = self.mlp(h)
70
+ h = h * gate_mlp
71
+ x = x + h
72
+ return x
73
+
74
+ def forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor:
75
+ if self.use_checkpoint:
76
+ return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False)
77
+ else:
78
+ return self._forward(x, mod)
79
+
80
+
81
+ class ModulatedSparseTransformerCrossBlock(nn.Module):
82
+ """
83
+ Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning.
84
+ """
85
+ def __init__(
86
+ self,
87
+ channels: int,
88
+ ctx_channels: int,
89
+ num_heads: int,
90
+ mlp_ratio: float = 4.0,
91
+ attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
92
+ window_size: Optional[int] = None,
93
+ shift_sequence: Optional[int] = None,
94
+ shift_window: Optional[Tuple[int, int, int]] = None,
95
+ serialize_mode: Optional[SerializeMode] = None,
96
+ use_checkpoint: bool = False,
97
+ use_rope: bool = False,
98
+ qk_rms_norm: bool = False,
99
+ qk_rms_norm_cross: bool = False,
100
+ qkv_bias: bool = True,
101
+ share_mod: bool = False,
102
+
103
+ ):
104
+ super().__init__()
105
+ self.use_checkpoint = use_checkpoint
106
+ self.share_mod = share_mod
107
+ self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
108
+ self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
109
+ self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
110
+ self.self_attn = SparseMultiHeadAttention(
111
+ channels,
112
+ num_heads=num_heads,
113
+ type="self",
114
+ attn_mode=attn_mode,
115
+ window_size=window_size,
116
+ shift_sequence=shift_sequence,
117
+ shift_window=shift_window,
118
+ serialize_mode=serialize_mode,
119
+ qkv_bias=qkv_bias,
120
+ use_rope=use_rope,
121
+ qk_rms_norm=qk_rms_norm,
122
+ )
123
+ self.cross_attn = SparseMultiHeadAttention(
124
+ channels,
125
+ ctx_channels=ctx_channels,
126
+ num_heads=num_heads,
127
+ type="cross",
128
+ attn_mode="full",
129
+ qkv_bias=qkv_bias,
130
+ qk_rms_norm=qk_rms_norm_cross,
131
+ )
132
+ self.mlp = SparseFeedForwardNet(
133
+ channels,
134
+ mlp_ratio=mlp_ratio,
135
+ )
136
+ if not share_mod:
137
+ self.adaLN_modulation = nn.Sequential(
138
+ nn.SiLU(),
139
+ nn.Linear(channels, 6 * channels, bias=True)
140
+ )
141
+
142
+ def _forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor) -> SparseTensor:
143
+ if self.share_mod:
144
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
145
+ else:
146
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
147
+ h = x.replace(self.norm1(x.feats))
148
+ h = h * (1 + scale_msa) + shift_msa
149
+ h = self.self_attn(h)
150
+ h = h * gate_msa
151
+ x = x + h
152
+ h = x.replace(self.norm2(x.feats))
153
+ h = self.cross_attn(h, context)
154
+ x = x + h
155
+ h = x.replace(self.norm3(x.feats))
156
+ h = h * (1 + scale_mlp) + shift_mlp
157
+ h = self.mlp(h)
158
+ h = h * gate_mlp
159
+ x = x + h
160
+ return x
161
+
162
+ def forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor) -> SparseTensor:
163
+ if self.use_checkpoint:
164
+ return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, use_reentrant=False)
165
+ else:
166
+ return self._forward(x, mod, context)
trellis/modules/spatial.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor:
5
+ """
6
+ 3D pixel shuffle.
7
+ """
8
+ B, C, H, W, D = x.shape
9
+ C_ = C // scale_factor**3
10
+ x = x.reshape(B, C_, scale_factor, scale_factor, scale_factor, H, W, D)
11
+ x = x.permute(0, 1, 5, 2, 6, 3, 7, 4)
12
+ x = x.reshape(B, C_, H*scale_factor, W*scale_factor, D*scale_factor)
13
+ return x
14
+
15
+
16
+ def patchify(x: torch.Tensor, patch_size: int):
17
+ """
18
+ Patchify a tensor.
19
+
20
+ Args:
21
+ x (torch.Tensor): (N, C, *spatial) tensor
22
+ patch_size (int): Patch size
23
+ """
24
+ DIM = x.dim() - 2
25
+ for d in range(2, DIM + 2):
26
+ assert x.shape[d] % patch_size == 0, f"Dimension {d} of input tensor must be divisible by patch size, got {x.shape[d]} and {patch_size}"
27
+
28
+ x = x.reshape(*x.shape[:2], *sum([[x.shape[d] // patch_size, patch_size] for d in range(2, DIM + 2)], []))
29
+ x = x.permute(0, 1, *([2 * i + 3 for i in range(DIM)] + [2 * i + 2 for i in range(DIM)]))
30
+ x = x.reshape(x.shape[0], x.shape[1] * (patch_size ** DIM), *(x.shape[-DIM:]))
31
+ return x
32
+
33
+
34
+ def unpatchify(x: torch.Tensor, patch_size: int):
35
+ """
36
+ Unpatchify a tensor.
37
+
38
+ Args:
39
+ x (torch.Tensor): (N, C, *spatial) tensor
40
+ patch_size (int): Patch size
41
+ """
42
+ DIM = x.dim() - 2
43
+ assert x.shape[1] % (patch_size ** DIM) == 0, f"Second dimension of input tensor must be divisible by patch size to unpatchify, got {x.shape[1]} and {patch_size ** DIM}"
44
+
45
+ x = x.reshape(x.shape[0], x.shape[1] // (patch_size ** DIM), *([patch_size] * DIM), *(x.shape[-DIM:]))
46
+ x = x.permute(0, 1, *(sum([[2 + DIM + i, 2 + i] for i in range(DIM)], [])))
47
+ x = x.reshape(x.shape[0], x.shape[1], *[x.shape[2 + 2 * i] * patch_size for i in range(DIM)])
48
+ return x
trellis/modules/transformer/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .blocks import *
2
+ from .modulated import *
trellis/modules/transformer/blocks.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ from ..attention import MultiHeadAttention
5
+ from ..norm import LayerNorm32
6
+
7
+
8
+ class AbsolutePositionEmbedder(nn.Module):
9
+ """
10
+ Embeds spatial positions into vector representations.
11
+ """
12
+ def __init__(self, channels: int, in_channels: int = 3):
13
+ super().__init__()
14
+ self.channels = channels
15
+ self.in_channels = in_channels
16
+ self.freq_dim = channels // in_channels // 2
17
+ self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
18
+ self.freqs = 1.0 / (10000 ** self.freqs)
19
+
20
+ def _sin_cos_embedding(self, x: torch.Tensor) -> torch.Tensor:
21
+ """
22
+ Create sinusoidal position embeddings.
23
+
24
+ Args:
25
+ x: a 1-D Tensor of N indices
26
+
27
+ Returns:
28
+ an (N, D) Tensor of positional embeddings.
29
+ """
30
+ self.freqs = self.freqs.to(x.device)
31
+ out = torch.outer(x, self.freqs)
32
+ out = torch.cat([torch.sin(out), torch.cos(out)], dim=-1)
33
+ return out
34
+
35
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
36
+ """
37
+ Args:
38
+ x (torch.Tensor): (N, D) tensor of spatial positions
39
+ """
40
+ N, D = x.shape
41
+ assert D == self.in_channels, "Input dimension must match number of input channels"
42
+ embed = self._sin_cos_embedding(x.reshape(-1))
43
+ embed = embed.reshape(N, -1)
44
+ if embed.shape[1] < self.channels:
45
+ embed = torch.cat([embed, torch.zeros(N, self.channels - embed.shape[1], device=embed.device)], dim=-1)
46
+ return embed
47
+
48
+
49
+ class FeedForwardNet(nn.Module):
50
+ def __init__(self, channels: int, mlp_ratio: float = 4.0):
51
+ super().__init__()
52
+ self.mlp = nn.Sequential(
53
+ nn.Linear(channels, int(channels * mlp_ratio)),
54
+ nn.GELU(approximate="tanh"),
55
+ nn.Linear(int(channels * mlp_ratio), channels),
56
+ )
57
+
58
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
59
+ return self.mlp(x)
60
+
61
+
62
+ class TransformerBlock(nn.Module):
63
+ """
64
+ Transformer block (MSA + FFN).
65
+ """
66
+ def __init__(
67
+ self,
68
+ channels: int,
69
+ num_heads: int,
70
+ mlp_ratio: float = 4.0,
71
+ attn_mode: Literal["full", "windowed"] = "full",
72
+ window_size: Optional[int] = None,
73
+ shift_window: Optional[int] = None,
74
+ use_checkpoint: bool = False,
75
+ use_rope: bool = False,
76
+ qk_rms_norm: bool = False,
77
+ qkv_bias: bool = True,
78
+ ln_affine: bool = False,
79
+ ):
80
+ super().__init__()
81
+ self.use_checkpoint = use_checkpoint
82
+ self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
83
+ self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
84
+ self.attn = MultiHeadAttention(
85
+ channels,
86
+ num_heads=num_heads,
87
+ attn_mode=attn_mode,
88
+ window_size=window_size,
89
+ shift_window=shift_window,
90
+ qkv_bias=qkv_bias,
91
+ use_rope=use_rope,
92
+ qk_rms_norm=qk_rms_norm,
93
+ )
94
+ self.mlp = FeedForwardNet(
95
+ channels,
96
+ mlp_ratio=mlp_ratio,
97
+ )
98
+
99
+ def _forward(self, x: torch.Tensor) -> torch.Tensor:
100
+ h = self.norm1(x)
101
+ h = self.attn(h)
102
+ x = x + h
103
+ h = self.norm2(x)
104
+ h = self.mlp(h)
105
+ x = x + h
106
+ return x
107
+
108
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
109
+ if self.use_checkpoint:
110
+ return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
111
+ else:
112
+ return self._forward(x)
113
+
114
+
115
+ class TransformerCrossBlock(nn.Module):
116
+ """
117
+ Transformer cross-attention block (MSA + MCA + FFN).
118
+ """
119
+ def __init__(
120
+ self,
121
+ channels: int,
122
+ ctx_channels: int,
123
+ num_heads: int,
124
+ mlp_ratio: float = 4.0,
125
+ attn_mode: Literal["full", "windowed"] = "full",
126
+ window_size: Optional[int] = None,
127
+ shift_window: Optional[Tuple[int, int, int]] = None,
128
+ use_checkpoint: bool = False,
129
+ use_rope: bool = False,
130
+ qk_rms_norm: bool = False,
131
+ qk_rms_norm_cross: bool = False,
132
+ qkv_bias: bool = True,
133
+ ln_affine: bool = False,
134
+ ):
135
+ super().__init__()
136
+ self.use_checkpoint = use_checkpoint
137
+ self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
138
+ self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
139
+ self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
140
+ self.self_attn = MultiHeadAttention(
141
+ channels,
142
+ num_heads=num_heads,
143
+ type="self",
144
+ attn_mode=attn_mode,
145
+ window_size=window_size,
146
+ shift_window=shift_window,
147
+ qkv_bias=qkv_bias,
148
+ use_rope=use_rope,
149
+ qk_rms_norm=qk_rms_norm,
150
+ )
151
+ self.cross_attn = MultiHeadAttention(
152
+ channels,
153
+ ctx_channels=ctx_channels,
154
+ num_heads=num_heads,
155
+ type="cross",
156
+ attn_mode="full",
157
+ qkv_bias=qkv_bias,
158
+ qk_rms_norm=qk_rms_norm_cross,
159
+ )
160
+ self.mlp = FeedForwardNet(
161
+ channels,
162
+ mlp_ratio=mlp_ratio,
163
+ )
164
+
165
+ def _forward(self, x: torch.Tensor, context: torch.Tensor):
166
+ h = self.norm1(x)
167
+ h = self.self_attn(h)
168
+ x = x + h
169
+ h = self.norm2(x)
170
+ h = self.cross_attn(h, context)
171
+ x = x + h
172
+ h = self.norm3(x)
173
+ h = self.mlp(h)
174
+ x = x + h
175
+ return x
176
+
177
+ def forward(self, x: torch.Tensor, context: torch.Tensor):
178
+ if self.use_checkpoint:
179
+ return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False)
180
+ else:
181
+ return self._forward(x, context)
182
+
trellis/modules/transformer/modulated.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ from ..attention import MultiHeadAttention
5
+ from ..norm import LayerNorm32
6
+ from .blocks import FeedForwardNet
7
+
8
+
9
+ class ModulatedTransformerBlock(nn.Module):
10
+ """
11
+ Transformer block (MSA + FFN) with adaptive layer norm conditioning.
12
+ """
13
+ def __init__(
14
+ self,
15
+ channels: int,
16
+ num_heads: int,
17
+ mlp_ratio: float = 4.0,
18
+ attn_mode: Literal["full", "windowed"] = "full",
19
+ window_size: Optional[int] = None,
20
+ shift_window: Optional[Tuple[int, int, int]] = None,
21
+ use_checkpoint: bool = False,
22
+ use_rope: bool = False,
23
+ qk_rms_norm: bool = False,
24
+ qkv_bias: bool = True,
25
+ share_mod: bool = False,
26
+ ):
27
+ super().__init__()
28
+ self.use_checkpoint = use_checkpoint
29
+ self.share_mod = share_mod
30
+ self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
31
+ self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
32
+ self.attn = MultiHeadAttention(
33
+ channels,
34
+ num_heads=num_heads,
35
+ attn_mode=attn_mode,
36
+ window_size=window_size,
37
+ shift_window=shift_window,
38
+ qkv_bias=qkv_bias,
39
+ use_rope=use_rope,
40
+ qk_rms_norm=qk_rms_norm,
41
+ )
42
+ self.mlp = FeedForwardNet(
43
+ channels,
44
+ mlp_ratio=mlp_ratio,
45
+ )
46
+ if not share_mod:
47
+ self.adaLN_modulation = nn.Sequential(
48
+ nn.SiLU(),
49
+ nn.Linear(channels, 6 * channels, bias=True)
50
+ )
51
+
52
+ def _forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor:
53
+ if self.share_mod:
54
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
55
+ else:
56
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
57
+ h = self.norm1(x)
58
+ h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
59
+ h = self.attn(h)
60
+ h = h * gate_msa.unsqueeze(1)
61
+ x = x + h
62
+ h = self.norm2(x)
63
+ h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
64
+ h = self.mlp(h)
65
+ h = h * gate_mlp.unsqueeze(1)
66
+ x = x + h
67
+ return x
68
+
69
+ def forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor:
70
+ if self.use_checkpoint:
71
+ return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False)
72
+ else:
73
+ return self._forward(x, mod)
74
+
75
+
76
+ class ModulatedTransformerCrossBlock(nn.Module):
77
+ """
78
+ Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning.
79
+ """
80
+ def __init__(
81
+ self,
82
+ channels: int,
83
+ ctx_channels: int,
84
+ num_heads: int,
85
+ mlp_ratio: float = 4.0,
86
+ attn_mode: Literal["full", "windowed"] = "full",
87
+ window_size: Optional[int] = None,
88
+ shift_window: Optional[Tuple[int, int, int]] = None,
89
+ use_checkpoint: bool = False,
90
+ use_rope: bool = False,
91
+ qk_rms_norm: bool = False,
92
+ qk_rms_norm_cross: bool = False,
93
+ qkv_bias: bool = True,
94
+ share_mod: bool = False,
95
+ ):
96
+ super().__init__()
97
+ self.use_checkpoint = use_checkpoint
98
+ self.share_mod = share_mod
99
+ self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
100
+ self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
101
+ self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
102
+ self.self_attn = MultiHeadAttention(
103
+ channels,
104
+ num_heads=num_heads,
105
+ type="self",
106
+ attn_mode=attn_mode,
107
+ window_size=window_size,
108
+ shift_window=shift_window,
109
+ qkv_bias=qkv_bias,
110
+ use_rope=use_rope,
111
+ qk_rms_norm=qk_rms_norm,
112
+ )
113
+ self.cross_attn = MultiHeadAttention(
114
+ channels,
115
+ ctx_channels=ctx_channels,
116
+ num_heads=num_heads,
117
+ type="cross",
118
+ attn_mode="full",
119
+ qkv_bias=qkv_bias,
120
+ qk_rms_norm=qk_rms_norm_cross,
121
+ )
122
+ self.mlp = FeedForwardNet(
123
+ channels,
124
+ mlp_ratio=mlp_ratio,
125
+ )
126
+ if not share_mod:
127
+ self.adaLN_modulation = nn.Sequential(
128
+ nn.SiLU(),
129
+ nn.Linear(channels, 6 * channels, bias=True)
130
+ )
131
+
132
+ def _forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor):
133
+ if self.share_mod:
134
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
135
+ else:
136
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
137
+ h = self.norm1(x)
138
+ h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
139
+ h = self.self_attn(h)
140
+ h = h * gate_msa.unsqueeze(1)
141
+ x = x + h
142
+ h = self.norm2(x)
143
+ h = self.cross_attn(h, context)
144
+ x = x + h
145
+ h = self.norm3(x)
146
+ h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
147
+ h = self.mlp(h)
148
+ h = h * gate_mlp.unsqueeze(1)
149
+ x = x + h
150
+ return x
151
+
152
+ def forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor):
153
+ if self.use_checkpoint:
154
+ return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, use_reentrant=False)
155
+ else:
156
+ return self._forward(x, mod, context)
157
+
trellis/modules/utils.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from ..modules import sparse as sp
3
+
4
+ FP16_MODULES = (
5
+ nn.Conv1d,
6
+ nn.Conv2d,
7
+ nn.Conv3d,
8
+ nn.ConvTranspose1d,
9
+ nn.ConvTranspose2d,
10
+ nn.ConvTranspose3d,
11
+ nn.Linear,
12
+ sp.SparseConv3d,
13
+ sp.SparseInverseConv3d,
14
+ sp.SparseLinear,
15
+ )
16
+
17
+ def convert_module_to_f16(l):
18
+ """
19
+ Convert primitive modules to float16.
20
+ """
21
+ if isinstance(l, FP16_MODULES):
22
+ for p in l.parameters():
23
+ p.data = p.data.half()
24
+
25
+
26
+ def convert_module_to_f32(l):
27
+ """
28
+ Convert primitive modules to float32, undoing convert_module_to_f16().
29
+ """
30
+ if isinstance(l, FP16_MODULES):
31
+ for p in l.parameters():
32
+ p.data = p.data.float()
33
+
34
+
35
+ def zero_module(module):
36
+ """
37
+ Zero out the parameters of a module and return it.
38
+ """
39
+ for p in module.parameters():
40
+ p.detach().zero_()
41
+ return module
42
+
43
+
44
+ def scale_module(module, scale):
45
+ """
46
+ Scale the parameters of a module and return it.
47
+ """
48
+ for p in module.parameters():
49
+ p.detach().mul_(scale)
50
+ return module
51
+
52
+
53
+ def modulate(x, shift, scale):
54
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
trellis/pipelines/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import samplers
2
+ from .trellis_image_to_3d import TrellisImageTo3DPipeline
3
+ from .trellis_text_to_3d import TrellisTextTo3DPipeline
4
+
5
+
6
+ def from_pretrained(path: str):
7
+ """
8
+ Load a pipeline from a model folder or a Hugging Face model hub.
9
+
10
+ Args:
11
+ path: The path to the model. Can be either local path or a Hugging Face model name.
12
+ """
13
+ import os
14
+ import json
15
+ is_local = os.path.exists(f"{path}/pipeline.json")
16
+
17
+ if is_local:
18
+ config_file = f"{path}/pipeline.json"
19
+ else:
20
+ from huggingface_hub import hf_hub_download
21
+ config_file = hf_hub_download(path, "pipeline.json")
22
+
23
+ with open(config_file, 'r') as f:
24
+ config = json.load(f)
25
+ return globals()[config['name']].from_pretrained(path)
trellis/pipelines/base.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ from .. import models
5
+
6
+
7
+ class Pipeline:
8
+ """
9
+ A base class for pipelines.
10
+ """
11
+ def __init__(
12
+ self,
13
+ models: dict[str, nn.Module] = None,
14
+ ):
15
+ if models is None:
16
+ return
17
+ self.models = models
18
+ for model in self.models.values():
19
+ model.eval()
20
+
21
+ @staticmethod
22
+ def from_pretrained(path: str) -> "Pipeline":
23
+ """
24
+ Load a pretrained model.
25
+ """
26
+ import os
27
+ import json
28
+ is_local = os.path.exists(f"{path}/pipeline.json")
29
+
30
+ if is_local:
31
+ config_file = f"{path}/pipeline.json"
32
+ else:
33
+ from huggingface_hub import hf_hub_download
34
+ config_file = hf_hub_download(path, "pipeline.json")
35
+
36
+ with open(config_file, 'r') as f:
37
+ args = json.load(f)['args']
38
+
39
+ _models = {}
40
+ for k, v in args['models'].items():
41
+ try:
42
+ _models[k] = models.from_pretrained(f"{path}/{v}")
43
+ except:
44
+ _models[k] = models.from_pretrained(v)
45
+
46
+ new_pipeline = Pipeline(_models)
47
+ new_pipeline._pretrained_args = args
48
+ return new_pipeline
49
+
50
+ @property
51
+ def device(self) -> torch.device:
52
+ for model in self.models.values():
53
+ if hasattr(model, 'device'):
54
+ return model.device
55
+ for model in self.models.values():
56
+ if hasattr(model, 'parameters'):
57
+ return next(model.parameters()).device
58
+ raise RuntimeError("No device found.")
59
+
60
+ def to(self, device: torch.device) -> None:
61
+ for model in self.models.values():
62
+ model.to(device)
63
+
64
+ def cuda(self) -> None:
65
+ self.to(torch.device("cuda"))
66
+
67
+ def cpu(self) -> None:
68
+ self.to(torch.device("cpu"))
trellis/pipelines/samplers/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .base import Sampler
2
+ from .flow_euler import FlowEulerSampler, FlowEulerCfgSampler, FlowEulerGuidanceIntervalSampler
trellis/pipelines/samplers/base.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ from abc import ABC, abstractmethod
3
+
4
+
5
+ class Sampler(ABC):
6
+ """
7
+ A base class for samplers.
8
+ """
9
+
10
+ @abstractmethod
11
+ def sample(
12
+ self,
13
+ model,
14
+ **kwargs
15
+ ):
16
+ """
17
+ Sample from a model.
18
+ """
19
+ pass
20
+